package org.dllearner.cli;

import com.google.common.collect.Sets;
import java.io.File;
import java.lang.reflect.InvocationTargetException;
import java.text.DecimalFormat;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.dllearner.core.AbstractCELA;
import org.dllearner.core.AbstractClassExpressionLearningProblem;
import org.dllearner.core.AbstractReasonerComponent;
import org.dllearner.core.ComponentInitException;
import org.dllearner.core.StringRenderer;
import org.dllearner.learningproblems.Heuristics;
import org.dllearner.learningproblems.PosNegLP;
import org.dllearner.learningproblems.PosOnlyLP;
import org.dllearner.utilities.Files;
import org.dllearner.utilities.Helper;
import org.dllearner.utilities.owl.ManchesterOWLSyntaxOWLObjectRendererImplExt;
import org.dllearner.utilities.owl.OWLClassExpressionUtils;
import org.dllearner.utilities.statistics.Stat;
import org.semanticweb.owlapi.model.OWLClassExpression;
import org.semanticweb.owlapi.model.OWLIndividual;
import org.semanticweb.owlapi.util.SimpleShortFormProvider;

/* loaded from: input_file:org/dllearner/cli/CrossValidation.class */
public class CrossValidation {
    public static File outputFile;
    public static boolean writeToFile = false;
    public static boolean multiThreaded = false;
    protected Stat runtime = new Stat();
    protected Stat accuracy = new Stat();
    protected Stat length = new Stat();
    protected Stat accuracyTraining = new Stat();
    protected Stat fMeasure = new Stat();
    protected Stat fMeasureTraining = new Stat();
    protected Stat trainingCompletenessStat = new Stat();
    protected Stat trainingCorrectnessStat = new Stat();
    protected Stat testingCompletenessStat = new Stat();
    protected Stat testingCorrectnessStat = new Stat();
    DecimalFormat df = new DecimalFormat();

    public CrossValidation() {
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v152, types: [java.util.Set] */
    public CrossValidation(AbstractCELA abstractCELA, AbstractClassExpressionLearningProblem abstractClassExpressionLearningProblem, final AbstractReasonerComponent abstractReasonerComponent, int i, boolean z) {
        Set positiveExamples;
        HashSet hashSet;
        StringRenderer.setRenderer(new ManchesterOWLSyntaxOWLObjectRendererImplExt());
        StringRenderer.setShortFormProvider(new SimpleShortFormProvider());
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        LinkedList linkedList3 = new LinkedList();
        LinkedList linkedList4 = new LinkedList();
        if (abstractClassExpressionLearningProblem instanceof PosNegLP) {
            positiveExamples = ((PosNegLP) abstractClassExpressionLearningProblem).getPositiveExamples();
            hashSet = ((PosNegLP) abstractClassExpressionLearningProblem).getNegativeExamples();
        } else {
            if (!(abstractClassExpressionLearningProblem instanceof PosOnlyLP)) {
                throw new IllegalArgumentException("Only PosNeg and PosOnly learning problems are supported");
            }
            positiveExamples = ((PosNegLP) abstractClassExpressionLearningProblem).getPositiveExamples();
            hashSet = new HashSet();
        }
        LinkedList linkedList5 = new LinkedList(positiveExamples);
        LinkedList linkedList6 = new LinkedList(hashSet);
        Collections.shuffle(linkedList5, new Random(1L));
        Collections.shuffle(linkedList6, new Random(2L));
        if (!z && positiveExamples.size() < i && hashSet.size() < i) {
            System.out.println("The number of folds is higher than the number of positive/negative examples. This can result in empty test sets. Exiting.");
            System.exit(0);
        }
        if (z) {
            for (int i2 = 0; i2 < positiveExamples.size() + hashSet.size(); i2++) {
            }
            System.out.println("Leave-one-out not supported yet.");
            System.exit(1);
        } else {
            int[] calculateSplits = calculateSplits(positiveExamples.size(), i);
            int[] calculateSplits2 = calculateSplits(hashSet.size(), i);
            for (int i3 = 0; i3 < i; i3++) {
                Set<OWLIndividual> testingSet = getTestingSet(linkedList5, calculateSplits, i3);
                Set<OWLIndividual> testingSet2 = getTestingSet(linkedList6, calculateSplits2, i3);
                linkedList3.add(i3, testingSet);
                linkedList4.add(i3, testingSet2);
                linkedList.add(i3, getTrainingSet(positiveExamples, testingSet));
                linkedList2.add(i3, getTrainingSet(hashSet, testingSet2));
            }
        }
        if (multiThreaded && (abstractClassExpressionLearningProblem instanceof Cloneable) && (abstractCELA instanceof Cloneable)) {
            ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors() - 1);
            for (int i4 = 0; i4 < i; i4++) {
                try {
                    final PosNegLP posNegLP = (AbstractClassExpressionLearningProblem) abstractClassExpressionLearningProblem.getClass().getMethod("clone", new Class[0]).invoke(abstractClassExpressionLearningProblem, new Object[0]);
                    final Set set = (Set) linkedList.get(i4);
                    final Set set2 = (Set) linkedList2.get(i4);
                    final Set set3 = (Set) linkedList3.get(i4);
                    final Set set4 = (Set) linkedList4.get(i4);
                    if (abstractClassExpressionLearningProblem instanceof PosNegLP) {
                        posNegLP.setPositiveExamples(set);
                        posNegLP.setNegativeExamples(set2);
                    } else if (abstractClassExpressionLearningProblem instanceof PosOnlyLP) {
                        ((PosOnlyLP) posNegLP).setPositiveExamples(new TreeSet(set));
                    }
                    final AbstractCELA abstractCELA2 = (AbstractCELA) abstractCELA.getClass().getMethod("clone", new Class[0]).invoke(abstractCELA, new Object[0]);
                    final int i5 = i4;
                    newFixedThreadPool.submit(new Runnable() { // from class: org.dllearner.cli.CrossValidation.1
                        @Override // java.lang.Runnable
                        public void run() {
                            try {
                                CrossValidation.this.validate(abstractCELA2, posNegLP, abstractReasonerComponent, i5, set, set2, set3, set4);
                            } catch (Exception e) {
                                e.printStackTrace();
                            }
                        }
                    });
                } catch (IllegalAccessException | IllegalArgumentException | NoSuchMethodException | SecurityException | InvocationTargetException e) {
                    e.printStackTrace();
                }
            }
            newFixedThreadPool.shutdown();
            try {
                newFixedThreadPool.awaitTermination(1L, TimeUnit.DAYS);
            } catch (InterruptedException e2) {
                e2.printStackTrace();
            }
        } else {
            for (int i6 = 0; i6 < i; i6++) {
                Set<OWLIndividual> set5 = (Set) linkedList.get(i6);
                Set<OWLIndividual> set6 = (Set) linkedList2.get(i6);
                Set<OWLIndividual> set7 = (Set) linkedList3.get(i6);
                Set<OWLIndividual> set8 = (Set) linkedList4.get(i6);
                if (abstractClassExpressionLearningProblem instanceof PosNegLP) {
                    ((PosNegLP) abstractClassExpressionLearningProblem).setPositiveExamples(set5);
                    ((PosNegLP) abstractClassExpressionLearningProblem).setNegativeExamples(set6);
                } else if (abstractClassExpressionLearningProblem instanceof PosOnlyLP) {
                    ((PosOnlyLP) abstractClassExpressionLearningProblem).setPositiveExamples(new TreeSet(set5));
                }
                validate(abstractCELA, abstractClassExpressionLearningProblem, abstractReasonerComponent, i6, set5, set6, set7, set8);
            }
        }
        outputWriter("");
        outputWriter("Finished " + i + "-folds cross-validation.");
        outputWriter("runtime: " + statOutput(this.df, this.runtime, "s"));
        outputWriter("length: " + statOutput(this.df, this.length, ""));
        outputWriter("F-Measure on training set: " + statOutput(this.df, this.fMeasureTraining, "%"));
        outputWriter("F-Measure: " + statOutput(this.df, this.fMeasure, "%"));
        outputWriter("predictive accuracy on training set: " + statOutput(this.df, this.accuracyTraining, "%"));
        outputWriter("predictive accuracy: " + statOutput(this.df, this.accuracy, "%"));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void validate(AbstractCELA abstractCELA, AbstractClassExpressionLearningProblem abstractClassExpressionLearningProblem, AbstractReasonerComponent abstractReasonerComponent, int i, Set<OWLIndividual> set, Set<OWLIndividual> set2, Set<OWLIndividual> set3, Set<OWLIndividual> set4) {
        SortedSet stringSet = Helper.getStringSet(set);
        SortedSet stringSet2 = Helper.getStringSet(set2);
        String str = ("+" + new TreeSet((Collection) stringSet) + "\n") + "-" + new TreeSet((Collection) stringSet2) + "\n";
        try {
            abstractClassExpressionLearningProblem.init();
            abstractCELA.setLearningProblem(abstractClassExpressionLearningProblem);
            abstractCELA.init();
        } catch (ComponentInitException e) {
            e.printStackTrace();
        }
        long nanoTime = System.nanoTime();
        abstractCELA.start();
        long nanoTime2 = System.nanoTime() - nanoTime;
        this.runtime.addNumber(nanoTime2 / 1.0E9d);
        OWLClassExpression currentlyBestDescription = abstractCELA.getCurrentlyBestDescription();
        Sets.SetView difference = Sets.difference(set3, abstractReasonerComponent.hasType(currentlyBestDescription, set3));
        SortedSet hasType = abstractReasonerComponent.hasType(currentlyBestDescription, set4);
        int correctPosClassified = getCorrectPosClassified(abstractReasonerComponent, currentlyBestDescription, set);
        double correctNegClassified = 100.0d * ((correctPosClassified + getCorrectNegClassified(abstractReasonerComponent, currentlyBestDescription, set2)) / (set.size() + set2.size()));
        this.accuracyTraining.addNumber(correctNegClassified);
        int correctPosClassified2 = getCorrectPosClassified(abstractReasonerComponent, currentlyBestDescription, set3);
        int correctNegClassified2 = getCorrectNegClassified(abstractReasonerComponent, currentlyBestDescription, set4);
        double size = 100.0d * ((correctPosClassified2 + correctNegClassified2) / (set3.size() + set4.size()));
        this.accuracy.addNumber(size);
        this.fMeasureTraining.addNumber(100.0d * Heuristics.getFScore(correctPosClassified / set.size(), correctPosClassified + abstractReasonerComponent.hasType(currentlyBestDescription, set2).size() == 0 ? 0.0d : correctPosClassified / (correctPosClassified + r0)));
        this.fMeasure.addNumber(100.0d * Heuristics.getFScore(correctPosClassified2 / set3.size(), correctPosClassified2 + abstractReasonerComponent.hasType(currentlyBestDescription, set4).size() == 0 ? 0.0d : correctPosClassified2 / (correctPosClassified2 + r0)));
        this.length.addNumber(OWLClassExpressionUtils.getLength(currentlyBestDescription));
        outputWriter(((((((((str + "test set errors pos: " + difference + "\n") + "test set errors neg: " + hasType + "\n") + "fold " + i + ":\n") + "  training: " + stringSet.size() + " positive and " + stringSet2.size() + " negative examples") + "  testing: " + correctPosClassified2 + "/" + set3.size() + " correct positives, " + correctNegClassified2 + "/" + set4.size() + " correct negatives\n") + "  concept: " + currentlyBestDescription.toString().replace("\n", " ") + "\n") + "  accuracy: " + this.df.format(size) + "% (" + this.df.format(correctNegClassified) + "% on training set)\n") + "  length: " + this.df.format(OWLClassExpressionUtils.getLength(currentlyBestDescription)) + "\n") + "  runtime: " + this.df.format(nanoTime2 / 1.0E9d) + "s\n");
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public int getCorrectPosClassified(AbstractReasonerComponent abstractReasonerComponent, OWLClassExpression oWLClassExpression, Set<OWLIndividual> set) {
        return abstractReasonerComponent.hasType(oWLClassExpression, set).size();
    }

    protected int getCorrectNegClassified(AbstractReasonerComponent abstractReasonerComponent, OWLClassExpression oWLClassExpression, Set<OWLIndividual> set) {
        return set.size() - abstractReasonerComponent.hasType(oWLClassExpression, set).size();
    }

    public static Set<OWLIndividual> getTestingSet(List<OWLIndividual> list, int[] iArr, int i) {
        int i2 = i == 0 ? 0 : iArr[i - 1];
        int i3 = iArr[i];
        HashSet hashSet = new HashSet();
        hashSet.addAll(list.subList(i2, i3));
        return hashSet;
    }

    public static Set<OWLIndividual> getTrainingSet(Set<OWLIndividual> set, Set<OWLIndividual> set2) {
        return Sets.difference(set, set2);
    }

    public static int[] calculateSplits(int i, int i2) {
        int[] iArr = new int[i2];
        for (int i3 = 1; i3 <= i2; i3++) {
            iArr[i3 - 1] = (int) Math.ceil((i3 * i) / i2);
        }
        return iArr;
    }

    public static String statOutput(DecimalFormat decimalFormat, Stat stat, String str) {
        return ((("av. " + decimalFormat.format(stat.getMean()) + str) + " (deviation " + decimalFormat.format(stat.getStandardDeviation()) + str + "; ") + "min " + decimalFormat.format(stat.getMin()) + str + "; ") + "max " + decimalFormat.format(stat.getMax()) + str + ")";
    }

    public Stat getAccuracy() {
        return this.accuracy;
    }

    public Stat getLength() {
        return this.length;
    }

    public Stat getRuntime() {
        return this.runtime;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void outputWriter(String str) {
        if (!writeToFile) {
            System.out.println(str);
        } else {
            Files.appendToFile(outputFile, str + "\n");
            System.out.println(str);
        }
    }

    public Stat getfMeasure() {
        return this.fMeasure;
    }

    public Stat getfMeasureTraining() {
        return this.fMeasureTraining;
    }
}
