package org.dllearner.scripts;

import com.google.common.collect.Lists;
import java.io.File;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
import joptsimple.OptionParser;
import joptsimple.OptionSet;
import org.apache.commons.beanutils.PropertyUtils;
import org.apache.log4j.ConsoleAppender;
import org.apache.log4j.FileAppender;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.log4j.PatternLayout;
import org.dllearner.cli.CLI;
import org.dllearner.core.AbstractCELA;
import org.dllearner.core.AbstractReasonerComponent;
import org.dllearner.core.ComponentInitException;
import org.dllearner.core.owl.Description;
import org.dllearner.core.owl.Individual;
import org.dllearner.learningproblems.PosNegLP;
import org.dllearner.parser.ParseException;
import org.dllearner.utilities.Helper;
import org.dllearner.utilities.datastructures.TrainTestList;
import org.dllearner.utilities.statistics.Stat;

/* loaded from: input_file:org/dllearner/scripts/NestedCrossValidation.class */
public class NestedCrossValidation {
    private static final Logger logger = Logger.getLogger(NestedCrossValidation.class.getName());
    private static File logFile = new File("log/nested-cv.log");
    DecimalFormat df;
    Stat globalAcc;
    Stat globalF;
    Stat globalRecall;
    Stat globalPrecision;
    Map<Double, Stat> globalParaStats;

    public static void main(String[] strArr) throws IOException, ComponentInitException, ParseException, org.dllearner.confparser.ParseException {
        OptionParser optionParser = new OptionParser();
        optionParser.acceptsAll(Arrays.asList("h", "?", "help"), "Show help.");
        optionParser.acceptsAll(Arrays.asList("c", "conf"), "The comma separated list of config files to be used.").withRequiredArg().describedAs("file1, file2, ...");
        optionParser.acceptsAll(Arrays.asList("v", "verbose"), "Be more verbose.");
        optionParser.acceptsAll(Arrays.asList("o", "outerfolds"), "Number of outer folds.").withRequiredArg().ofType(Integer.class).describedAs("#folds");
        optionParser.acceptsAll(Arrays.asList("i", "innerfolds"), "Number of inner folds.").withRequiredArg().ofType(Integer.class).describedAs("#folds");
        optionParser.acceptsAll(Arrays.asList("p", "parameter"), "Parameter to vary.").withRequiredArg();
        optionParser.acceptsAll(Arrays.asList("r", "pvalues", "range"), "Values of parameter. $x-$y can be used for integer ranges.").withRequiredArg();
        optionParser.acceptsAll(Arrays.asList("s", "stepsize", "steps"), "Step size of range.").withOptionalArg().ofType(Double.class).defaultsTo(Double.valueOf(1.0d), new Double[0]);
        OptionSet optionSet = null;
        try {
            optionSet = optionParser.parse(strArr);
        } catch (Exception e) {
            System.out.println("Error: " + e.getMessage() + ". Use -? to get help.");
            System.exit(0);
        }
        if (optionSet.has("?")) {
            optionParser.printHelpOn(System.out);
            return;
        }
        if (!optionSet.has("c") || !optionSet.has("o") || !optionSet.has("i") || !optionSet.has("p") || !optionSet.has("r")) {
            optionParser.printHelpOn(System.out);
            System.out.println("\nYou need to specify the options c, i, o, p, r. Please consult the help table above.");
            return;
        }
        String str = (String) optionSet.valueOf("c");
        ArrayList arrayList = new ArrayList();
        for (String str2 : str.split(",")) {
            arrayList.add(new File(str2.trim()));
        }
        int intValue = ((Integer) optionSet.valueOf("o")).intValue();
        int intValue2 = ((Integer) optionSet.valueOf("i")).intValue();
        String str3 = (String) optionSet.valueOf("p");
        String[] split = ((String) optionSet.valueOf("r")).split("-");
        double doubleValue = Double.valueOf(split[0]).doubleValue();
        double doubleValue2 = Double.valueOf(split[1]).doubleValue();
        double doubleValue3 = ((Double) optionSet.valueOf("s")).doubleValue();
        boolean has = optionSet.has("v");
        PatternLayout patternLayout = new PatternLayout("%m%n");
        ConsoleAppender consoleAppender = new ConsoleAppender(patternLayout);
        Logger rootLogger = Logger.getRootLogger();
        rootLogger.removeAllAppenders();
        rootLogger.addAppender(consoleAppender);
        rootLogger.setLevel(Level.ERROR);
        Logger.getLogger("org.dllearner.algorithms").setLevel(Level.INFO);
        Logger.getLogger("org.dllearner.scripts").setLevel(Level.INFO);
        FileAppender fileAppender = new FileAppender(patternLayout, logFile.getPath(), false);
        rootLogger.addAppender(fileAppender);
        fileAppender.setThreshold(Level.INFO);
        java.util.logging.Logger.getLogger("").setLevel(java.util.logging.Level.WARNING);
        System.out.println("Warning: The script is not well tested yet. (No known bugs, but needs more testing.)");
        new NestedCrossValidation(arrayList, intValue, intValue2, str3, doubleValue, doubleValue2, doubleValue3, has);
    }

    public NestedCrossValidation(File file, int i, int i2, String str, double d, double d2, double d3, boolean z) throws ComponentInitException, ParseException, org.dllearner.confparser.ParseException, IOException {
        this(Lists.newArrayList(new File[]{file}), i, i2, str, d, d2, d3, z);
    }

    public NestedCrossValidation(List<File> list, int i, int i2, String str, double d, double d2, double d3, boolean z) throws ComponentInitException, ParseException, org.dllearner.confparser.ParseException, IOException {
        this.df = new DecimalFormat();
        this.globalAcc = new Stat();
        this.globalF = new Stat();
        this.globalRecall = new Stat();
        this.globalPrecision = new Stat();
        this.globalParaStats = new HashMap();
        for (File file : list) {
            logger.info("++++++++++++++++++++++++++++++++++++++++++++++");
            logger.info(file.getPath());
            logger.info("++++++++++++++++++++++++++++++++++++++++++++++");
            validate(file, i, i2, str, d, d2, d3, z);
        }
        logger.info("############################################");
        logger.info("############################################");
        logger.info("   Overall summary over parameter values:");
        double d4 = d;
        double d5 = Double.NEGATIVE_INFINITY;
        for (Map.Entry<Double, Stat> entry : this.globalParaStats.entrySet()) {
            double doubleValue = entry.getKey().doubleValue();
            Stat value = entry.getValue();
            logger.info("      value " + doubleValue + ": " + value.prettyPrint("%"));
            if (value.getMean() > d5) {
                d4 = doubleValue;
                d5 = value.getMean();
            }
        }
        logger.info("      selected " + d4 + " as best parameter value (criterion value " + this.df.format(d5) + "%)");
        logger.info("*******************");
        logger.info("* Overall Results *");
        logger.info("*******************");
        logger.info("accuracy: " + this.globalAcc.prettyPrint("%"));
        logger.info("F measure: " + this.globalF.prettyPrint("%"));
        logger.info("precision: " + this.globalPrecision.prettyPrint("%"));
        logger.info("recall: " + this.globalRecall.prettyPrint("%"));
    }

    private void validate(File file, int i, int i2, String str, double d, double d2, double d3, boolean z) throws IOException, ComponentInitException {
        CLI cli = new CLI(file);
        cli.init();
        PosNegLP learningProblem = cli.getLearningProblem();
        if (!(learningProblem instanceof PosNegLP)) {
            System.out.println("Positive only learning not supported yet.");
            System.exit(0);
        }
        LinkedList linkedList = new LinkedList(learningProblem.getPositiveExamples());
        Collections.shuffle(linkedList, new Random(1L));
        LinkedList linkedList2 = new LinkedList(learningProblem.getNegativeExamples());
        Collections.shuffle(linkedList2, new Random(2L));
        AbstractReasonerComponent reasonerComponent = cli.getReasonerComponent();
        reasonerComponent.init();
        String baseURI = reasonerComponent.getBaseURI();
        List<TrainTestList> folds = getFolds(linkedList, i);
        List<TrainTestList> folds2 = getFolds(linkedList2, i);
        Stat stat = new Stat();
        Stat stat2 = new Stat();
        Stat stat3 = new Stat();
        Stat stat4 = new Stat();
        for (int i3 = 0; i3 < i; i3++) {
            logger.info("Outer fold " + i3);
            TrainTestList trainTestList = folds.get(i3);
            TrainTestList trainTestList2 = folds2.get(i3);
            HashMap hashMap = new HashMap();
            double d4 = d;
            while (true) {
                double d5 = d4;
                if (d5 > d2) {
                    break;
                }
                logger.info("  Parameter value " + d5 + ":");
                List<TrainTestList> folds3 = getFolds(trainTestList.getTrainList(), i2);
                List<TrainTestList> folds4 = getFolds(trainTestList2.getTrainList(), i2);
                Stat stat5 = new Stat();
                for (int i4 = 0; i4 < i2; i4++) {
                    logger.info("    Inner fold " + i4 + ":");
                    TreeSet treeSet = new TreeSet(folds3.get(i4).getTrainList());
                    TreeSet treeSet2 = new TreeSet(folds4.get(i4).getTrainList());
                    CLI cli2 = new CLI(file);
                    cli2.init();
                    PosNegLP learningProblem2 = cli2.getLearningProblem();
                    learningProblem2.setPositiveExamples(treeSet);
                    learningProblem2.setNegativeExamples(treeSet2);
                    AbstractCELA learningAlgorithm = cli2.getLearningAlgorithm();
                    try {
                        PropertyUtils.setSimpleProperty(learningAlgorithm, str, Double.valueOf(d5));
                    } catch (IllegalAccessException | NoSuchMethodException | InvocationTargetException e) {
                        e.printStackTrace();
                    }
                    learningProblem2.init();
                    learningAlgorithm.init();
                    learningAlgorithm.start();
                    Description currentlyBestDescription = learningAlgorithm.getCurrentlyBestDescription();
                    TreeSet treeSet3 = new TreeSet(folds3.get(i4).getTestList());
                    TreeSet treeSet4 = new TreeSet(folds4.get(i4).getTestList());
                    SortedSet hasType = reasonerComponent.hasType(currentlyBestDescription, treeSet3);
                    Set difference = Helper.difference(treeSet3, hasType);
                    SortedSet hasType2 = reasonerComponent.hasType(currentlyBestDescription, treeSet4);
                    double size = 100.0d * ((hasType.size() + Helper.difference(treeSet4, hasType2).size()) / (treeSet3.size() + treeSet4.size()));
                    double size2 = (100.0d * ((double) hasType.size())) / ((double) (hasType.size() + hasType2.size())) == 0.0d ? 0.0d : hasType.size() + hasType2.size();
                    double size3 = (100.0d * ((double) hasType.size())) / ((double) (hasType.size() + difference.size())) == 0.0d ? 0.0d : hasType.size() + difference.size();
                    double d6 = (2.0d * (size2 * size3)) / (size2 + size3) == 0.0d ? 0.0d : size2 + size3;
                    stat5.addNumber(size);
                    logger.info("      hypothesis: " + currentlyBestDescription.toManchesterSyntaxString(baseURI, (Map) null));
                    logger.info("      accuracy: " + this.df.format(size) + "%");
                    logger.info("      precision: " + this.df.format(size2) + "%");
                    logger.info("      recall: " + this.df.format(size3) + "%");
                    logger.info("      F measure: " + this.df.format(d6) + "%");
                    if (z) {
                        logger.info("      false positives (neg. examples classified as pos.): " + formatIndividualSet(difference, baseURI));
                        logger.info("      false negatives (pos. examples classified as neg.): " + formatIndividualSet(hasType2, baseURI));
                    }
                }
                hashMap.put(Double.valueOf(d5), stat5);
                Stat stat6 = this.globalParaStats.get(Double.valueOf(d5));
                if (stat6 == null) {
                    stat6 = new Stat();
                    this.globalParaStats.put(Double.valueOf(d5), stat6);
                }
                stat6.add(stat5);
                d4 = d5 + d3;
            }
            logger.info("    Summary over parameter values:");
            double d7 = d;
            double d8 = Double.NEGATIVE_INFINITY;
            for (Map.Entry entry : hashMap.entrySet()) {
                double doubleValue = ((Double) entry.getKey()).doubleValue();
                Stat stat7 = (Stat) entry.getValue();
                logger.info("      value " + doubleValue + ": " + stat7.prettyPrint("%"));
                if (stat7.getMean() > d8) {
                    d7 = doubleValue;
                    d8 = stat7.getMean();
                }
            }
            logger.info("      selected " + d7 + " as best parameter value (criterion value " + this.df.format(d8) + "%)");
            logger.info("    Learn on Outer fold:");
            CLI cli3 = new CLI(file);
            cli3.init();
            PosNegLP learningProblem3 = cli3.getLearningProblem();
            learningProblem3.setPositiveExamples(new TreeSet(folds.get(i3).getTrainList()));
            learningProblem3.setNegativeExamples(new TreeSet(folds2.get(i3).getTrainList()));
            AbstractCELA learningAlgorithm2 = cli3.getLearningAlgorithm();
            try {
                PropertyUtils.setSimpleProperty(learningAlgorithm2, str, Double.valueOf(d7));
            } catch (IllegalAccessException | NoSuchMethodException | InvocationTargetException e2) {
                e2.printStackTrace();
            }
            learningProblem3.init();
            learningAlgorithm2.init();
            learningAlgorithm2.start();
            Description currentlyBestDescription2 = learningAlgorithm2.getCurrentlyBestDescription();
            TreeSet treeSet5 = new TreeSet(folds.get(i3).getTestList());
            TreeSet treeSet6 = new TreeSet(folds2.get(i3).getTestList());
            AbstractReasonerComponent reasonerComponent2 = cli3.getReasonerComponent();
            Set difference2 = Helper.difference(treeSet5, reasonerComponent2.hasType(currentlyBestDescription2, treeSet5));
            SortedSet hasType3 = reasonerComponent2.hasType(currentlyBestDescription2, treeSet6);
            double size4 = 100.0d * ((r0.size() + Helper.difference(treeSet6, hasType3).size()) / (treeSet5.size() + treeSet6.size()));
            double size5 = (100.0d * r0.size()) / (r0.size() + hasType3.size());
            double size6 = (100.0d * r0.size()) / (r0.size() + difference2.size());
            double d9 = (2.0d * (size5 * size6)) / (size5 + size6);
            logger.info("      hypothesis: " + currentlyBestDescription2.toManchesterSyntaxString(baseURI, (Map) null));
            logger.info("      accuracy: " + this.df.format(size4) + "%");
            logger.info("      precision: " + this.df.format(size5) + "%");
            logger.info("      recall: " + this.df.format(size6) + "%");
            logger.info("      F measure: " + this.df.format(d9) + "%");
            if (z) {
                logger.info("      false positives (neg. examples classified as pos.): " + formatIndividualSet(difference2, baseURI));
                logger.info("      false negatives (pos. examples classified as neg.): " + formatIndividualSet(hasType3, baseURI));
            }
            stat.addNumber(size4);
            stat2.addNumber(d9);
            stat3.addNumber(size6);
            stat4.addNumber(size5);
            reasonerComponent2.releaseKB();
        }
        this.globalAcc.add(stat);
        this.globalF.add(stat2);
        this.globalPrecision.add(stat4);
        this.globalRecall.add(stat3);
        logger.info("*******************");
        logger.info("* Overall Results *");
        logger.info("*******************");
        logger.info("accuracy: " + stat.prettyPrint("%"));
        logger.info("F measure: " + stat2.prettyPrint("%"));
        logger.info("precision: " + stat4.prettyPrint("%"));
        logger.info("recall: " + stat3.prettyPrint("%"));
    }

    public static List<TrainTestList> getFolds(List<Individual> list, int i) {
        LinkedList linkedList = new LinkedList();
        int[] calculateSplits = CrossValidation.calculateSplits(list.size(), i);
        int i2 = 0;
        while (i2 < i) {
            List<Individual> subList = list.subList(i2 == 0 ? 0 : calculateSplits[i2 - 1], calculateSplits[i2]);
            LinkedList linkedList2 = new LinkedList(list);
            linkedList2.removeAll(subList);
            linkedList.add(new TrainTestList(linkedList2, subList));
            i2++;
        }
        return linkedList;
    }

    private static String formatIndividualSet(Set<Individual> set, String str) {
        String str2 = "";
        int i = 0;
        Iterator<Individual> it = set.iterator();
        while (it.hasNext()) {
            str2 = str2 + it.next().toManchesterSyntaxString(str, (Map) null) + " ";
            i++;
            if (i == 20) {
                break;
            }
        }
        return str2;
    }
}
