package org.aksw.jenax.reprogen.util;

import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

import org.aksw.commons.util.string.StringUtils;
import org.aksw.jenax.arq.util.node.NodeTransformLib2;
import org.aksw.jenax.arq.util.var.Vars;
import org.aksw.jenax.reprogen.core.MapperProxyUtils;
import org.aksw.jenax.reprogen.hashid.HashIdCxt;
import org.apache.jena.graph.Graph;
import org.apache.jena.graph.Node;
import org.apache.jena.graph.NodeFactory;
import org.apache.jena.graph.Triple;
import org.apache.jena.rdf.model.Model;
import org.apache.jena.rdf.model.ModelFactory;
import org.apache.jena.rdf.model.Property;
import org.apache.jena.rdf.model.RDFNode;
import org.apache.jena.rdf.model.Resource;
import org.apache.jena.rdf.model.ResourceFactory;
import org.apache.jena.sparql.graph.NodeTransform;
import org.apache.jena.sparql.graph.NodeTransformLib;
import org.apache.jena.sparql.util.Closure;
import org.apache.jena.sparql.util.ModelUtils;
import org.apache.jena.sparql.util.graph.GraphUtils;
import org.apache.jena.system.G;
import org.apache.jena.util.ResourceUtils;

public class Skolemize {
    // Property for the skolemized id (without uri prefix and such)
    public static final Property skolemId = ResourceFactory.createProperty("http://tmp.aksw.org/skolemId");

    public static void skolemize2(Resource r) {

    }

    /**
     * Skolemizes blank nodes using a two phase approach:
     * First, for each individual blank node a signature string is created from its direct neighbors with blank nodes replaced with a constant.
     * Finally, computes the signature again, with blank node neighbors replaced by their signature string of the first pass.
     *
     * @param r
     */
    public static void skolemize(Resource r) {
        if(!r.isURIResource()) {
            throw new RuntimeException("This skolemization function requires a URI resource as input");
        }

        String baseUri = r.getURI();
        Model model = r.getModel();
        Model closure = Closure.closure(r, false);//ResourceUtils.reachableClosure(r);

        Graph g = closure.getGraph();
        Iterable<Node> allNodes = () -> GraphUtils.allNodes(g);

        Set<Node> blankNodes = StreamSupport.stream(allNodes.spliterator(), false)
                .filter(x -> x.isBlank() || x.isVariable())
                .collect(Collectors.toSet());


        NodeTransform unifyBlankNodes = (node) -> node.isBlank() || node.isVariable() ? Vars.a : node;

        Map<Node, Node> nodeToLocalHash = blankNodes.stream()
            .collect(Collectors.toMap(
                    x -> x,
                    x -> NodeFactory.createLiteral(createSignature(g, x, unifyBlankNodes))
            ));

        Map<Node, String> nodeToGlobalHash = blankNodes.stream()
                .collect(Collectors.toMap(
                        x -> x,
                        x -> createSignature(g, x, node -> nodeToLocalHash.getOrDefault(node, node))
                ));

        Map<Resource, String> map = blankNodes.stream()
                .collect(Collectors.toMap(
                    n -> (Resource)ModelUtils.convertGraphNodeToRDFNode(n, model),
                    n -> nodeToGlobalHash.get(n).substring(0, 8)
                ));


        map.entrySet().forEach(e -> e.getKey().addLiteral(skolemId, e.getValue()));

        map.entrySet().forEach(e -> ResourceUtils.renameResource(e.getKey(), baseUri + "-bn" + e.getValue()));
    }

    public static String createSignature(Graph g, Node n, Function<? super Node, ? extends Node> nodeTransform) {
        List<Triple> rawSig = createRawSignature(g, n, nodeTransform);
        String result = StringUtils.md5Hash("" + rawSig);
        return result;
    }

    public static List<Triple> createRawSignature(Graph g, Node n, Function<? super Node, ? extends Node> nodeTransform) {
        List<Triple> triples = g.find(n, Node.ANY, Node.ANY).andThen(g.find(Node.ANY, Node.ANY, n)).toList();

        NodeTransform fn = (node) -> nodeTransform.apply(node);

        List<Triple> result = triples.stream()
                .map(triple -> NodeTransformLib.transform(fn, triple))
                .sorted((a, b) -> ("" + a).compareTo("" + b))
                .collect(Collectors.toList());

        return result;
    }

    public static <T extends RDFNode> T skolemize(
            Resource root,
            String baseIri,
            Class<T> cls,
            BiConsumer<Resource, Map<Node, Node>> postProcessor) {
        return skolemize(root, null, baseIri, cls, postProcessor);
    }

    public static <T extends RDFNode> T skolemize(
            Resource root,
            Model staticModel,
            String baseIri,
            Class<T> cls) {
        return skolemize(root, staticModel, baseIri, cls, null);
    }

    /**
     * Skolemize the resource root and all relevant reachable resources.
     * The root's model is internally unioned with the staticModel (if non null).
     * Resources that appear in the static model are excluded from renaming.
     * Note, that for the computation of structural hashes, the static model may still need to be traversed
     * even if the traversed resources are excluded from renaming.
     *
     * @param <T>
     * @param root
     * @param staticModel Model that will be union'd in but renames will not be applied to it.
     * @param baseIri
     * @param cls
     * @param postProcessor
     * @return
     */
    public static <T extends RDFNode> T skolemize(
            Resource root,
            Model staticModel,
            String baseIri,
            Class<T> cls,
            BiConsumer<Resource, Map<Node, Node>> postProcessor) {

        Graph staticGraph = staticModel != null ? staticModel.getGraph() : null;

        Resource mergedRoot = staticModel != null && root.getModel() != staticModel
            ? root.inModel(ModelFactory.createUnion(root.getModel(), staticModel))
            : root;

        T q = mergedRoot.as(cls);

        HashIdCxt hashIdCxt = MapperProxyUtils.getHashId(q);
        Map<Node, Node> renames = hashIdCxt.getNodeMapping(baseIri);

        Set<Node> blockedRenames = staticGraph == null
                ? Set.of()
                : renames.keySet().stream().filter(n -> G.containsNode(staticGraph, n)).collect(Collectors.toSet());

        Node newRoot = renames.get(q.asNode());

        // Also rename the original graph name to match the IRI of the new lsq query root

        Model rootModel = root.getModel();
        // Apply an in-place node transform on the dataset
        // queryInDataset = ResourceInDatasetImpl.applyNodeTransform(queryInDataset, NodeTransformLib2.makeNullSafe(renames::get));
        NodeTransformLib2.applyNodeTransform(NodeTransformLib2.wrapWithNullAsIdentity(
                n -> blockedRenames.contains(n) ? null : renames.get(n)),
                rootModel);

        Resource resultRes = rootModel.asRDFNode(newRoot).asResource();

        if (postProcessor != null) {
            postProcessor.accept(resultRes, renames);
        }

        T result = resultRes.as(cls);
        return result;
    }
}

