package org.dllearner.algorithms.qtl.experiments;

import com.google.common.base.Charsets;
import com.google.common.io.Files;
import com.mxgraph.layout.orthogonal.mxOrthogonalLayout;
import com.mxgraph.swing.mxGraphComponent;
import com.mxgraph.util.mxCellRenderer;
import com.mxgraph.util.mxConstants;
import com.mxgraph.util.mxRectangle;
import com.mxgraph.util.png.mxPngEncodeParam;
import com.mxgraph.util.png.mxPngImageEncoder;
import com.mxgraph.view.mxGraph;
import com.mxgraph.view.mxStylesheet;
import java.awt.Color;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.text.DecimalFormat;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import org.aksw.jena_sparql_api.core.QueryExecutionFactory;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
import org.apache.jena.graph.Node;
import org.apache.jena.query.ParameterizedSparqlString;
import org.apache.jena.query.Query;
import org.apache.jena.query.QueryExecution;
import org.apache.jena.query.QueryFactory;
import org.dllearner.algorithms.qtl.experiments.SPARQLUtils;
import org.dllearner.kb.sparql.CBDStructureTree;
import org.dllearner.kb.sparql.SparqlEndpoint;
import org.dllearner.kb.sparql.TreeBasedConciseBoundedDescriptionGenerator;
import org.dllearner.utilities.ProgressBar;
import org.dllearner.utilities.QueryUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/dllearner/algorithms/qtl/experiments/BenchmarkDescriptionGenerator.class */
public abstract class BenchmarkDescriptionGenerator {
    private static final Logger LOGGER = LoggerFactory.getLogger(BenchmarkDescriptionGenerator.class);
    private QueryExecutionFactory qef;
    private TreeBasedConciseBoundedDescriptionGenerator cbdGen;
    private QueryUtils utils = new QueryUtils();
    private boolean useConstruct = true;
    protected Set<String> skipQueryTokens = new HashSet();
    private CBDStructureTree defaultCbdStructure;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/dllearner/algorithms/qtl/experiments/BenchmarkDescriptionGenerator$QueryData.class */
    public class QueryData {
        final String id;
        final Query query;
        final SPARQLUtils.QueryType queryType;
        final int maxTreeDepth;
        final int nrOfInstances;
        final DescriptiveStatistics optimalCBDSizeStats;
        final DescriptiveStatistics defaultCBDSizesStats;

        public QueryData(String str, Query query, SPARQLUtils.QueryType queryType, int i, int i2, DescriptiveStatistics descriptiveStatistics, DescriptiveStatistics descriptiveStatistics2) {
            this.id = str;
            this.query = query;
            this.queryType = queryType;
            this.maxTreeDepth = i;
            this.nrOfInstances = i2;
            this.optimalCBDSizeStats = descriptiveStatistics;
            this.defaultCBDSizesStats = descriptiveStatistics2;
        }
    }

    public BenchmarkDescriptionGenerator(QueryExecutionFactory queryExecutionFactory) {
        this.qef = queryExecutionFactory;
        this.cbdGen = new TreeBasedConciseBoundedDescriptionGenerator(queryExecutionFactory);
    }

    public void setWorkaroundEnabled(boolean z) {
        this.cbdGen.setWorkaround(z);
    }

    public void setEndpoint(SparqlEndpoint sparqlEndpoint) {
        this.cbdGen.setEndpoint(sparqlEndpoint);
    }

    protected abstract void beginDocument();

    protected abstract void endDocument();

    protected abstract void beginTable();

    protected abstract void addRow(QueryData queryData);

    protected abstract void endTable();

    public void generateBenchmarkDescription(File file, boolean z) throws Exception {
        HashMap hashMap = new HashMap();
        for (String str : Files.readLines(file, Charsets.UTF_8)) {
            String valueOf = String.valueOf(1);
            if (z) {
                valueOf = str.substring(0, str.indexOf(","));
                str = str.substring(str.indexOf(",") + 1);
            }
            hashMap.put(valueOf, QueryFactory.create(str));
        }
        generateBenchmarkDescription(hashMap);
    }

    public void generateBenchmarkDescription(Map<String, Query> map) throws Exception {
        beginDocument();
        beginTable();
        for (Map.Entry<String, Query> entry : map.entrySet()) {
            String key = entry.getKey();
            Query value = entry.getValue();
            if (!this.skipQueryTokens.stream().anyMatch(str -> {
                return value.toString().contains(str);
            })) {
                System.out.println(value);
                SPARQLUtils.QueryType queryType = SPARQLUtils.getQueryType(value);
                int longestPath = getLongestPath(value);
                List<String> result = SPARQLUtils.getResult(this.qef, value);
                addRow(new QueryData(key, value, queryType, longestPath, result.size(), determineOptimalCBDSizes(value, result), determineDefaultCBDSizes(value, result)));
            }
        }
        endTable();
        endDocument();
    }

    public void setSkipQueryTokens(Collection<String> collection) {
        this.skipQueryTokens.addAll(collection);
    }

    private int getLongestPath(Query query) {
        SPARQLUtils.QueryType queryType = SPARQLUtils.getQueryType(query);
        int i = 0;
        if (queryType == SPARQLUtils.QueryType.IN) {
            Set extractIncomingTriplePatterns = this.utils.extractIncomingTriplePatterns(query, (Node) query.getProjectVars().get(0));
            while (true) {
                Set set = extractIncomingTriplePatterns;
                if (set.isEmpty()) {
                    break;
                }
                i++;
                extractIncomingTriplePatterns = (Set) set.stream().filter(triple -> {
                    return triple.getSubject().isVariable();
                }).map((v0) -> {
                    return v0.getSubject();
                }).map(node -> {
                    return this.utils.extractIncomingTriplePatterns(query, node);
                }).flatMap((v0) -> {
                    return v0.stream();
                }).collect(Collectors.toSet());
            }
        } else if (queryType == SPARQLUtils.QueryType.OUT) {
            Set extractOutgoingTriplePatterns = this.utils.extractOutgoingTriplePatterns(query, (Node) query.getProjectVars().get(0));
            while (true) {
                Set set2 = extractOutgoingTriplePatterns;
                if (set2.isEmpty()) {
                    break;
                }
                i++;
                extractOutgoingTriplePatterns = (Set) set2.stream().filter(triple2 -> {
                    return triple2.getObject().isVariable();
                }).map((v0) -> {
                    return v0.getObject();
                }).map(node2 -> {
                    return this.utils.extractOutgoingTriplePatterns(query, node2);
                }).flatMap((v0) -> {
                    return v0.stream();
                }).collect(Collectors.toSet());
            }
        } else {
            i = -1;
        }
        return i;
    }

    public void setDefaultCbdStructure(CBDStructureTree cBDStructureTree) {
        this.defaultCbdStructure = cBDStructureTree;
    }

    private CBDStructureTree getDefaultCBDStructureTree() {
        return this.defaultCbdStructure;
    }

    private DescriptiveStatistics determineDefaultCBDSizes(Query query, List<String> list) {
        DescriptiveStatistics descriptiveStatistics = new DescriptiveStatistics();
        DecimalFormat.getPercentInstance(Locale.ROOT);
        AtomicInteger atomicInteger = new AtomicInteger(1);
        System.out.println(getDefaultCBDStructureTree().toStringVerbose());
        ProgressBar progressBar = new ProgressBar();
        list.forEach(str -> {
            long j = -1;
            if (!this.useConstruct) {
                ParameterizedSparqlString copy = SPARQLUtils.CBD_TEMPLATE_DEPTH3.copy();
                copy.setIri("uri", str);
                try {
                    QueryExecution createQueryExecution = this.qef.createQueryExecution(copy.toString());
                    try {
                        j = createQueryExecution.execSelect().next().getLiteral("cnt").getInt();
                        if (createQueryExecution != null) {
                            createQueryExecution.close();
                        }
                    } finally {
                    }
                } catch (Exception e) {
                    LOGGER.error(e.getMessage(), e.getCause());
                }
            }
            descriptiveStatistics.addValue(j);
            progressBar.update(atomicInteger.getAndAdd(1), list.size());
        });
        return descriptiveStatistics;
    }

    private DescriptiveStatistics determineOptimalCBDSizes(Query query, List<String> list) {
        DescriptiveStatistics descriptiveStatistics = new DescriptiveStatistics();
        DecimalFormat.getPercentInstance(Locale.ROOT);
        AtomicInteger atomicInteger = new AtomicInteger(1);
        System.out.println(QueryUtils.getOptimalCBDStructure(query).toStringVerbose());
        ProgressBar progressBar = new ProgressBar();
        list.forEach(str -> {
            long j = -1;
            if (!this.useConstruct) {
                ParameterizedSparqlString copy = SPARQLUtils.CBD_TEMPLATE_DEPTH3.copy();
                copy.setIri("uri", str);
                try {
                    QueryExecution createQueryExecution = this.qef.createQueryExecution(copy.toString());
                    try {
                        j = createQueryExecution.execSelect().next().getLiteral("cnt").getInt();
                        if (createQueryExecution != null) {
                            createQueryExecution.close();
                        }
                    } finally {
                    }
                } catch (Exception e) {
                    LOGGER.error(e.getMessage(), e.getCause());
                }
            }
            descriptiveStatistics.addValue(j);
            progressBar.update(atomicInteger.getAndAdd(1), list.size());
        });
        return descriptiveStatistics;
    }

    private void exportGraph(Query query, File file) {
        mxGraph mxgraph = new mxGraph();
        Object defaultParent = mxgraph.getDefaultParent();
        mxgraph.getModel().beginUpdate();
        try {
            Set extractTriplePattern = this.utils.extractTriplePattern(query);
            HashMap hashMap = new HashMap();
            extractTriplePattern.forEach(triple -> {
                hashMap.putIfAbsent(triple.getSubject(), mxgraph.insertVertex(defaultParent, (String) null, triple.getSubject().toString(query.getPrefixMapping()), 20.0d, 20.0d, 40.0d, 30.0d));
                hashMap.putIfAbsent(triple.getObject(), mxgraph.insertVertex(defaultParent, (String) null, triple.getObject().toString(query.getPrefixMapping()), 20.0d, 20.0d, 40.0d, 30.0d));
            });
            extractTriplePattern.forEach(triple2 -> {
                mxgraph.insertEdge(defaultParent, (String) null, triple2.getPredicate().toString(query.getPrefixMapping()), hashMap.get(triple2.getSubject()), hashMap.get(triple2.getObject()));
            });
            mxgraph.getModel().endUpdate();
            new mxGraphComponent(mxgraph);
            new mxOrthogonalLayout(mxgraph).execute(mxgraph.getDefaultParent());
            HashMap hashMap2 = new HashMap();
            hashMap2.put(mxConstants.STYLE_SHAPE, "connector");
            hashMap2.put(mxConstants.STYLE_ENDARROW, "classic");
            hashMap2.put(mxConstants.STYLE_STROKECOLOR, "#000000");
            hashMap2.put(mxConstants.STYLE_FONTCOLOR, "#000000");
            hashMap2.put(mxConstants.STYLE_LABEL_BACKGROUNDCOLOR, "#ffffff");
            HashMap hashMap3 = new HashMap();
            hashMap3.put(mxConstants.STYLE_SHAPE, "ellipse");
            hashMap3.put(mxConstants.STYLE_VERTICAL_ALIGN, "bottom");
            mxStylesheet mxstylesheet = new mxStylesheet();
            mxstylesheet.setDefaultEdgeStyle(hashMap2);
            mxstylesheet.setDefaultVertexStyle(hashMap3);
            mxgraph.setStylesheet(mxstylesheet);
            BufferedImage createBufferedImage = mxCellRenderer.createBufferedImage(mxgraph, (Object[]) null, 1.0d, Color.WHITE, true, (mxRectangle) null);
            mxPngEncodeParam defaultEncodeParam = mxPngEncodeParam.getDefaultEncodeParam(createBufferedImage);
            try {
                FileOutputStream fileOutputStream = new FileOutputStream(file);
                mxPngImageEncoder mxpngimageencoder = new mxPngImageEncoder(fileOutputStream, defaultEncodeParam);
                if (createBufferedImage != null) {
                    mxpngimageencoder.encode(createBufferedImage);
                }
                fileOutputStream.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
        } catch (Throwable th) {
            mxgraph.getModel().endUpdate();
            throw th;
        }
    }
}
