/*
 * Decompiled with CFR 0.152.
 */
package org.aksw.jena_sparql_api_sparql_path2.playground;

import com.google.common.collect.HashMultimap;
import java.util.AbstractMap;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.aksw.commons.collections.multimaps.BiHashMultimap;
import org.aksw.jena_sparql_api.sparql_path2.JGraphTUtils;
import org.aksw.jena_sparql_api.sparql_path2.Nfa;
import org.aksw.jena_sparql_api.sparql_path2.Pair;
import org.aksw.jena_sparql_api.sparql_path2.PredicateClass;
import org.aksw.jena_sparql_api.sparql_path2.ValueSet;
import org.aksw.jena_sparql_api_sparql_path2.playground.JoinSummaryService;
import org.aksw.jena_sparql_api_sparql_path2.playground.MapUtils;
import org.aksw.jena_sparql_api_sparql_path2.playground.NfaAnalysisResult;
import org.aksw.jenax.arq.util.var.Vars;
import org.aksw.jenax.dataaccess.sparql.factory.execution.query.QueryExecutionFactory;
import org.apache.jena.graph.Node;
import org.apache.jena.query.QueryExecution;
import org.apache.jena.query.ResultSet;
import org.apache.jena.sparql.engine.binding.Binding;
import org.jgrapht.Graph;
import org.jgrapht.graph.DefaultDirectedGraph;
import org.jgrapht.graph.DefaultEdge;

public class EdgeReducer {
    public static <S, T> NfaAnalysisResult<S> estimateFrontierCost(Nfa<S, T> nfa, Predicate<T> isEpsilon, Function<T, PredicateClass> transToPredicateClass, Pair<Map<Node, Number>> initPredFreqs, JoinSummaryService joinSummaryService) {
        HashMap stateToDiPredToCost = new HashMap();
        Pair currOpenDiPreds = new Pair(new HashSet(), new HashSet());
        for (Object state : nfa.getGraph().vertexSet()) {
            stateToDiPredToCost.put(state, new Pair(new HashMap(), new HashMap()));
        }
        DefaultDirectedGraph joinGraph = new DefaultDirectedGraph(DefaultEdge.class);
        for (Object state : nfa.getStartStates()) {
            Pair diPredToCost = (Pair)stateToDiPredToCost.get(state);
            for (int i = 0; i < 2; ++i) {
                Map predToCost = (Map)diPredToCost.get(i);
                Map<Node, Number> initPredToCost = initPredFreqs.get(i);
                MapUtils.mergeMapsInPlace(predToCost, initPredToCost, (a, b) -> a.doubleValue() + b.doubleValue());
                Set openPreds = currOpenDiPreds.get(i);
                openPreds.addAll(predToCost.keySet());
            }
        }
        Graph<S, T> graph = nfa.getGraph();
        Set<S> currOpenStates = nfa.getStartStates();
        int remainingSteps = 100;
        while ((((Set)currOpenDiPreds.get(0)).size() != 0 || ((Set)currOpenDiPreds.get(1)).size() != 0) && remainingSteps-- > 0) {
            HashSet<S> nextOpenStates = new HashSet<S>();
            Pair nextOpenDiPreds = new Pair(new HashSet(), new HashSet());
            Map<Object, Pair> predsInFrontier = currOpenStates.stream().collect(Collectors.toMap(e -> e, e -> new Pair<Integer>(((Map)((Pair)stateToDiPredToCost.get(e)).get(0)).keySet().size(), ((Map)((Pair)stateToDiPredToCost.get(e)).get(1)).keySet().size())));
            System.out.println("ENTERING states: " + String.valueOf(predsInFrontier));
            for (S srcState : currOpenStates) {
                Set transitions = graph.outgoingEdgesOf(srcState);
                Pair srcDiPredToCost = (Pair)stateToDiPredToCost.get(srcState);
                for (Object trans : transitions) {
                    Object tgtState = graph.getEdgeTarget(trans);
                    Pair tgtDiPredToCost = (Pair)stateToDiPredToCost.get(tgtState);
                    nextOpenStates.add(tgtState);
                    if (isEpsilon.test(trans)) {
                        for (int i = 0; i < 2; ++i) {
                            Set nextOpenPreds = nextOpenDiPreds.get(i);
                            Map srcPredToCost = (Map)srcDiPredToCost.get(i);
                            Map tgtPredToCost = (Map)tgtDiPredToCost.get(i);
                            Set srcPreds = srcPredToCost.keySet();
                            Set tgtPreds = tgtPredToCost.keySet();
                            Set tgtContribPreds = srcPreds.stream().filter(p -> !tgtPreds.contains(p)).collect(Collectors.toSet());
                            MapUtils.mergeMapsInPlace(tgtPredToCost, srcPredToCost, (a, b) -> Math.max(a.doubleValue(), b.doubleValue()));
                            nextOpenPreds.addAll(tgtContribPreds);
                        }
                        continue;
                    }
                    PredicateClass transPredClass = transToPredicateClass.apply(trans);
                    for (int i = 0; i < 2; ++i) {
                        boolean reverse = i == 1;
                        ValueSet transPredSet = (ValueSet)transPredClass.get(i);
                        Set currOpenPreds = currOpenDiPreds.get(i);
                        Set nextOpenPreds = nextOpenDiPreds.get(i);
                        if (transPredSet.isEmpty()) continue;
                        Map srcPredToCost = (Map)srcDiPredToCost.get(i);
                        Map tgtPredToCost = (Map)tgtDiPredToCost.get(i);
                        Set tgtPreds = tgtPredToCost.keySet();
                        if (srcPredToCost == null) {
                            throw new RuntimeException("not implemented yet");
                        }
                        Set<Node> openPassPreds = currOpenPreds.stream().filter(p -> transPredSet.contains(p)).collect(Collectors.toSet());
                        Map<Node, Map<Node, Number>> joinSummaryFragment = joinSummaryService.fetch(openPassPreds, reverse);
                        if (joinSummaryFragment == null) {
                            throw new RuntimeException("Join summary fragment was null - should not happen");
                        }
                        for (Node openPassPred : openPassPreds) {
                            Number baseCost = srcPredToCost.getOrDefault(openPassPred, 0);
                            if (baseCost == null) {
                                throw new RuntimeException("No base cost for " + String.valueOf(openPassPred) + " - should not happen");
                            }
                            Map joinPredToCost = joinSummaryFragment.getOrDefault(openPassPred, Collections.emptyMap());
                            Set joinPreds = joinPredToCost.keySet();
                            Set tgtContribPreds = joinPreds.stream().filter(p -> !tgtPreds.contains(p)).collect(Collectors.toSet());
                            Map<Node, Number> tgtPredCostContrib = tgtContribPreds.stream().map(p -> new AbstractMap.SimpleEntry<Node, Double>((Node)p, baseCost.doubleValue() * ((Number)joinPredToCost.get(p)).doubleValue())).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
                            MapUtils.mergeMapsInPlace(tgtPredToCost, tgtPredCostContrib, (a, b) -> a.doubleValue() + b.doubleValue());
                            joinGraph.addVertex((Object)openPassPred);
                            tgtContribPreds.stream().forEach(arg_0 -> EdgeReducer.lambda$estimateFrontierCost$9((Graph)joinGraph, openPassPred, arg_0));
                            nextOpenPreds.addAll(tgtContribPreds);
                        }
                        System.out.println("Join graph size now: " + joinGraph.edgeSet().size() + " after " + openPassPreds.size() + " predicates passing transition " + String.valueOf(trans));
                    }
                }
            }
            currOpenStates = nextOpenStates;
            currOpenDiPreds = nextOpenDiPreds;
        }
        return new NfaAnalysisResult(stateToDiPredToCost, (Graph<Node, DefaultEdge>)joinGraph);
    }

    public static <S, T> Map<T, Double> trimPredicates(Nfa<S, T> nfa, Predicate<T> isEpsilon, Function<T, PredicateClass> transitionToPredicateClass, Pair<Map<Node, Number>> initDiPredToCost, Map<S, Pair<Map<Node, Number>>> stateToDiPredToCost, JoinSummaryService joinSummaryService) {
        Graph<S, T> graph = nfa.getGraph();
        HashMap stateToDiRestPreds = new HashMap();
        for (Object state : nfa.getGraph().vertexSet()) {
            stateToDiRestPreds.put(state, new Pair(new HashSet(), new HashSet()));
        }
        Set<S> endStates = nfa.getEndStates();
        for (Object endState : endStates) {
            Pair<Map<Node, Number>> diPredToCost = stateToDiPredToCost.get(endState);
            Pair diRestPreds = (Pair)stateToDiRestPreds.get(endState);
            for (int i = 0; i < 2; ++i) {
                boolean reverse = i == 1;
                Map<Node, Number> predToCost = diPredToCost.get(i);
                Map<Node, Number> initPredToCost = initDiPredToCost.get(i);
                Set<Node> preds = predToCost.keySet();
                Set<Node> initPreds = initPredToCost.keySet();
                preds.retainAll(initPreds);
            }
        }
        Set<S> currOpenStates = nfa.getEndStates();
        while (!currOpenStates.isEmpty()) {
            for (S tgtState : currOpenStates) {
                Set transitions = graph.incomingEdgesOf(tgtState);
                for (Object reverse : transitions) {
                }
                Pair<Map<Node, Number>> diPredToCost = stateToDiPredToCost.get(tgtState);
                for (int i = 0; i < 2; ++i) {
                    boolean reverse = i == 1;
                    Map<Node, Number> map = diPredToCost.get(i);
                }
            }
        }
        return null;
    }

    public static <S, T> int determineRequiredPredicateDirectionsForRetrieval(Graph<S, T> graph, S state, Function<T, PredicateClass> toPredicateClass) {
        Set edges = graph.outgoingEdgesOf(state);
        int result = 0;
        for (Object edge : edges) {
            PredicateClass pc = toPredicateClass.apply(edge);
            result |= !pc.getFwdNodes().isEmpty() ? 1 : 0;
            result |= !pc.getBwdNodes().isEmpty() ? 2 : 0;
        }
        return result;
    }

    public static BiHashMultimap<Node, Node> loadJoinSummary(QueryExecutionFactory qef) {
        BiHashMultimap result = new BiHashMultimap();
        QueryExecution qe = qef.createQueryExecution("PREFIX o: <http://example.org/ontology/> SELECT ?x ?y { ?s o:sourcePredicate ?x ; o:targetPredicate ?y }");
        ResultSet rs = qe.execSelect();
        while (rs.hasNext()) {
            Binding binding = rs.nextBinding();
            Node x = binding.get(Vars.x);
            Node y = binding.get(Vars.y);
            result.put((Object)x, (Object)y);
        }
        return result;
    }

    public static Map<Node, Long> loadPredicateSummary(QueryExecutionFactory qef) {
        HashMap<Node, Long> result = new HashMap<Node, Long>();
        QueryExecution qe = qef.createQueryExecution("PREFIX o: <http://example.org/ontology/> SELECT ?x ?y { ?s a o:PredicateSummary ; o:predicate ?x ; o:freqTotal ?y }");
        ResultSet rs = qe.execSelect();
        while (rs.hasNext()) {
            Binding binding = rs.nextBinding();
            Node x = binding.get(Vars.x);
            Node y = binding.get(Vars.y);
            result.put(x, ((Number)y.getLiteralValue()).longValue());
        }
        return result;
    }

    public static <S, T, P> Set<P> getReferencedPredicates(Nfa<S, T> nfa, Predicate<Map.Entry<P, P>> joins, Predicate<T> isEpsilon, Function<Set<T>, Set<P>> transitionsToPredicates) {
        HashSet result = new HashSet();
        Graph<S, T> graph = nfa.getGraph();
        boolean change = true;
        HashSet<S> visited = new HashSet<S>();
        Set<S> states = nfa.getStartStates();
        HashMultimap priorTransitions = HashMultimap.create();
        while (change) {
            HashSet<S> nextStates = new HashSet<S>();
            HashMultimap nextTransitions = HashMultimap.create();
            Set<T> transitions = JGraphTUtils.resolveTransitions(graph, isEpsilon, states, false);
            for (T t : transitions) {
                Object state = graph.getEdgeTarget(t);
                Set<P> nextPredicates = transitionsToPredicates.apply(Collections.singleton(t));
                Collection priorPredicates = priorTransitions.get(state);
                Set set = nextPredicates.stream().filter(np -> priorPredicates.stream().anyMatch(p -> joins.test(new AbstractMap.SimpleEntry<Object, Object>(np, p)))).collect(Collectors.toSet());
            }
            change = visited.addAll(states);
            states = nextStates;
        }
        return null;
    }

    private static /* synthetic */ void lambda$estimateFrontierCost$9(Graph joinGraph, Node openPassPred, Node tgtPred) {
        joinGraph.addVertex((Object)tgtPred);
        joinGraph.addEdge((Object)openPassPred, (Object)tgtPred);
    }
}

