package org.dllearner.cli.unife;

import java.text.DecimalFormat;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.TreeSet;
import mpi.MPI;
import org.apache.commons.io.FilenameUtils;
import org.apache.log4j.Logger;
import org.dllearner.algorithms.probabilistic.parameter.unife.edge.AbstractEDGE;
import org.dllearner.algorithms.probabilistic.structure.unife.leap.AbstractLEAP;
import org.dllearner.cli.CrossValidation;
import org.dllearner.core.AbstractReasonerComponent;
import org.dllearner.core.ComponentInitException;
import org.dllearner.core.probabilistic.unife.AbstractPSLA;
import org.dllearner.learningproblems.ClassLearningProblem;
import org.dllearner.learningproblems.PosNegLP;
import org.dllearner.learningproblems.PosOnlyLP;
import org.dllearner.utils.unife.OWLUtils;
import org.dllearner.utils.unife.ReflectionHelper;
import org.semanticweb.owlapi.apibinding.OWLManager;
import org.semanticweb.owlapi.model.AxiomType;
import org.semanticweb.owlapi.model.OWLClass;
import org.semanticweb.owlapi.model.OWLClassAssertionAxiom;
import org.semanticweb.owlapi.model.OWLDataFactory;
import org.semanticweb.owlapi.model.OWLIndividual;
import org.semanticweb.owlapi.model.OWLOntologyCreationException;
import org.semanticweb.owlapi.model.OWLOntologyStorageException;
import unife.bundle.utilities.BundleUtilities;
import unife.edge.mpi.MPIUtilities;

/* loaded from: input_file:org/dllearner/cli/unife/LEAPCrossValidation.class */
public class LEAPCrossValidation extends CrossValidation {
    private static final Logger logger = Logger.getLogger(LEAPCrossValidation.class);

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v218, types: [java.util.Set] */
    /* JADX WARN: Type inference failed for: r0v36, types: [java.util.Set] */
    public LEAPCrossValidation(AbstractPSLA abstractPSLA, int i, boolean z, boolean z2) throws OWLOntologyStorageException, OWLOntologyCreationException {
        boolean isMaster = z2 ? MPIUtilities.isMaster(MPI.COMM_WORLD) : true;
        PosNegLP learningProblem = abstractPSLA.getLearningProblem();
        new DecimalFormat();
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        LinkedList linkedList3 = new LinkedList();
        LinkedList linkedList4 = new LinkedList();
        HashSet hashSet = new HashSet();
        Set hashSet2 = new HashSet();
        logger.debug("Setting cross validation");
        if (learningProblem instanceof PosNegLP) {
            hashSet = learningProblem.getPositiveExamples();
            hashSet2 = learningProblem.getNegativeExamples();
        } else if (learningProblem instanceof PosOnlyLP) {
            hashSet = learningProblem.getPositiveExamples();
            hashSet2 = new HashSet();
        } else {
            if (!(learningProblem instanceof ClassLearningProblem)) {
                throw new IllegalArgumentException("Only ClassLearningProblem, PosNeg and PosOnly learning problems are supported");
            }
            try {
                hashSet = new HashSet((List) ReflectionHelper.getPrivateField(learningProblem, "classInstances"));
                hashSet2 = new HashSet((List) ReflectionHelper.getPrivateField(learningProblem, "superClassInstances"));
                if (hashSet2.size() < i) {
                    logger.info("The number of folds is higher than the number of negative examples. Selecting the instances of Thing which are non instances of ClasstoDescribe as negative Examples");
                    hashSet2 = learningProblem.getReasoner().getIndividuals(OWLManager.getOWLDataFactory().getOWLThing());
                    hashSet2.removeAll(hashSet);
                }
            } catch (Exception e) {
                logger.error("Cannot get positive and negative individuals for the cross validation");
                logger.error(e);
                System.exit(-2);
            }
        }
        LinkedList linkedList5 = new LinkedList(hashSet);
        LinkedList linkedList6 = new LinkedList(hashSet2);
        Collections.shuffle(linkedList5, new Random(1L));
        Collections.shuffle(linkedList6, new Random(2L));
        if (!z && (hashSet.size() < i || hashSet2.size() < i)) {
            logger.error("The number of folds is higher than the number of positive/negative examples. This can result in empty test sets. Exiting.");
            System.exit(0);
        }
        if (z) {
            for (int i2 = 0; i2 < hashSet.size() + hashSet2.size(); i2++) {
            }
            logger.error("Leave-one-out not supported yet.");
            System.exit(1);
        } else {
            int[] calculateSplits = calculateSplits(hashSet.size(), i);
            int[] calculateSplits2 = calculateSplits(hashSet2.size(), i);
            for (int i3 = 0; i3 < i; i3++) {
                Set<OWLIndividual> testingSet = getTestingSet(linkedList5, calculateSplits, i3);
                Set<OWLIndividual> testingSet2 = getTestingSet(linkedList6, calculateSplits2, i3);
                linkedList3.add(i3, testingSet);
                linkedList4.add(i3, testingSet2);
                linkedList.add(i3, getTrainingSet(hashSet, testingSet));
                linkedList2.add(i3, getTrainingSet(hashSet2, testingSet2));
            }
        }
        String outputFile = abstractPSLA.getOutputFile();
        String removeExtension = FilenameUtils.removeExtension(outputFile);
        String extension = FilenameUtils.getExtension(outputFile);
        String removeExtension2 = FilenameUtils.removeExtension("posExamples.owl");
        String extension2 = FilenameUtils.getExtension("posExamples.owl");
        String removeExtension3 = FilenameUtils.removeExtension("negExamples.owl");
        String extension3 = FilenameUtils.getExtension("negExamples.owl");
        logger.debug("Performing Cross Validation");
        for (int i4 = 0; i4 < i; i4++) {
            logger.debug("Current Fold: " + (i4 + 1));
            Set set = (Set) linkedList.get(i4);
            Set set2 = (Set) linkedList2.get(i4);
            Set set3 = (Set) linkedList3.get(i4);
            Set set4 = (Set) linkedList4.get(i4);
            if (learningProblem instanceof PosNegLP) {
                learningProblem.setPositiveExamples(set);
                learningProblem.setNegativeExamples(set2);
                try {
                    learningProblem.init();
                } catch (ComponentInitException e2) {
                    logger.error(e2);
                    logger.error(e2.getLocalizedMessage());
                    System.exit(-2);
                }
            } else if (learningProblem instanceof PosOnlyLP) {
                ((PosOnlyLP) learningProblem).setPositiveExamples(new TreeSet(set));
                try {
                    learningProblem.init();
                } catch (ComponentInitException e3) {
                    logger.error(e3);
                    logger.error(e3.getLocalizedMessage());
                    System.exit(-2);
                }
            } else if (learningProblem instanceof ClassLearningProblem) {
                try {
                    ReflectionHelper.setPrivateField(learningProblem, "classInstances", set);
                    ReflectionHelper.setPrivateField(learningProblem, "superClassInstances", set2);
                    ReflectionHelper.setPrivateField(learningProblem, "negatedClassInstances", set2);
                } catch (Exception e4) {
                    logger.error("Cannot set positive and negative individuals for the cross validation");
                    logger.error(e4);
                    System.exit(-2);
                }
            }
            AbstractEDGE learningParameterAlgorithm = abstractPSLA.getLearningParameterAlgorithm();
            try {
                BundleUtilities.copyOntology(learningParameterAlgorithm.getSourcesOntology());
            } catch (OWLOntologyCreationException e5) {
                e5.printStackTrace();
            }
            abstractPSLA.setOutputFile(removeExtension + (i4 + 1) + "." + extension);
            try {
                learningParameterAlgorithm.init();
                abstractPSLA.init();
            } catch (ComponentInitException e6) {
                e6.printStackTrace();
            }
            abstractPSLA.start();
            if (isMaster) {
                Set<OWLClassAssertionAxiom> positiveExampleAxioms = learningParameterAlgorithm.getPositiveExampleAxioms();
                Set<OWLClassAssertionAxiom> negativeExampleAxioms = learningParameterAlgorithm.getNegativeExampleAxioms();
                OWLDataFactory oWLDataFactory = OWLManager.getOWLDataFactory();
                if (learningProblem instanceof ClassLearningProblem) {
                    ClassLearningProblem classLearningProblem = (ClassLearningProblem) learningProblem;
                    HashSet hashSet3 = new HashSet();
                    HashSet hashSet4 = new HashSet();
                    for (OWLClassAssertionAxiom oWLClassAssertionAxiom : positiveExampleAxioms) {
                        if (oWLClassAssertionAxiom.isOfType(new AxiomType[]{AxiomType.CLASS_ASSERTION})) {
                            hashSet3.add(oWLDataFactory.getOWLClassAssertionAxiom(classLearningProblem.getClassToDescribe(), oWLClassAssertionAxiom.getIndividual()));
                        }
                    }
                    for (OWLClassAssertionAxiom oWLClassAssertionAxiom2 : negativeExampleAxioms) {
                        if (oWLClassAssertionAxiom2.isOfType(new AxiomType[]{AxiomType.CLASS_ASSERTION})) {
                            hashSet4.add(oWLDataFactory.getOWLClassAssertionAxiom(classLearningProblem.getClassToDescribe(), oWLClassAssertionAxiom2.getIndividual()));
                        }
                    }
                    positiveExampleAxioms = hashSet3;
                    negativeExampleAxioms = hashSet4;
                }
                HashSet hashSet5 = new HashSet();
                HashSet hashSet6 = new HashSet();
                OWLClass classToDescribe = learningProblem instanceof ClassLearningProblem ? ((ClassLearningProblem) learningProblem).getClassToDescribe() : ((AbstractLEAP) abstractPSLA).getDummyClass();
                Iterator it = set3.iterator();
                while (it.hasNext()) {
                    hashSet5.add(oWLDataFactory.getOWLClassAssertionAxiom(classToDescribe, (OWLIndividual) it.next()));
                }
                Iterator it2 = set4.iterator();
                while (it2.hasNext()) {
                    hashSet6.add(oWLDataFactory.getOWLClassAssertionAxiom(classToDescribe, (OWLIndividual) it2.next()));
                }
                OWLUtils.saveAxioms(hashSet5, "posTestExamples" + (i4 + 1) + "." + extension2, "OWLXML");
                OWLUtils.saveAxioms(hashSet6, "negTestExamples" + (i4 + 1) + "." + extension3, "OWLXML");
                OWLUtils.saveAxioms(positiveExampleAxioms, removeExtension2 + (i4 + 1) + "." + extension2, "OWLXML");
                OWLUtils.saveAxioms(negativeExampleAxioms, removeExtension3 + (i4 + 1) + "." + extension3, "OWLXML");
            }
        }
    }

    protected int getCorrectPosClassified(AbstractReasonerComponent abstractReasonerComponent, OWLClass oWLClass, Set<OWLIndividual> set) {
        return abstractReasonerComponent.hasType(oWLClass, set).size();
    }

    protected int getCorrectNegClassified(AbstractReasonerComponent abstractReasonerComponent, OWLClass oWLClass, Set<OWLIndividual> set) {
        return set.size() - abstractReasonerComponent.hasType(oWLClass, set).size();
    }

    public static Set<OWLIndividual> getTestingSet(List<OWLIndividual> list, int[] iArr, int i) {
        int i2 = i == 0 ? 0 : iArr[i - 1];
        int i3 = iArr[i];
        HashSet hashSet = new HashSet();
        hashSet.addAll(list.subList(i2, i3));
        return hashSet;
    }
}
