package org.dllearner.algorithms.qtl;

import com.google.common.collect.Sets;
import com.hp.hpl.jena.query.QueryExecution;
import com.hp.hpl.jena.query.ResultSet;
import com.hp.hpl.jena.rdf.model.Model;
import com.hp.hpl.jena.rdf.model.Statement;
import com.hp.hpl.jena.rdf.model.StmtIterator;
import com.hp.hpl.jena.util.iterator.Filter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.concurrent.TimeUnit;
import org.aksw.jena_sparql_api.cache.core.QueryExecutionFactoryCacheEx;
import org.aksw.jena_sparql_api.cache.h2.CacheUtilsH2;
import org.aksw.jena_sparql_api.core.QueryExecutionFactory;
import org.aksw.jena_sparql_api.http.QueryExecutionFactoryHttp;
import org.aksw.jena_sparql_api.model.QueryExecutionFactoryModel;
import org.apache.commons.collections15.ListUtils;
import org.apache.log4j.Logger;
import org.dllearner.algorithms.qtl.cache.QueryTreeCache;
import org.dllearner.algorithms.qtl.datastructures.QueryTree;
import org.dllearner.algorithms.qtl.datastructures.impl.QueryTreeImpl;
import org.dllearner.algorithms.qtl.exception.EmptyLGGException;
import org.dllearner.algorithms.qtl.exception.NegativeTreeCoverageExecption;
import org.dllearner.algorithms.qtl.exception.TimeOutException;
import org.dllearner.algorithms.qtl.filters.KeywordBasedQueryTreeFilter;
import org.dllearner.algorithms.qtl.filters.QueryTreeFilter;
import org.dllearner.algorithms.qtl.operations.NBR;
import org.dllearner.algorithms.qtl.operations.lgg.LGGGenerator;
import org.dllearner.algorithms.qtl.operations.lgg.LGGGeneratorImpl;
import org.dllearner.algorithms.qtl.util.SPARQLEndpointEx;
import org.dllearner.core.AbstractCELA;
import org.dllearner.core.AbstractLearningProblem;
import org.dllearner.core.ComponentAnn;
import org.dllearner.core.EvaluatedDescription;
import org.dllearner.core.LearningProblem;
import org.dllearner.core.LearningProblemUnsupportedException;
import org.dllearner.core.SparqlQueryLearningAlgorithm;
import org.dllearner.core.options.CommonConfigOptions;
import org.dllearner.core.options.ConfigOption;
import org.dllearner.core.options.IntegerConfigOption;
import org.dllearner.kb.LocalModelBasedSparqlEndpointKS;
import org.dllearner.kb.SparqlEndpointKS;
import org.dllearner.kb.sparql.CachingConciseBoundedDescriptionGenerator;
import org.dllearner.kb.sparql.ConciseBoundedDescriptionGenerator;
import org.dllearner.kb.sparql.ConciseBoundedDescriptionGeneratorImpl;
import org.dllearner.kb.sparql.SparqlEndpoint;
import org.dllearner.learningproblems.PosNegLP;
import org.dllearner.learningproblems.PosOnlyLP;
import org.dllearner.utilities.Helper;
import org.semanticweb.owlapi.model.OWLClassExpression;
import org.semanticweb.owlapi.model.OWLIndividual;
import org.springframework.beans.factory.annotation.Autowired;

@ComponentAnn(name = "query tree learner", shortName = "qtl", version = 0.8d)
/* loaded from: input_file:org/dllearner/algorithms/qtl/QTL.class */
public class QTL extends AbstractCELA implements SparqlQueryLearningAlgorithm {
    private static final Logger logger = Logger.getLogger(QTL.class);
    private LearningProblem learningProblem;
    private SparqlEndpointKS endpointKS;
    private SparqlEndpoint endpoint;
    private Model model;
    private QueryExecutionFactory qef;
    private String cacheDirectory;
    private QueryTreeCache treeCache;
    private LGGGenerator<String> lggGenerator;
    private NBR<String> nbr;
    private List<String> posExamples;
    private List<String> negExamples;
    private List<QueryTree<String>> posExampleTrees;
    private List<QueryTree<String>> negExampleTrees;
    private QueryTreeFilter queryTreeFilter;
    private ConciseBoundedDescriptionGenerator cbdGenerator;
    private int maxExecutionTimeInSeconds;
    private int maxQueryTreeDepth;
    private QueryTree<String> lgg;
    private SortedSet<String> lggInstances;
    private Set<String> objectNamespacesToIgnore;
    private Set<String> allowedNamespaces;
    private Map<String, String> prefixes;
    private boolean enableNumericLiteralFilters;

    public static Collection<ConfigOption<?>> createConfigOptions() {
        LinkedList linkedList = new LinkedList();
        linkedList.add(CommonConfigOptions.maxExecutionTimeInSeconds(10));
        linkedList.add(new IntegerConfigOption("maxQueryTreeDepth", "recursion depth of query tree extraction", 2));
        return linkedList;
    }

    public QTL(AbstractLearningProblem abstractLearningProblem, SparqlEndpointKS sparqlEndpointKS) throws LearningProblemUnsupportedException {
        this(abstractLearningProblem, sparqlEndpointKS, null);
    }

    public QTL(AbstractLearningProblem abstractLearningProblem, SparqlEndpointKS sparqlEndpointKS, String str) throws LearningProblemUnsupportedException {
        this.maxExecutionTimeInSeconds = 60;
        this.maxQueryTreeDepth = 2;
        this.objectNamespacesToIgnore = new HashSet();
        this.allowedNamespaces = new HashSet();
        this.prefixes = new HashMap();
        this.enableNumericLiteralFilters = false;
        if (!(abstractLearningProblem instanceof PosOnlyLP) && !(abstractLearningProblem instanceof PosNegLP)) {
            throw new LearningProblemUnsupportedException(abstractLearningProblem.getClass(), getClass());
        }
        this.learningProblem = abstractLearningProblem;
        this.endpointKS = sparqlEndpointKS;
        this.cacheDirectory = str;
    }

    public QTL(SPARQLEndpointEx sPARQLEndpointEx, String str) {
        this.maxExecutionTimeInSeconds = 60;
        this.maxQueryTreeDepth = 2;
        this.objectNamespacesToIgnore = new HashSet();
        this.allowedNamespaces = new HashSet();
        this.prefixes = new HashMap();
        this.enableNumericLiteralFilters = false;
        this.endpoint = sPARQLEndpointEx;
        this.cacheDirectory = str;
        this.treeCache = new QueryTreeCache();
        this.cbdGenerator = new CachingConciseBoundedDescriptionGenerator(new ConciseBoundedDescriptionGeneratorImpl(sPARQLEndpointEx, str));
        this.cbdGenerator.setRecursionDepth(this.maxQueryTreeDepth);
        this.lggGenerator = new LGGGeneratorImpl();
        this.nbr = new NBR<>(sPARQLEndpointEx, str);
        this.nbr.setMaxExecutionTimeInSeconds(this.maxExecutionTimeInSeconds);
        this.posExampleTrees = new ArrayList();
        this.negExampleTrees = new ArrayList();
    }

    public QTL(SparqlEndpointKS sparqlEndpointKS, String str) {
        this.maxExecutionTimeInSeconds = 60;
        this.maxQueryTreeDepth = 2;
        this.objectNamespacesToIgnore = new HashSet();
        this.allowedNamespaces = new HashSet();
        this.prefixes = new HashMap();
        this.enableNumericLiteralFilters = false;
        this.endpointKS = sparqlEndpointKS;
        this.cacheDirectory = str;
        this.treeCache = new QueryTreeCache();
        this.cbdGenerator = new CachingConciseBoundedDescriptionGenerator(new ConciseBoundedDescriptionGeneratorImpl(this.endpoint, str));
        this.cbdGenerator.setRecursionDepth(this.maxQueryTreeDepth);
        this.lggGenerator = new LGGGeneratorImpl();
        this.nbr = new NBR<>(this.endpoint, str);
        this.nbr.setMaxExecutionTimeInSeconds(this.maxExecutionTimeInSeconds);
        this.posExampleTrees = new ArrayList();
        this.negExampleTrees = new ArrayList();
    }

    public QTL(Model model) {
        this.maxExecutionTimeInSeconds = 60;
        this.maxQueryTreeDepth = 2;
        this.objectNamespacesToIgnore = new HashSet();
        this.allowedNamespaces = new HashSet();
        this.prefixes = new HashMap();
        this.enableNumericLiteralFilters = false;
        this.model = model;
        this.treeCache = new QueryTreeCache();
        this.cbdGenerator = new CachingConciseBoundedDescriptionGenerator(new ConciseBoundedDescriptionGeneratorImpl(model));
        this.cbdGenerator.setRecursionDepth(this.maxQueryTreeDepth);
        this.lggGenerator = new LGGGeneratorImpl();
        this.nbr = new NBR<>(model);
        this.nbr.setMaxExecutionTimeInSeconds(this.maxExecutionTimeInSeconds);
        this.posExampleTrees = new ArrayList();
        this.negExampleTrees = new ArrayList();
    }

    public String getQuestion(List<String> list, List<String> list2) throws EmptyLGGException, NegativeTreeCoverageExecption, TimeOutException {
        this.posExamples = list;
        this.negExamples = list2;
        generatePositiveExampleTrees();
        generateNegativeExampleTrees();
        if (list2.isEmpty()) {
            QueryTreeImpl queryTreeImpl = new QueryTreeImpl("?");
            queryTreeImpl.addChild(new QueryTreeImpl("?"), (Object) "dummy");
            this.negExampleTrees.add(queryTreeImpl);
        }
        this.lgg = this.lggGenerator.getLGG(this.posExampleTrees);
        if (this.queryTreeFilter != null) {
            this.lgg = this.queryTreeFilter.getFilteredQueryTree(this.lgg);
        }
        if (logger.isDebugEnabled()) {
            logger.debug("LGG: \n" + this.lgg.getStringRepresentation());
        }
        if (this.lgg.isEmpty()) {
            throw new EmptyLGGException();
        }
        int coversNegativeQueryTree = coversNegativeQueryTree(this.lgg);
        if (coversNegativeQueryTree != -1) {
            throw new NegativeTreeCoverageExecption(list2.get(coversNegativeQueryTree));
        }
        this.lggInstances = getResources(this.lgg);
        this.nbr.setLGGInstances(this.lggInstances);
        return list2.isEmpty() ? this.nbr.getQuestion(this.lgg, this.negExampleTrees, getKnownResources()) : this.nbr.getQuestion(this.lgg, this.negExampleTrees, getKnownResources());
    }

    public void setExamples(List<String> list, List<String> list2) {
        this.posExamples = list;
        this.negExamples = list2;
    }

    public void addStatementFilter(Filter<Statement> filter) {
        this.treeCache.setStatementFilter(filter);
    }

    public void addQueryTreeFilter(QueryTreeFilter queryTreeFilter) {
        this.queryTreeFilter = queryTreeFilter;
    }

    public void setMaxExecutionTimeInSeconds(int i) {
        this.maxExecutionTimeInSeconds = i;
        this.nbr.setMaxExecutionTimeInSeconds(i);
    }

    public void setMaxQueryTreeDepth(int i) {
        this.maxQueryTreeDepth = i;
    }

    public int getMaxQueryTreeDepth() {
        return this.maxQueryTreeDepth;
    }

    public void setPrefixes(Map<String, String> map) {
        this.prefixes = map;
    }

    public Map<String, String> getPrefixes() {
        return this.prefixes;
    }

    public String getSPARQLQuery() {
        if (this.lgg == null) {
            this.lgg = this.lggGenerator.getLGG(getQueryTrees(this.posExamples));
        }
        return this.lgg.toSPARQLQueryString();
    }

    public void setObjectNamespacesToIgnore(Set<String> set) {
        this.objectNamespacesToIgnore = set;
    }

    public void setRestrictToNamespaces(List<String> list) {
        this.cbdGenerator.setRestrictToNamespaces(list);
    }

    private void generatePositiveExampleTrees() {
        this.posExampleTrees.clear();
        this.posExampleTrees.addAll(getQueryTrees(this.posExamples));
    }

    private void generateNegativeExampleTrees() {
        this.negExampleTrees.clear();
        this.negExampleTrees.addAll(getQueryTrees(this.negExamples));
    }

    private List<QueryTree<String>> getQueryTrees(List<String> list) {
        ArrayList arrayList = new ArrayList();
        for (String str : list) {
            try {
                logger.debug("Generating tree for " + str);
                Model conciseBoundedDescription = this.cbdGenerator.getConciseBoundedDescription(str);
                applyFilters(conciseBoundedDescription);
                QueryTree<String> queryTree = this.treeCache.getQueryTree(str, conciseBoundedDescription);
                if (logger.isDebugEnabled()) {
                    logger.debug("Tree for resource " + str);
                    logger.debug(queryTree.getStringRepresentation());
                }
                arrayList.add(queryTree);
            } catch (Exception e) {
                logger.error("Failed to create tree for resource " + str + ".", e);
            }
        }
        return arrayList;
    }

    private void applyFilters(Model model) {
        StmtIterator listStatements = model.listStatements();
        while (listStatements.hasNext()) {
            Statement statement = (Statement) listStatements.next();
            Iterator<String> it = this.objectNamespacesToIgnore.iterator();
            while (true) {
                if (it.hasNext()) {
                    String next = it.next();
                    if (statement.getObject().isURIResource() && statement.getObject().asResource().getURI().startsWith(next)) {
                        listStatements.remove();
                        break;
                    }
                }
            }
        }
    }

    private List<String> getKnownResources() {
        return ListUtils.union(this.posExamples, this.negExamples);
    }

    private int coversNegativeQueryTree(QueryTree<String> queryTree) {
        for (int i = 0; i < this.negExampleTrees.size(); i++) {
            if (this.negExampleTrees.get(i).isSubsumedBy(queryTree)) {
                return i;
            }
        }
        return -1;
    }

    private SortedSet<String> getResources(QueryTree<String> queryTree) {
        TreeSet treeSet = new TreeSet();
        QueryExecution createQueryExecution = this.qef.createQueryExecution(getDistinctSPARQLQuery(queryTree));
        ResultSet execSelect = createQueryExecution.execSelect();
        while (execSelect.hasNext()) {
            treeSet.add(execSelect.next().getResource("x0").getURI());
        }
        createQueryExecution.close();
        return treeSet;
    }

    private String getDistinctSPARQLQuery(QueryTree<String> queryTree) {
        return queryTree.toSPARQLQueryString();
    }

    @Override // org.dllearner.core.LearningAlgorithm
    public void start() {
        generatePositiveExampleTrees();
        this.lgg = this.lggGenerator.getLGG(this.posExampleTrees);
        if (this.queryTreeFilter != null) {
            this.lgg = this.queryTreeFilter.getFilteredQueryTree(this.lgg);
        }
        if (logger.isDebugEnabled()) {
            logger.debug("LGG: \n" + this.lgg.getStringRepresentation());
        }
        if (logger.isInfoEnabled()) {
            logger.info("Generated SPARQL query:\n" + this.lgg.toSPARQLQueryString(true, this.enableNumericLiteralFilters, this.prefixes));
        }
        if (this.negExamples.isEmpty()) {
            return;
        }
        generateNegativeExampleTrees();
        try {
            int coversNegativeQueryTree = coversNegativeQueryTree(this.lgg);
            if (coversNegativeQueryTree != -1) {
                throw new NegativeTreeCoverageExecption(this.negExamples.get(coversNegativeQueryTree));
            }
            this.lggInstances = getResources(this.lgg);
            this.nbr.setLGGInstances(this.lggInstances);
            logger.info("Question:\n" + (this.negExamples.isEmpty() ? this.nbr.getQuestion(this.lgg, this.negExampleTrees, getKnownResources()) : this.nbr.getQuestion(this.lgg, this.negExampleTrees, getKnownResources())));
        } catch (NegativeTreeCoverageExecption e) {
            e.printStackTrace();
        } catch (TimeOutException e2) {
            e2.printStackTrace();
        }
    }

    public void setEnableNumericLiteralFilters(boolean z) {
        this.enableNumericLiteralFilters = z;
    }

    public boolean isEnableNumericLiteralFilters() {
        return this.enableNumericLiteralFilters;
    }

    @Override // org.dllearner.core.SparqlQueryLearningAlgorithm
    public List<String> getCurrentlyBestSPARQLQueries(int i) {
        return Collections.singletonList(getBestSPARQLQuery());
    }

    @Override // org.dllearner.core.SparqlQueryLearningAlgorithm
    public String getBestSPARQLQuery() {
        return this.lgg.toSPARQLQueryString();
    }

    @Override // org.dllearner.core.Component
    public void init() {
        if (this.endpointKS == null) {
            this.qef = new QueryExecutionFactoryModel(this.model);
            this.cbdGenerator = new CachingConciseBoundedDescriptionGenerator(new ConciseBoundedDescriptionGeneratorImpl(this.model));
            this.nbr = new NBR<>(this.model);
        } else if (this.endpointKS.isRemote()) {
            SparqlEndpoint endpoint = this.endpointKS.getEndpoint();
            QueryExecutionFactoryHttp queryExecutionFactoryHttp = new QueryExecutionFactoryHttp(endpoint.getURL().toString(), endpoint.getDefaultGraphURIs());
            if (this.cacheDirectory != null) {
                new QueryExecutionFactoryCacheEx(queryExecutionFactoryHttp, CacheUtilsH2.createCacheFrontend(this.cacheDirectory, true, TimeUnit.DAYS.toMillis(30L)));
            }
        } else {
            this.qef = new QueryExecutionFactoryModel(((LocalModelBasedSparqlEndpointKS) this.endpointKS).getModel());
        }
        if (this.learningProblem instanceof PosOnlyLP) {
            this.posExamples = convert(((PosOnlyLP) this.learningProblem).getPositiveExamples());
            this.negExamples = new ArrayList();
        } else if (this.learningProblem instanceof PosNegLP) {
            this.posExamples = convert(((PosNegLP) this.learningProblem).getPositiveExamples());
            this.negExamples = convert(((PosNegLP) this.learningProblem).getNegativeExamples());
        }
        this.treeCache = new QueryTreeCache();
        this.treeCache.addAllowedNamespaces(this.allowedNamespaces);
        if (this.endpointKS != null) {
            this.nbr = new NBR<>(this.endpoint);
            this.nbr.setMaxExecutionTimeInSeconds(this.maxExecutionTimeInSeconds);
            if (this.endpointKS instanceof LocalModelBasedSparqlEndpointKS) {
                this.cbdGenerator = new CachingConciseBoundedDescriptionGenerator(new ConciseBoundedDescriptionGeneratorImpl((Model) ((LocalModelBasedSparqlEndpointKS) this.endpointKS).getModel()));
            } else {
                this.endpoint = this.endpointKS.getEndpoint();
                this.cbdGenerator = new CachingConciseBoundedDescriptionGenerator(new ConciseBoundedDescriptionGeneratorImpl(this.endpoint, this.endpointKS.getCache()));
            }
        }
        this.cbdGenerator.setRecursionDepth(this.maxQueryTreeDepth);
        this.lggGenerator = new LGGGeneratorImpl();
        this.posExampleTrees = new ArrayList();
        this.negExampleTrees = new ArrayList();
    }

    private List<String> convert(Set<OWLIndividual> set) {
        ArrayList arrayList = new ArrayList();
        Iterator<OWLIndividual> it = set.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().toStringID());
        }
        return arrayList;
    }

    public QueryTree<String> getLgg() {
        return this.lgg;
    }

    @Override // org.dllearner.core.AbstractCELA, org.dllearner.core.LearningAlgorithm
    @Autowired
    public void setLearningProblem(LearningProblem learningProblem) {
        this.learningProblem = learningProblem;
    }

    public SparqlEndpointKS getEndpointKS() {
        return this.endpointKS;
    }

    @Autowired
    public void setEndpointKS(SparqlEndpointKS sparqlEndpointKS) {
        this.endpointKS = sparqlEndpointKS;
    }

    @Override // org.dllearner.core.StoppableLearningAlgorithm
    public void stop() {
    }

    @Override // org.dllearner.core.StoppableLearningAlgorithm
    public boolean isRunning() {
        return false;
    }

    @Override // org.dllearner.core.AbstractCELA
    public OWLClassExpression getCurrentlyBestDescription() {
        if (this.lgg == null) {
            return null;
        }
        return this.lgg.asOWLClassExpression();
    }

    @Override // org.dllearner.core.AbstractCELA
    public EvaluatedDescription getCurrentlyBestEvaluatedDescription() {
        return null;
    }

    public void setAllowedNamespaces(Set<String> set) {
        this.allowedNamespaces = set;
    }

    public static void main(String[] strArr) throws Exception {
        HashSet hashSet = new HashSet();
        hashSet.add("http://dbpedia.org/resource/Liverpool_F.C.");
        hashSet.add("http://dbpedia.org/resource/Chelsea_F.C.");
        SparqlEndpointKS sparqlEndpointKS = new SparqlEndpointKS(SparqlEndpoint.getEndpointDBpedia());
        sparqlEndpointKS.init();
        PosOnlyLP posOnlyLP = new PosOnlyLP();
        posOnlyLP.setPositiveExamples(Helper.getIndividualSet(hashSet));
        QTL qtl = new QTL(posOnlyLP, sparqlEndpointKS, "cache");
        qtl.setAllowedNamespaces(Sets.newHashSet(new String[]{"http://dbpedia.org/ontology/", "http://dbpedia.org/resource/"}));
        qtl.addQueryTreeFilter(new KeywordBasedQueryTreeFilter(Arrays.asList("soccer club", "Premier League")));
        qtl.init();
        qtl.start();
        System.out.println(qtl.getBestSPARQLQuery());
        System.out.println(qtl.getCurrentlyBestDescription());
    }
}
