package org.dllearner.cli;

import java.io.File;
import java.text.DecimalFormat;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.TreeSet;
import org.dllearner.algorithms.qtl.QTL2;
import org.dllearner.algorithms.qtl.datastructures.QueryTree;
import org.dllearner.algorithms.qtl.datastructures.impl.QueryTreeImpl;
import org.dllearner.core.AbstractLearningProblem;
import org.dllearner.core.ComponentInitException;
import org.dllearner.core.IndividualReasoner;
import org.dllearner.learningproblems.Heuristics;
import org.dllearner.learningproblems.PosNegLP;
import org.dllearner.learningproblems.PosOnlyLP;
import org.dllearner.reasoning.SPARQLReasoner;
import org.dllearner.utilities.Files;
import org.dllearner.utilities.Helper;
import org.dllearner.utilities.datastructures.Datastructures;
import org.dllearner.utilities.owl.OWLClassExpressionUtils;
import org.dllearner.utilities.statistics.Stat;
import org.semanticweb.owlapi.model.OWLClassExpression;
import org.semanticweb.owlapi.model.OWLIndividual;

/* loaded from: input_file:org/dllearner/cli/SPARQLCrossValidation.class */
public class SPARQLCrossValidation {
    protected static boolean writeToFile = false;
    protected static File outputFile;
    protected Stat runtime = new Stat();
    protected Stat accuracy = new Stat();
    protected Stat length = new Stat();
    protected Stat accuracyTraining = new Stat();
    protected Stat fMeasure = new Stat();
    protected Stat fMeasureTraining = new Stat();
    protected Stat trainingCompletenessStat = new Stat();
    protected Stat trainingCorrectnessStat = new Stat();
    protected Stat testingCompletenessStat = new Stat();
    protected Stat testingCorrectnessStat = new Stat();
    QueryTreeImpl.LiteralNodeSubsumptionStrategy literalNodeSubsumptionStrategy = QueryTreeImpl.LiteralNodeSubsumptionStrategy.INTERVAL;

    public SPARQLCrossValidation() {
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v180, types: [java.util.Set] */
    public SPARQLCrossValidation(QTL2 qtl2, AbstractLearningProblem abstractLearningProblem, IndividualReasoner individualReasoner, int i, boolean z) {
        Set positiveExamples;
        HashSet hashSet;
        DecimalFormat decimalFormat = new DecimalFormat();
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        LinkedList linkedList3 = new LinkedList();
        LinkedList linkedList4 = new LinkedList();
        if (abstractLearningProblem instanceof PosNegLP) {
            positiveExamples = ((PosNegLP) abstractLearningProblem).getPositiveExamples();
            hashSet = ((PosNegLP) abstractLearningProblem).getNegativeExamples();
        } else {
            if (!(abstractLearningProblem instanceof PosOnlyLP)) {
                throw new IllegalArgumentException("Only PosNeg and PosOnly learning problems are supported");
            }
            positiveExamples = ((PosNegLP) abstractLearningProblem).getPositiveExamples();
            hashSet = new HashSet();
        }
        LinkedList linkedList5 = new LinkedList(positiveExamples);
        LinkedList linkedList6 = new LinkedList(hashSet);
        Collections.shuffle(linkedList5, new Random(1L));
        Collections.shuffle(linkedList6, new Random(2L));
        if (!z && positiveExamples.size() < i && hashSet.size() < i) {
            System.out.println("The number of folds is higher than the number of positive/negative examples. This can result in empty test sets. Exiting.");
            System.exit(0);
        }
        if (z) {
            for (int i2 = 0; i2 < positiveExamples.size() + hashSet.size(); i2++) {
            }
            System.out.println("Leave-one-out not supported yet.");
            System.exit(1);
        } else {
            int[] calculateSplits = calculateSplits(positiveExamples.size(), i);
            int[] calculateSplits2 = calculateSplits(hashSet.size(), i);
            for (int i3 = 0; i3 < i; i3++) {
                Set<OWLIndividual> testingSet = getTestingSet(linkedList5, calculateSplits, i3);
                Set<OWLIndividual> testingSet2 = getTestingSet(linkedList6, calculateSplits2, i3);
                linkedList3.add(i3, testingSet);
                linkedList4.add(i3, testingSet2);
                linkedList.add(i3, getTrainingSet(positiveExamples, testingSet));
                linkedList2.add(i3, getTrainingSet(hashSet, testingSet2));
            }
        }
        for (int i4 = 0; i4 < i; i4++) {
            Set individualSetToStringSet = Datastructures.individualSetToStringSet((Set) linkedList.get(i4));
            Set individualSetToStringSet2 = Datastructures.individualSetToStringSet((Set) linkedList2.get(i4));
            if (abstractLearningProblem instanceof PosNegLP) {
                ((PosNegLP) abstractLearningProblem).setPositiveExamples((Set) linkedList.get(i4));
                ((PosNegLP) abstractLearningProblem).setNegativeExamples((Set) linkedList2.get(i4));
            } else if (abstractLearningProblem instanceof PosOnlyLP) {
                ((PosOnlyLP) abstractLearningProblem).setPositiveExamples(new TreeSet((Collection) linkedList.get(i4)));
            }
            try {
                abstractLearningProblem.init();
                qtl2.init();
            } catch (ComponentInitException e) {
                e.printStackTrace();
            }
            long nanoTime = System.nanoTime();
            qtl2.start();
            long nanoTime2 = System.nanoTime() - nanoTime;
            this.runtime.addNumber(nanoTime2 / 1.0E9d);
            OWLClassExpression currentlyBestDescription = qtl2.getCurrentlyBestDescription();
            System.out.println(currentlyBestDescription);
            Set difference = Helper.difference((Set) linkedList3.get(i4), hasType((Set) linkedList3.get(i4), qtl2));
            Set<OWLIndividual> hasType = hasType((Set) linkedList4.get(i4), qtl2);
            outputWriter("test set errors pos: " + difference);
            outputWriter("test set errors neg: " + hasType);
            System.out.println(getCorrectPosClassified(individualReasoner, currentlyBestDescription, (Set) linkedList.get(i4)));
            int correctPosClassified = getCorrectPosClassified((Set) linkedList.get(i4), qtl2);
            int correctNegClassified = getCorrectNegClassified((Set) linkedList2.get(i4), qtl2);
            double size = 100.0d * ((correctPosClassified + correctNegClassified) / (((Set) linkedList.get(i4)).size() + ((Set) linkedList2.get(i4)).size()));
            this.accuracyTraining.addNumber(size);
            int correctPosClassified2 = getCorrectPosClassified((Set) linkedList3.get(i4), qtl2);
            int correctNegClassified2 = getCorrectNegClassified((Set) linkedList4.get(i4), qtl2);
            double size2 = 100.0d * ((correctPosClassified2 + correctNegClassified2) / (((Set) linkedList3.get(i4)).size() + ((Set) linkedList4.get(i4)).size()));
            this.accuracy.addNumber(size2);
            this.fMeasureTraining.addNumber(100.0d * Heuristics.getFScore(correctPosClassified / ((Set) linkedList.get(i4)).size(), correctPosClassified + (((Set) linkedList2.get(i4)).size() - correctNegClassified) == 0 ? 0.0d : correctPosClassified / (correctPosClassified + r0)));
            this.fMeasure.addNumber(100.0d * Heuristics.getFScore(correctPosClassified2 / ((Set) linkedList3.get(i4)).size(), correctPosClassified2 + (((Set) linkedList4.get(i4)).size() - correctNegClassified2) == 0 ? 0.0d : correctPosClassified2 / (correctPosClassified2 + r0)));
            this.length.addNumber(OWLClassExpressionUtils.getLength(currentlyBestDescription));
            outputWriter("fold " + i4 + ":");
            outputWriter("  training: " + individualSetToStringSet.size() + " positive and " + individualSetToStringSet2.size() + " negative examples");
            outputWriter("  testing: " + correctPosClassified2 + "/" + ((Set) linkedList3.get(i4)).size() + " correct positives, " + correctNegClassified2 + "/" + ((Set) linkedList4.get(i4)).size() + " correct negatives");
            outputWriter("  concept: " + currentlyBestDescription);
            outputWriter("  accuracy: " + decimalFormat.format(size2) + "% (" + decimalFormat.format(size) + "% on training set)");
            outputWriter("  length: " + decimalFormat.format(OWLClassExpressionUtils.getLength(currentlyBestDescription)));
            outputWriter("  runtime: " + decimalFormat.format(nanoTime2 / 1.0E9d) + "s");
        }
        outputWriter("");
        outputWriter("Finished " + i + "-folds cross-validation.");
        outputWriter("runtime: " + statOutput(decimalFormat, this.runtime, "s"));
        outputWriter("length: " + statOutput(decimalFormat, this.length, ""));
        outputWriter("F-Measure on training set: " + statOutput(decimalFormat, this.fMeasureTraining, "%"));
        outputWriter("F-Measure: " + statOutput(decimalFormat, this.fMeasure, "%"));
        outputWriter("predictive accuracy on training set: " + statOutput(decimalFormat, this.accuracyTraining, "%"));
        outputWriter("predictive accuracy: " + statOutput(decimalFormat, this.accuracy, "%"));
    }

    protected int getCorrectPosClassified(IndividualReasoner individualReasoner, OWLClassExpression oWLClassExpression, Set<OWLIndividual> set) {
        return individualReasoner.hasType(oWLClassExpression, set).size();
    }

    protected Set<OWLIndividual> hasType(Set<OWLIndividual> set, QTL2 qtl2) {
        HashSet hashSet = new HashSet();
        QueryTree tree = qtl2.getBestSolution().getTree();
        for (OWLIndividual oWLIndividual : set) {
            if (qtl2.getTreeCache().getQueryTree(oWLIndividual.toStringID()).isSubsumedBy(tree, this.literalNodeSubsumptionStrategy)) {
                hashSet.add(oWLIndividual);
            }
        }
        return hashSet;
    }

    protected int getCorrectPosClassified(Set<OWLIndividual> set, QTL2 qtl2) {
        QueryTree tree = qtl2.getBestSolution().getTree();
        int i = 0;
        Iterator<OWLIndividual> it = set.iterator();
        while (it.hasNext()) {
            QueryTree queryTree = qtl2.getTreeCache().getQueryTree(it.next().toStringID());
            if (queryTree.isSubsumedBy(tree, this.literalNodeSubsumptionStrategy)) {
                i++;
            } else {
                System.out.println("POS NOT COVERED");
                queryTree.dump();
            }
        }
        return i;
    }

    protected int getCorrectNegClassified(SPARQLReasoner sPARQLReasoner, OWLClassExpression oWLClassExpression, Set<OWLIndividual> set) {
        return set.size() - sPARQLReasoner.hasType(oWLClassExpression, set).size();
    }

    protected int getCorrectNegClassified(Set<OWLIndividual> set, QTL2 qtl2) {
        QueryTree tree = qtl2.getBestSolution().getTree();
        int size = set.size();
        Iterator<OWLIndividual> it = set.iterator();
        while (it.hasNext()) {
            if (qtl2.getTreeCache().getQueryTree(it.next().toStringID()).isSubsumedBy(tree, this.literalNodeSubsumptionStrategy)) {
                size--;
            }
        }
        return size;
    }

    public static Set<OWLIndividual> getTestingSet(List<OWLIndividual> list, int[] iArr, int i) {
        int i2 = i == 0 ? 0 : iArr[i - 1];
        int i3 = iArr[i];
        HashSet hashSet = new HashSet();
        hashSet.addAll(list.subList(i2, i3));
        return hashSet;
    }

    public static Set<OWLIndividual> getTrainingSet(Set<OWLIndividual> set, Set<OWLIndividual> set2) {
        return Helper.difference(set, set2);
    }

    public static int[] calculateSplits(int i, int i2) {
        int[] iArr = new int[i2];
        for (int i3 = 1; i3 <= i2; i3++) {
            iArr[i3 - 1] = (int) Math.ceil((i3 * i) / i2);
        }
        return iArr;
    }

    public static String statOutput(DecimalFormat decimalFormat, Stat stat, String str) {
        return ((("av. " + decimalFormat.format(stat.getMean()) + str) + " (deviation " + decimalFormat.format(stat.getStandardDeviation()) + str + "; ") + "min " + decimalFormat.format(stat.getMin()) + str + "; ") + "max " + decimalFormat.format(stat.getMax()) + str + ")";
    }

    public Stat getAccuracy() {
        return this.accuracy;
    }

    public Stat getLength() {
        return this.length;
    }

    public Stat getRuntime() {
        return this.runtime;
    }

    protected void outputWriter(String str) {
        if (!writeToFile) {
            System.out.println(str);
        } else {
            Files.appendToFile(outputFile, str + "\n");
            System.out.println(str);
        }
    }

    public Stat getfMeasure() {
        return this.fMeasure;
    }

    public Stat getfMeasureTraining() {
        return this.fMeasureTraining;
    }
}
