package org.dllearner.cli;

import java.text.DecimalFormat;
import java.util.Collections;
import java.util.LinkedList;
import java.util.Random;
import java.util.Set;
import java.util.SortedSet;
import org.apache.log4j.Logger;
import org.dllearner.algorithms.ParCEL.ParCELAbstract;
import org.dllearner.algorithms.ParCEL.ParCELPosNegLP;
import org.dllearner.algorithms.ParCELEx.ParCELExAbstract;
import org.dllearner.core.AbstractCELA;
import org.dllearner.core.AbstractReasonerComponent;
import org.dllearner.core.ComponentInitException;
import org.dllearner.core.owl.Description;
import org.dllearner.core.owl.Individual;
import org.dllearner.learningproblems.Heuristics;
import org.dllearner.learningproblems.PosNegLP;
import org.dllearner.utilities.Files;
import org.dllearner.utilities.Helper;
import org.dllearner.utilities.statistics.Stat;

/* loaded from: input_file:org/dllearner/cli/ParCELCrossValidation.class */
public class ParCELCrossValidation extends CrossValidation {
    protected Stat noOfPartialDef;
    protected Stat partialDefinitionLength;
    Logger logger;
    protected boolean interupted;

    public ParCELCrossValidation(AbstractCELA abstractCELA, PosNegLP posNegLP, AbstractReasonerComponent abstractReasonerComponent, int i, boolean z) {
        super(abstractCELA, posNegLP, abstractReasonerComponent, i, z);
        this.noOfPartialDef = new Stat();
        this.partialDefinitionLength = new Stat();
        this.logger = Logger.getLogger(getClass());
        this.interupted = false;
    }

    public ParCELCrossValidation(AbstractCELA abstractCELA, ParCELPosNegLP parCELPosNegLP, AbstractReasonerComponent abstractReasonerComponent, int i, boolean z, int i2) {
        this.noOfPartialDef = new Stat();
        this.partialDefinitionLength = new Stat();
        this.logger = Logger.getLogger(getClass());
        this.interupted = false;
        DecimalFormat decimalFormat = new DecimalFormat();
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        LinkedList linkedList3 = new LinkedList();
        LinkedList linkedList4 = new LinkedList();
        Set positiveExamples = parCELPosNegLP.getPositiveExamples();
        LinkedList linkedList5 = new LinkedList(positiveExamples);
        Collections.shuffle(linkedList5, new Random(1L));
        Set negativeExamples = parCELPosNegLP.getNegativeExamples();
        LinkedList linkedList6 = new LinkedList(negativeExamples);
        Collections.shuffle(linkedList6, new Random(2L));
        if (!z && positiveExamples.size() < i && negativeExamples.size() < i) {
            System.out.println("The number of folds is higher than the number of positive/negative examples. This can result in empty test sets. Exiting.");
            System.exit(0);
        }
        if (z) {
            for (int i3 = 0; i3 < positiveExamples.size() + negativeExamples.size(); i3++) {
            }
            System.out.println("Leave-one-out not supported yet.");
            System.exit(1);
        } else {
            int[] calculateSplits = calculateSplits(positiveExamples.size(), i);
            int[] calculateSplits2 = calculateSplits(negativeExamples.size(), i);
            for (int i4 = 0; i4 < i; i4++) {
                Set<Individual> testingSet = getTestingSet(linkedList5, calculateSplits, i4);
                Set<Individual> testingSet2 = getTestingSet(linkedList6, calculateSplits2, i4);
                linkedList3.add(i4, testingSet);
                linkedList4.add(i4, testingSet2);
                linkedList.add(i4, getTrainingSet(positiveExamples, testingSet));
                linkedList2.add(i4, getTrainingSet(negativeExamples, testingSet2));
            }
        }
        int i5 = 0;
        int i6 = 0;
        Stat stat = new Stat();
        Stat stat2 = new Stat();
        Stat stat3 = new Stat();
        Stat stat4 = new Stat();
        Stat stat5 = new Stat();
        Stat stat6 = new Stat();
        Stat stat7 = new Stat();
        Stat stat8 = new Stat();
        Stat stat9 = new Stat();
        Stat stat10 = new Stat();
        Stat stat11 = new Stat();
        Stat stat12 = new Stat();
        Stat stat13 = new Stat();
        Stat stat14 = new Stat();
        Stat stat15 = new Stat();
        Stat stat16 = new Stat();
        Stat stat17 = new Stat();
        Stat stat18 = new Stat();
        Stat stat19 = new Stat();
        Stat stat20 = new Stat();
        Stat stat21 = new Stat();
        Stat stat22 = new Stat();
        Stat stat23 = new Stat();
        Stat stat24 = new Stat();
        Stat stat25 = new Stat();
        Stat stat26 = new Stat();
        Stat stat27 = new Stat();
        Stat stat28 = new Stat();
        Stat stat29 = new Stat();
        Stat stat30 = new Stat();
        for (int i7 = 0; i7 < i2; i7++) {
            this.runtime = new Stat();
            this.noOfPartialDef = new Stat();
            this.partialDefinitionLength = new Stat();
            this.length = new Stat();
            this.accuracyTraining = new Stat();
            this.trainingCorrectnessStat = new Stat();
            this.trainingCompletenessStat = new Stat();
            this.accuracy = new Stat();
            this.testingCorrectnessStat = new Stat();
            this.testingCompletenessStat = new Stat();
            for (int i8 = 0; i8 < i; i8++) {
                if (this.interupted) {
                    outputWriter("Cross validation has been interupted");
                    return;
                }
                parCELPosNegLP.setPositiveExamples((Set) linkedList.get(i8));
                parCELPosNegLP.setNegativeExamples((Set) linkedList2.get(i8));
                try {
                    parCELPosNegLP.init();
                    abstractCELA.init();
                } catch (ComponentInitException e) {
                    e.printStackTrace();
                }
                long nanoTime = System.nanoTime();
                try {
                    abstractCELA.start();
                } catch (OutOfMemoryError e2) {
                    System.out.println("out of memory at " + ((System.currentTimeMillis() - nanoTime) / 1000) + "s");
                }
                long nanoTime2 = System.nanoTime() - nanoTime;
                this.runtime.addNumber(nanoTime2 / 1.0E9d);
                Description unionCurrenlyBestDescription = ((ParCELAbstract) abstractCELA).getUnionCurrenlyBestDescription();
                Set difference = Helper.difference((Set) linkedList.get(i8), abstractReasonerComponent.hasType(unionCurrenlyBestDescription, (Set) linkedList.get(i8)));
                SortedSet hasType = abstractReasonerComponent.hasType(unionCurrenlyBestDescription, (Set) linkedList2.get(i8));
                outputWriter("training set errors pos (" + difference.size() + "): " + difference);
                outputWriter("training set errors neg (" + hasType.size() + "): " + hasType);
                Set difference2 = Helper.difference((Set) linkedList3.get(i8), abstractReasonerComponent.hasType(unionCurrenlyBestDescription, (Set) linkedList3.get(i8)));
                SortedSet hasType2 = abstractReasonerComponent.hasType(unionCurrenlyBestDescription, (Set) linkedList4.get(i8));
                outputWriter("test set errors pos: " + difference2);
                outputWriter("test set errors neg: " + hasType2);
                Set difference3 = Helper.difference((Set) linkedList3.get(i8), abstractReasonerComponent.hasType(unionCurrenlyBestDescription, (Set) linkedList3.get(i8)));
                SortedSet hasType3 = abstractReasonerComponent.hasType(unionCurrenlyBestDescription, (Set) linkedList4.get(i8));
                outputWriter("test set errors pos: " + difference3);
                outputWriter("test set errors neg: " + hasType3);
                int correctPosClassified = getCorrectPosClassified(abstractReasonerComponent, unionCurrenlyBestDescription, (Set) linkedList.get(i8));
                int correctNegClassified = getCorrectNegClassified(abstractReasonerComponent, unionCurrenlyBestDescription, (Set) linkedList2.get(i8));
                double size = 100.0d * ((correctPosClassified + correctNegClassified) / (((Set) linkedList.get(i8)).size() + ((Set) linkedList2.get(i8)).size()));
                double size2 = (100.0d * correctPosClassified) / ((Set) linkedList.get(i8)).size();
                double size3 = (100.0d * correctNegClassified) / ((Set) linkedList2.get(i8)).size();
                this.accuracyTraining.addNumber(size);
                this.trainingCompletenessStat.addNumber(size2);
                this.trainingCorrectnessStat.addNumber(size3);
                int correctPosClassified2 = getCorrectPosClassified(abstractReasonerComponent, unionCurrenlyBestDescription, (Set) linkedList3.get(i8));
                int correctNegClassified2 = getCorrectNegClassified(abstractReasonerComponent, unionCurrenlyBestDescription, (Set) linkedList4.get(i8));
                double size4 = 100.0d * ((correctPosClassified2 + correctNegClassified2) / (((Set) linkedList3.get(i8)).size() + ((Set) linkedList4.get(i8)).size()));
                double size5 = (100.0d * correctPosClassified2) / ((Set) linkedList3.get(i8)).size();
                double size6 = (100.0d * correctNegClassified2) / ((Set) linkedList4.get(i8)).size();
                this.accuracy.addNumber(size4);
                this.testingCompletenessStat.addNumber(size5);
                this.testingCorrectnessStat.addNumber(size6);
                this.fMeasureTraining.addNumber(100.0d * Heuristics.getFScore(correctPosClassified / ((Set) linkedList.get(i8)).size(), correctPosClassified + abstractReasonerComponent.hasType(unionCurrenlyBestDescription, (Set) linkedList2.get(i8)).size() == 0 ? 0.0d : correctPosClassified / (correctPosClassified + r0)));
                this.fMeasure.addNumber(100.0d * Heuristics.getFScore(correctPosClassified2 / ((Set) linkedList3.get(i8)).size(), correctPosClassified2 + abstractReasonerComponent.hasType(unionCurrenlyBestDescription, (Set) linkedList4.get(i8)).size() == 0 ? 0.0d : correctPosClassified2 / (correctPosClassified2 + r0)));
                this.length.addNumber(unionCurrenlyBestDescription.getLength());
                outputWriter("fold " + i8 + ":");
                outputWriter("  training: " + correctPosClassified + "/" + ((Set) linkedList.get(i8)).size() + " positive and " + correctNegClassified + "/" + ((Set) linkedList2.get(i8)).size() + " negative examples");
                outputWriter("  testing: " + correctPosClassified2 + "/" + ((Set) linkedList3.get(i8)).size() + " correct positives, " + correctNegClassified2 + "/" + ((Set) linkedList4.get(i8)).size() + " correct negatives");
                outputWriter("  concept: " + unionCurrenlyBestDescription);
                outputWriter("  accuracy: " + decimalFormat.format(size4) + "(corr/comp:" + size6 + "/" + size5 + ")% --- " + decimalFormat.format(size) + " (corr/comp:" + size3 + "/" + size2 + ")% on training set)");
                outputWriter("  length: " + decimalFormat.format(unionCurrenlyBestDescription.getLength()));
                outputWriter("  runtime: " + decimalFormat.format(nanoTime2 / 1.0E9d) + "s");
                if (abstractCELA instanceof ParCELAbstract) {
                    int noOfCompactedPartialDefinition = ((ParCELAbstract) abstractCELA).getNoOfCompactedPartialDefinition();
                    this.noOfPartialDef.addNumber(noOfCompactedPartialDefinition);
                    outputWriter("  number of partial definitions: " + noOfCompactedPartialDefinition + "/" + ((ParCELAbstract) abstractCELA).getNumberOfPartialDefinitions());
                    double length = unionCurrenlyBestDescription.getLength() / noOfCompactedPartialDefinition;
                    this.partialDefinitionLength.addNumber(length);
                    outputWriter("  avarage partial definition length: " + length);
                    if (abstractCELA instanceof ParCELExAbstract) {
                        ParCELExAbstract parCELExAbstract = (ParCELExAbstract) abstractCELA;
                        outputWriter("  number of partial definitions for each type: 1:" + parCELExAbstract.getNumberOfPartialDefinitions(1) + "; 2:" + parCELExAbstract.getNumberOfPartialDefinitions(2) + "; 3:" + parCELExAbstract.getNumberOfPartialDefinitions(3) + "; 4:" + parCELExAbstract.getNumberOfPartialDefinitions(4));
                        outputWriter("  number of counter partial definition used: " + (unionCurrenlyBestDescription.toString().split("NOT ").length - 1) + "/" + parCELExAbstract.getNumberOfCounterPartialDefinitionUsed());
                        if (parCELExAbstract.terminatedByCounterDefinitions()) {
                            outputWriter("  terminated by counter partial definitions");
                            i6++;
                        } else if (parCELExAbstract.terminatedByPartialDefinitions()) {
                            outputWriter("  terminated by partial definitions");
                            i5++;
                        } else {
                            outputWriter("  neither terminated by partial definition nor counter partial definition");
                        }
                    }
                }
            }
            outputWriter("");
            outputWriter("Finished the " + i7 + getOrderUnit(i7) + " of a " + i + "-folds cross-validation.");
            outputWriter("runtime: " + statOutput(decimalFormat, this.runtime, "s"));
            outputWriter("#partial definitions: " + statOutput(decimalFormat, this.noOfPartialDef, ""));
            outputWriter("avg. partial definition length: " + statOutput(decimalFormat, this.partialDefinitionLength, ""));
            outputWriter("length: " + statOutput(decimalFormat, this.length, ""));
            outputWriter("predictive accuracy on training set: " + statOutput(decimalFormat, this.accuracyTraining, "%") + " - corr/comp: " + statOutput(decimalFormat, this.trainingCorrectnessStat, "%") + "/" + statOutput(decimalFormat, this.trainingCompletenessStat, "%"));
            outputWriter("predictive accuracy: " + statOutput(decimalFormat, this.accuracy, "%") + " - corr/comp: " + statOutput(decimalFormat, this.testingCorrectnessStat, "%") + "/" + statOutput(decimalFormat, this.testingCompletenessStat, "%"));
            if (abstractCELA instanceof ParCELExAbstract) {
                outputWriter("terminated by: partial def.: " + i5 + "; counter partial def.: " + i6);
            }
            stat.addNumber(this.runtime.getMean());
            stat2.addNumber(this.runtime.getMax());
            stat3.addNumber(this.runtime.getMin());
            stat4.addNumber(this.noOfPartialDef.getMean());
            stat5.addNumber(this.noOfPartialDef.getMax());
            stat6.addNumber(this.noOfPartialDef.getMin());
            stat7.addNumber(this.partialDefinitionLength.getMean());
            stat8.addNumber(this.partialDefinitionLength.getMax());
            stat9.addNumber(this.partialDefinitionLength.getMin());
            stat10.addNumber(this.length.getMean());
            stat11.addNumber(this.length.getMax());
            stat12.addNumber(this.length.getMin());
            stat13.addNumber(this.accuracyTraining.getMean());
            stat14.addNumber(this.accuracyTraining.getMax());
            stat15.addNumber(this.accuracyTraining.getMin());
            stat16.addNumber(this.trainingCorrectnessStat.getMean());
            stat17.addNumber(this.trainingCorrectnessStat.getMax());
            stat18.addNumber(this.trainingCorrectnessStat.getMin());
            stat19.addNumber(this.trainingCompletenessStat.getMean());
            stat20.addNumber(this.trainingCompletenessStat.getMax());
            stat21.addNumber(this.trainingCompletenessStat.getMin());
            stat22.addNumber(this.accuracy.getMean());
            stat23.addNumber(this.accuracy.getMax());
            stat24.addNumber(this.accuracy.getMin());
            stat25.addNumber(this.testingCorrectnessStat.getMean());
            stat26.addNumber(this.testingCorrectnessStat.getMax());
            stat27.addNumber(this.testingCorrectnessStat.getMin());
            stat28.addNumber(this.testingCompletenessStat.getMean());
            stat29.addNumber(this.testingCompletenessStat.getMax());
            stat30.addNumber(this.testingCompletenessStat.getMin());
        }
        outputWriter("");
        outputWriter("Finished " + i2 + " times of the " + i + "-folds cross-validations");
        outputWriter("runtime: \n\t avg.: " + statOutput(decimalFormat, stat, "s") + "\n\t max.: " + statOutput(decimalFormat, stat2, "s") + "\n\t min.: " + statOutput(decimalFormat, stat3, "s"));
        outputWriter("number of partial definitions: \n\t avg.: " + statOutput(decimalFormat, stat4, "") + "\n\t max.: " + statOutput(decimalFormat, stat5, "") + "\n\t min.: " + statOutput(decimalFormat, stat6, ""));
        outputWriter("avg. partial definition length: \n\t avg.: " + statOutput(decimalFormat, stat7, "") + "\n\t max.: " + statOutput(decimalFormat, stat8, "") + "\n\t min.: " + statOutput(decimalFormat, stat9, ""));
        outputWriter("definition length: \n\t avg.: " + statOutput(decimalFormat, stat10, "") + "\n\t max.: " + statOutput(decimalFormat, stat11, "") + "\n\t min.: " + statOutput(decimalFormat, stat12, ""));
        outputWriter("accuracy on training set:\n\t avg.: " + statOutput(decimalFormat, stat13, "%") + "\n\t max.: " + statOutput(decimalFormat, stat14, "%") + "\n\t min.: " + statOutput(decimalFormat, stat15, "%"));
        outputWriter("correctness on training set: \n\t avg.: " + statOutput(decimalFormat, stat16, "%") + "\n\t max.: " + statOutput(decimalFormat, stat17, "%") + "\n\t min.: " + statOutput(decimalFormat, stat18, "%"));
        outputWriter("completeness on training set: \n\t avg.: " + statOutput(decimalFormat, stat19, "%") + "\n\t max.: " + statOutput(decimalFormat, stat20, "%") + "\n\t min.: " + statOutput(decimalFormat, stat21, "%"));
        outputWriter("accuracy on testing set: \n\t avg.: " + statOutput(decimalFormat, stat22, "%") + "\n\t max.: " + statOutput(decimalFormat, stat23, "%") + "\n\t min.: " + statOutput(decimalFormat, stat24, "%"));
        outputWriter("correctness on testing set: \n\t avg.: " + statOutput(decimalFormat, stat25, "%") + "\n\t max.: " + statOutput(decimalFormat, stat26, "%") + "\n\t min.: " + statOutput(decimalFormat, stat27, "%"));
        outputWriter("completeness on testing set: \n\t avg.: " + statOutput(decimalFormat, stat28, "%") + "\n\t max.: " + statOutput(decimalFormat, stat29, "%") + "\n\t min.: " + statOutput(decimalFormat, stat30, "%"));
        if (abstractCELA instanceof ParCELExAbstract) {
            outputWriter("terminated by: partial def.: " + i5 + "; counter partial def.: " + i6);
        }
    }

    private String getOrderUnit(int i) {
        switch (i) {
            case 1:
                return "st";
            case 2:
                return "nd";
            case 3:
                return "rd";
            default:
                return "th";
        }
    }

    @Override // org.dllearner.cli.CrossValidation
    protected void outputWriter(String str) {
        this.logger.info(str);
        if (writeToFile) {
            Files.appendToFile(outputFile, str + "\n");
        }
    }
}
