package org.aksw.qa.dataset_generator;

import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.aksw.jena_sparql_api.cache.h2.CacheUtilsH2;
import org.aksw.jena_sparql_api.core.FluentQueryExecutionFactory;
import org.aksw.jena_sparql_api.core.QueryExecutionFactory;
import org.aksw.jena_sparql_api.http.QueryExecutionHttpWrapper;
import org.aksw.qa.commons.datastructure.Entity;
import org.aksw.qa.commons.datastructure.IQuestion;
import org.aksw.qa.commons.load.Dataset;
import org.aksw.qa.commons.load.LoaderController;
import org.aksw.qa.commons.nlp.nerd.AGDISTIS;
import org.aksw.qa.commons.nlp.nerd.Spotlight;
import org.apache.jena.query.Query;
import org.apache.jena.query.QueryExecution;
import org.apache.jena.query.ResultSet;
import org.apache.jena.rdf.model.Model;
import org.apache.jena.rdf.model.NsIterator;
import org.apache.jena.rdf.model.RDFNode;
import org.apache.jena.rdf.model.Statement;
import org.apache.jena.rdf.model.StmtIterator;
import org.apache.jena.rdf.model.impl.ResourceImpl;
import org.apache.jena.sparql.vocabulary.FOAF;
import org.apache.jena.vocabulary.RDFS;
import org.dllearner.algorithms.qtl.QueryTreeUtils;
import org.dllearner.algorithms.qtl.datastructures.impl.RDFResourceTree;
import org.dllearner.algorithms.qtl.impl.QueryTreeFactory;
import org.dllearner.algorithms.qtl.impl.QueryTreeFactoryBaseInv;
import org.dllearner.algorithms.qtl.operations.lgg.LGGGenerator;
import org.dllearner.algorithms.qtl.operations.lgg.LGGGeneratorSimple;
import org.dllearner.algorithms.qtl.util.StopURIsDBpedia;
import org.dllearner.algorithms.qtl.util.StopURIsOWL;
import org.dllearner.algorithms.qtl.util.StopURIsRDFS;
import org.dllearner.algorithms.qtl.util.StopURIsSKOS;
import org.dllearner.algorithms.qtl.util.filters.NamespaceDropStatementFilter;
import org.dllearner.algorithms.qtl.util.filters.ObjectDropStatementFilter;
import org.dllearner.algorithms.qtl.util.filters.PredicateDropStatementFilter;
import org.dllearner.kb.sparql.ConciseBoundedDescriptionGenerator;
import org.dllearner.kb.sparql.ConciseBoundedDescriptionGeneratorImpl;
import org.dllearner.kb.sparql.SymmetricConciseBoundedDescriptionGeneratorImpl;
import org.json.simple.parser.ParseException;
import org.semanticweb.owlapi.model.IRI;
import org.semanticweb.owlapi.util.SimpleIRIShortFormProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/aksw/qa/dataset_generator/DatasetGenerator.class */
public class DatasetGenerator {
    private static final Logger LOGGER = LoggerFactory.getLogger(DatasetGenerator.class);
    private static final String NAMESPACE = "http://dbpedia.org/resource/";
    private static final String NAMESPACE2 = "http://dbpedia.org/property/";
    private static final String CSV_FILE_NAME = "evaluation_datasetevaluator_avg.csv";
    private ConciseBoundedDescriptionGenerator cbdGen;
    private LGGGenerator lggGen;
    private QueryExecutionFactory qef;
    private AGDISTIS disambiguator = new AGDISTIS();
    private Spotlight recognizer = new Spotlight();
    List<String[]> evaluation = new ArrayList();
    double correct = 0.0d;
    double f1sum = 0.0d;
    double accsum = 0.0d;
    private QueryTreeFactory qtf = new QueryTreeFactoryBaseInv();

    public DatasetGenerator(QueryExecutionFactory queryExecutionFactory) {
        this.qef = queryExecutionFactory;
        this.cbdGen = new ConciseBoundedDescriptionGeneratorImpl(queryExecutionFactory);
        this.cbdGen = new SymmetricConciseBoundedDescriptionGeneratorImpl(queryExecutionFactory);
        ArrayList newArrayList = Lists.newArrayList(new Predicate[]{new PredicateDropStatementFilter(StopURIsDBpedia.get()), new ObjectDropStatementFilter(StopURIsDBpedia.get()), new PredicateDropStatementFilter(Sets.union(StopURIsRDFS.get(), Sets.newHashSet(new String[]{RDFS.seeAlso.getURI()}))), new PredicateDropStatementFilter(StopURIsOWL.get()), new ObjectDropStatementFilter(StopURIsOWL.get()), new PredicateDropStatementFilter(StopURIsSKOS.get()), new ObjectDropStatementFilter(StopURIsSKOS.get()), new NamespaceDropStatementFilter(Sets.newHashSet(new String[]{"http://purl.org/dc/terms/", "http://dbpedia.org/class/yago/", FOAF.getURI()})), new PredicateDropStatementFilter(Sets.newHashSet(new String[]{"http://www.w3.org/2002/07/owl#equivalentClass", "http://www.w3.org/2002/07/owl#disjointWith"}))});
        this.qtf.addDropFilters((Predicate[]) newArrayList.toArray(new Predicate[newArrayList.size()]));
        this.lggGen = new LGGGeneratorSimple();
    }

    public void generate(Map<String, Set<String>> map) {
        map.forEach((str, set) -> {
            LOGGER.info("Question:" + str);
            LOGGER.info("Answers:" + set);
            LOGGER.info("###################################################################################");
            LOGGER.info("processing \"{}\" ...", str);
            Map<String, List<Entity>> recognize = recognize(str);
            Map<String, Optional<String>> disambiguateAnswers = disambiguateAnswers(set);
            if (disambiguateAnswers.isEmpty()) {
                LOGGER.warn("Could not find the answer entities for");
            } else {
                generateSPARQLQuery(disambiguateAnswers, recognize, str);
            }
        });
        this.evaluation.add(new String[]{"average ", "0.0", "0.0", "0.0", "0.0", "0.0", (this.f1sum / map.size()) + "", (this.accsum / map.size()) + ""});
        this.evaluation.add(new String[]{"overall ", "-", "-", "-", "-", "-", "-", (this.correct / map.size()) + ""});
    }

    public void generate(List<IQuestion> list) {
        generate((Map<String, Set<String>>) list.stream().collect(Collectors.toMap(iQuestion -> {
            return (String) iQuestion.getLanguageToQuestion().get("en");
        }, iQuestion2 -> {
            return iQuestion2.getGoldenAnswers();
        }, (set, set2) -> {
            return set;
        })));
    }

    private Map<String, List<Entity>> recognize(String str) {
        LOGGER.info("entity detection in question...");
        Map<String, List<Entity>> entities = this.recognizer.getEntities(str);
        if (str.contains("Super Bowl 50")) {
            Entity entity = new Entity();
            entity.setLabel("Super Bowl 50");
            entity.getUris().add(new ResourceImpl("http://dbpedia.org/resource/Super_Bowl_50"));
            entities.put("Super Bowl", Lists.newArrayList(new Entity[]{entity}));
        }
        LOGGER.info("entities:{}", entities.entrySet().stream().map(entry -> {
            return entry.toString();
        }).collect(Collectors.joining("\n")));
        return entities;
    }

    private Map<String, Optional<String>> disambiguateAnswers(Set<String> set) {
        return (Map) set.stream().collect(Collectors.toMap(str -> {
            return str;
        }, str2 -> {
            return disambiguate(str2);
        }));
    }

    private Optional<String> disambiguate(String str) {
        LOGGER.info("NED for {} ...", str);
        try {
            String str2 = (String) this.disambiguator.runDisambiguation("<entity>" + str + "</entity>").get(str);
            if (str2 == null) {
                LOGGER.warn("no entity found for {}", str);
            } else {
                LOGGER.info("{} -> {}", str, str2);
            }
            return Optional.ofNullable(str2);
        } catch (ParseException | IOException e) {
            e.printStackTrace();
            return null;
        }
    }

    private void generateSPARQLQuery(Map<String, Optional<String>> map, Map<String, List<Entity>> map2, String str) {
        System.out.println((Set) map2.values().stream().flatMap(list -> {
            return list.stream();
        }).map(entity -> {
            return entity.getUris();
        }).flatMap(list2 -> {
            return list2.stream();
        }).collect(Collectors.toSet()));
        Set set = (Set) map2.values().stream().flatMap(list3 -> {
            return list3.stream();
        }).collect(Collectors.toSet());
        System.out.println(set);
        ArrayList arrayList = new ArrayList();
        Set hashSet = new HashSet();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList(map.size());
        map.values().stream().filter(optional -> {
            return optional.isPresent();
        }).map(optional2 -> {
            return (String) optional2.get();
        }).forEach(str2 -> {
            LOGGER.info(str2);
            Model conciseBoundedDescription = this.cbdGen.getConciseBoundedDescription(str2);
            LOGGER.info("|cbd(" + str2 + ")|=" + conciseBoundedDescription.size() + " triples");
            if (map.size() == 1 && conciseBoundedDescription.size() > 100) {
                StmtIterator listStatements = conciseBoundedDescription.listStatements();
                Statement[] statementArr = new Statement[100];
                for (int i = 0; i < 100; i++) {
                    statementArr[i] = listStatements.nextStatement();
                }
                conciseBoundedDescription.removeAll();
                conciseBoundedDescription.add(statementArr);
            }
            Predicate predicate = statement -> {
                return set.isEmpty() || set.stream().anyMatch(entity2 -> {
                    return statement.getSubject().toString().toLowerCase().contains(entity2.getLabel().toLowerCase()) || statement.getObject().toString().toLowerCase().contains(entity2.getLabel().toLowerCase());
                });
            };
            Predicate predicate2 = statement2 -> {
                return !"http://www.w3.org/1999/02/22-rdf-syntax-ns#type".equals(statement2.getPredicate().toString()) || statement2.getObject().toString().startsWith("http://dbpedia.org/ontology/");
            };
            LOGGER.info("|cbd_filtered(" + str2 + ")|=" + conciseBoundedDescription.size() + " triples");
            conciseBoundedDescription.remove(conciseBoundedDescription.listStatements().filterDrop(predicate).toList());
            LOGGER.info("|cbd_filtered(" + str2 + ")|=" + conciseBoundedDescription.size() + " triples");
            conciseBoundedDescription.remove(conciseBoundedDescription.listStatements().filterDrop(predicate2).toList());
            LOGGER.info("|cbd_filtered(" + str2 + ")|=" + conciseBoundedDescription.size() + " triples");
            boolean z = false;
            NsIterator listNameSpaces = conciseBoundedDescription.listNameSpaces();
            while (listNameSpaces.hasNext() && !z) {
                String nextNs = listNameSpaces.nextNs();
                z = NAMESPACE2.equals(nextNs) || "http://dbpedia.org/ontology/".equals(nextNs);
            }
            if (z) {
                HashSet hashSet2 = new HashSet();
                conciseBoundedDescription.listStatements().forEachRemaining(statement3 -> {
                    hashSet2.add(statement3.getPredicate().toString());
                });
                arrayList.add(hashSet2);
                if (conciseBoundedDescription.size() > 1000) {
                    arrayList2.add(new Object[]{str2, conciseBoundedDescription});
                    return;
                }
                RDFResourceTree queryTree = this.qtf.getQueryTree(str2, conciseBoundedDescription);
                arrayList3.add(queryTree);
                LOGGER.info(queryTree.getStringRepresentation(true));
            }
        });
        if (arrayList.size() > 0) {
            hashSet.addAll((Collection) arrayList.get(0));
            arrayList.remove(0);
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                hashSet = Sets.intersection(hashSet, (Set) it.next());
            }
        }
        HashSet hashSet2 = new HashSet();
        hashSet2.addAll(hashSet);
        arrayList2.forEach(objArr -> {
            Predicate predicate = statement -> {
                return hashSet2.contains(statement.getPredicate().toString());
            };
            String str3 = (String) objArr[0];
            Model model = (Model) objArr[1];
            LOGGER.info("|cbd_filtered|=" + model.size() + " triples");
            model.remove(model.listStatements().filterDrop(predicate).toList());
            LOGGER.info("|cbd_filtered|=" + model.size() + " triples");
            if (model.size() >= 1000) {
                LOGGER.error("CBD for " + str3 + " to large the Tree g eneration was skipped");
                return;
            }
            RDFResourceTree queryTree = this.qtf.getQueryTree(str3, model);
            arrayList3.add(queryTree);
            LOGGER.info(queryTree.getStringRepresentation(true));
        });
        if (arrayList3.size() > 0) {
            Query sPARQLQuery = QueryTreeUtils.toSPARQLQuery(this.lggGen.getLGG(arrayList3));
            System.out.println(sPARQLQuery);
            evaluate(map.keySet(), sPARQLQuery, str);
        } else {
            LOGGER.debug("No tree found for question: " + str);
            this.f1sum += 0.0d;
            this.accsum += 0.0d;
            this.evaluation.add(new String[]{str, "0.0", "0.0", "0.0", "0.0", "0.0", "0.0", "0.0"});
        }
    }

    private void evaluate(Set<String> set, Query query, String str) {
        QueryExecution createQueryExecution = this.qef.createQueryExecution(query);
        LOGGER.debug(query.toString());
        LOGGER.debug("Question:" + str);
        double d = 0.0d;
        double d2 = 0.0d;
        try {
            ResultSet execSelect = createQueryExecution.execSelect();
            while (execSelect.hasNext()) {
                RDFNode rDFNode = execSelect.next().get("s");
                if (rDFNode != null && rDFNode.isResource()) {
                    if (set.contains(rDFNode.asResource().toString().replace(NAMESPACE, "").replace(NAMESPACE2, ""))) {
                        d += 1.0d;
                    } else {
                        d2 += 1.0d;
                    }
                }
            }
        } catch (Exception e) {
            LOGGER.error("Executing SPARQL Query" + query.toString() + "failed");
        }
        if (d <= 0.0d && d2 <= 0.0d) {
            LOGGER.debug("No results for question: " + str);
            this.f1sum += 0.0d;
            this.accsum += 0.0d;
            this.evaluation.add(new String[]{str, "0.0", "0.0", "0.0", "0.0", "0.0", "0.0", "0.0"});
            return;
        }
        double size = set.size() - d;
        double d3 = d / (d + d2);
        double d4 = d / (d + size);
        double d5 = d / ((d + size) + d2);
        double d6 = (2.0d * (d4 * d3)) / (d4 + d3);
        if (d6 == 1.0d) {
            this.correct += 1.0d;
        }
        if (Double.isNaN(d6)) {
            d6 = 0.0d;
        }
        this.f1sum += d6;
        if (Double.isNaN(d5)) {
            d5 = 0.0d;
        }
        this.accsum += d5;
        LOGGER.debug("Evaluation: tp:" + d + " fp:" + d2 + " fn:" + size + " recall:" + d3 + " precision:" + d4 + " f1:" + d6 + " accuracy:" + d5);
        this.evaluation.add(new String[]{str, d + "", d2 + "", size + "", d3 + "", d4 + "", d6 + "", d5 + ""});
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void updateGoldenAnswers(QueryExecutionFactory queryExecutionFactory, IQuestion iQuestion) {
        HashSet hashSet = new HashSet();
        if (null == iQuestion || null == iQuestion.getSparqlQuery() || iQuestion.getSparqlQuery().contains("ASK")) {
            return;
        }
        QueryExecution createQueryExecution = queryExecutionFactory.createQueryExecution(iQuestion.getSparqlQuery());
        try {
            ResultSet execSelect = createQueryExecution.execSelect();
            while (execSelect.hasNext()) {
                RDFNode rDFNode = execSelect.next().get("uri");
                if (rDFNode != null && rDFNode.isResource()) {
                    hashSet.add(rDFNode.asResource().getURI());
                }
            }
            if (createQueryExecution != null) {
                createQueryExecution.close();
            }
            iQuestion.setGoldenAnswers(hashSet);
        } catch (Throwable th) {
            if (createQueryExecution != null) {
                try {
                    createQueryExecution.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private String convertToCSV(String[] strArr) {
        return (String) Stream.of((Object[]) strArr).collect(Collectors.joining(","));
    }

    private void writeCSVEvaluation() throws IOException {
        PrintWriter printWriter = new PrintWriter(new File(CSV_FILE_NAME));
        try {
            printWriter.println("Question,truepositive,falsepositive,falsenegative,recall,precision,f1,accuracy");
            Stream<R> map = this.evaluation.stream().map(this::convertToCSV);
            Objects.requireNonNull(printWriter);
            map.forEach(printWriter::println);
            printWriter.close();
        } catch (Throwable th) {
            try {
                printWriter.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    public static void main(String[] strArr) throws IOException {
        QueryExecutionFactory queryExecutionFactory = (QueryExecutionFactory) ((FluentQueryExecutionFactory) FluentQueryExecutionFactory.http("http://dbpedia.org/sparql", Lists.newArrayList(new String[]{"http://dbpedia.org"})).config().withPostProcessor(queryExecution -> {
            ((QueryExecutionHttpWrapper) queryExecution).getDecoratee().setModelContentType("application/rdf+xml");
        }).withCache(CacheUtilsH2.createCacheFrontend("/tmp/qald/sparql", true, TimeUnit.DAYS.toMillis(30L))).end()).create();
        List<IQuestion> load = LoaderController.load(Dataset.QALD9_Train_Multilingual);
        load.stream().filter(iQuestion -> {
            return iQuestion.getAnswerType().equals("resource");
        }).collect(Collectors.toList());
        load.forEach(iQuestion2 -> {
            updateGoldenAnswers(queryExecutionFactory, iQuestion2);
        });
        SimpleIRIShortFormProvider simpleIRIShortFormProvider = new SimpleIRIShortFormProvider();
        load.forEach(iQuestion3 -> {
            iQuestion3.setGoldenAnswers((Set) iQuestion3.getGoldenAnswers().stream().map(str -> {
                return simpleIRIShortFormProvider.getShortForm(IRI.create(str));
            }).collect(Collectors.toSet()));
        });
        DatasetGenerator datasetGenerator = new DatasetGenerator(queryExecutionFactory);
        datasetGenerator.generate(load);
        datasetGenerator.writeCSVEvaluation();
    }
}
