package net.sansa_stack.query.spark.rdd.op;

import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.aksw.commons.util.algebra.GenericDag;
import org.aksw.jena_sparql_api.algebra.utils.OpUtils;
import org.aksw.jena_sparql_api.algebra.utils.OpVar;
import org.apache.jena.graph.NodeFactory;
import org.apache.jena.sparql.algebra.Op;
import org.apache.jena.sparql.algebra.op.OpDisjunction;
import org.apache.jena.sparql.algebra.op.OpJoin;
import org.apache.jena.sparql.algebra.op.OpLateral;
import org.apache.jena.sparql.algebra.op.OpService;
import org.apache.jena.sparql.core.Var;
import org.apache.jena.sparql.core.VarAlloc;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:net/sansa_stack/query/spark/rdd/op/CacheOptimizer.class */
public class CacheOptimizer {
    private static final Logger logger = LoggerFactory.getLogger(JavaRddOfBindingsOps.class);

    public static boolean defaultBlocker(Op op, int i, Op op2) {
        return (op instanceof OpService) || ((op instanceof OpLateral) && i != 0);
    }

    public static GenericDag<Op, Var> buildDag(Op op) {
        OpUtils.OpOps opOps = OpUtils.getOpOps();
        VarAlloc varAlloc = new VarAlloc("op");
        GenericDag<Op, Var> genericDag = new GenericDag<>(opOps, varAlloc::allocVar, CacheOptimizer::defaultBlocker);
        genericDag.addRoot(op);
        HashSet<Var> hashSet = new HashSet();
        for (Map.Entry entry : genericDag.getChildToParent().asMap().entrySet()) {
            Var var = (Var) entry.getKey();
            if (((Collection) entry.getValue()).size() > 1) {
                hashSet.add(var);
            }
        }
        Stream stream = hashSet.stream();
        Objects.requireNonNull(genericDag);
        Set<Op> set = (Set) stream.map((v1) -> {
            return r1.getExpr(v1);
        }).collect(Collectors.toSet());
        Map<Op, Float> assessCacheImpact = assessCacheImpact(genericDag, set);
        for (Op op2 : set) {
            float floatValue = assessCacheImpact.get(op2).floatValue();
            if (floatValue < 10000.0f) {
                Var var2 = (Var) genericDag.getVar(op2);
                logger.info("Removing low impact cache candidate (cost=" + floatValue + "): " + var2);
                hashSet.remove(var2);
            }
        }
        logger.info("Remaining cache nodes: " + hashSet);
        for (Var var3 : hashSet) {
            Op op3 = (Op) genericDag.getExpr(var3);
            Var alloc = Var.alloc(var3.getName() + "_cached");
            genericDag.getVarToExpr().remove(var3);
            genericDag.getVarToExpr().put(alloc, op3);
            genericDag.getVarToExpr().put(var3, new OpService(NodeFactory.createURI("rdd:cache"), new OpVar(alloc), false));
        }
        genericDag.collapse();
        logger.info("Roots: " + genericDag.getRoots());
        Iterator it = genericDag.getVarToExpr().entrySet().iterator();
        while (it.hasNext()) {
            System.err.println(((Map.Entry) it.next()).toString());
        }
        return genericDag;
    }

    public static Map<Op, Float> assessCacheImpact(GenericDag<Op, Var> genericDag, Set<Op> set) {
        HashMap hashMap = new HashMap();
        Consumer consumer = op -> {
            hashMap.put(op, Float.valueOf(assessCostContribution(op, hashMap, set)));
        };
        Iterator it = genericDag.getRoots().iterator();
        while (it.hasNext()) {
            GenericDag.depthFirstTraverse(genericDag, (Object) null, 0, (Op) it.next(), CacheOptimizer::defaultBlocker, consumer);
        }
        return hashMap;
    }

    public static float assessCostContribution(Op op, Map<Op, Float> map, Set<Op> set) {
        float f;
        float floatValue;
        if (isRmlSourceOp(op)) {
            f = 1000000.0f;
        } else if ((op instanceof OpJoin) || (op instanceof OpDisjunction)) {
            f = 1000000.0f;
        } else {
            float f2 = 0.0f;
            for (Op op2 : OpUtils.getSubOps(op)) {
                if (set.contains(op2)) {
                    floatValue = 0.0f;
                } else {
                    Float f3 = map.get(op2);
                    floatValue = f3 == null ? 0.0f : f3.floatValue();
                }
                if (floatValue > f2) {
                    f2 = floatValue;
                }
            }
            f = f2 + 1.0f;
        }
        return f;
    }

    public static boolean isCacheOp(Op op) {
        return OpUtils.isServiceWithIri(op, "rdd:cache");
    }

    public static boolean isRmlSourceOp(Op op) {
        return OpUtils.isServiceWithIri(op, "https://w3id.org/aksw/norse#rml.source");
    }
}
