package org.dllearner.algorithms.decisiontrees.heuristics;

import java.util.SortedSet;
import java.util.TreeSet;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.dllearner.algorithms.decisiontrees.dsttdt.dst.DSTUtils;
import org.dllearner.algorithms.decisiontrees.dsttdt.dst.MassFunction;
import org.dllearner.algorithms.decisiontrees.utils.Couple;
import org.dllearner.core.AbstractClassExpressionLearningProblem;
import org.dllearner.core.AbstractReasonerComponent;
import org.dllearner.learningproblems.PosNegUndLP;
import org.semanticweb.owlapi.model.OWLClassExpression;
import org.semanticweb.owlapi.model.OWLDataFactory;
import org.semanticweb.owlapi.model.OWLIndividual;
import org.semanticweb.owlapi.model.OWLObjectComplementOf;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import uk.ac.manchester.cs.owl.owlapi.OWLDataFactoryImpl;

/* loaded from: input_file:lib/components-core-1.3.0-jena3-SNAPSHOT.jar:org/dllearner/algorithms/decisiontrees/heuristics/TreeInductionHeuristics.class */
public class TreeInductionHeuristics {
    private AbstractReasonerComponent reasoner;
    private PosNegUndLP problem;
    private OWLDataFactory dataFactory = new OWLDataFactoryImpl();
    private static Logger logger = LoggerFactory.getLogger((Class<?>) TreeInductionHeuristics.class);
    protected static final int UNCERTAIN_INSTANCE_CHECK_UNC = 8;
    protected static final int NEGATIVE_INSTANCE_CHECK_UNC = 7;
    protected static final int POSITIVE_INSTANCE_CHECK_UNC = 6;
    protected static final int UNCERTAIN_INSTANCE_CHECK_FALSE = 5;
    protected static final int NEGATIVE_INSTANCE_CHECK_FALSE = 4;
    protected static final int POSITIVE_INSTANCE_CHECK_FALSE = 3;
    protected static final int UNCERTAIN_INSTANCE_CHECK_TRUE = 2;
    protected static final int NEGATIVE_INSTANCE_CHECK_TRUE = 1;
    protected static final int POSITIVE_INSTANCE_CHECK_TRUE = 0;

    public AbstractClassExpressionLearningProblem getProblem() {
        return this.problem;
    }

    public void setProblem(AbstractClassExpressionLearningProblem abstractClassExpressionLearningProblem) {
        if (abstractClassExpressionLearningProblem instanceof PosNegUndLP) {
            this.problem = (PosNegUndLP) abstractClassExpressionLearningProblem;
        }
    }

    public AbstractReasonerComponent getReasoner() {
        return this.reasoner;
    }

    public void setReasoner(AbstractReasonerComponent abstractReasonerComponent) {
        this.reasoner = abstractReasonerComponent;
    }

    public void setProblem(PosNegUndLP posNegUndLP) {
        this.problem = posNegUndLP;
    }

    public void init() {
    }

    public OWLClassExpression selectBestConcept(OWLClassExpression[] oWLClassExpressionArr, SortedSet<OWLIndividual> sortedSet, SortedSet<OWLIndividual> sortedSet2, SortedSet<OWLIndividual> sortedSet3, double d, double d2) {
        int i = 0;
        int[] splitCounts = getSplitCounts(oWLClassExpressionArr[0], sortedSet, sortedSet2, sortedSet3);
        logger.debug("#0  " + oWLClassExpressionArr[0] + "\t p:" + splitCounts[0] + "n:" + splitCounts[1] + "u:" + splitCounts[2] + "\t p:" + splitCounts[3] + " n:" + splitCounts[4] + " u:" + splitCounts[5] + "\t p:" + splitCounts[6] + " n:" + splitCounts[7] + " u:" + splitCounts[8] + "\t ");
        double gain = gain(splitCounts, d, d2);
        System.out.printf("%+10e\n", Double.valueOf(gain));
        System.out.println(oWLClassExpressionArr[0]);
        for (int i2 = 1; i2 < oWLClassExpressionArr.length; i2++) {
            int[] splitCounts2 = getSplitCounts(oWLClassExpressionArr[i2], sortedSet, sortedSet2, sortedSet3);
            logger.debug("#" + i2 + "   " + oWLClassExpressionArr[i2] + "   p: " + splitCounts2[0] + "n:" + splitCounts2[1] + "u:" + splitCounts2[2] + "\t p:" + splitCounts2[3] + " n:" + splitCounts2[4] + " u:" + splitCounts2[5] + "\t p:" + splitCounts2[6] + " n:" + splitCounts2[7] + " u:" + splitCounts2[8] + "\t ");
            double gain2 = gain(splitCounts2, d, d2);
            logger.debug(gain2 + "\n");
            logger.debug(oWLClassExpressionArr[i2].toString());
            if (gain2 < gain) {
                i = i2;
                gain = gain2;
            }
        }
        System.out.printf("best gain: " + gain + " \t split " + oWLClassExpressionArr[i], new Object[0]);
        return oWLClassExpressionArr[i];
    }

    public OWLClassExpression selectBestConceptCCP(OWLClassExpression[] oWLClassExpressionArr, SortedSet<OWLIndividual> sortedSet, SortedSet<OWLIndividual> sortedSet2, SortedSet<OWLIndividual> sortedSet3, double d, double d2) {
        int i = 0;
        int[] splitCounts = getSplitCounts(oWLClassExpressionArr[0], sortedSet, sortedSet2, sortedSet3);
        logger.debug("#0\t p:" + splitCounts[0] + "n:" + splitCounts[3] + "u:" + splitCounts[6] + "\t p:" + splitCounts[1] + " n:" + splitCounts[4] + " u:" + splitCounts[7] + "\t p:" + splitCounts[2] + " n:" + splitCounts[5] + " u:" + splitCounts[8] + "\t ");
        double CCP = CCP(splitCounts, d, d2);
        logger.debug("%+10e\n", Double.valueOf(CCP));
        logger.debug(oWLClassExpressionArr[0].toString());
        for (int i2 = 1; i2 < oWLClassExpressionArr.length; i2++) {
            int[] splitCounts2 = getSplitCounts(oWLClassExpressionArr[0], sortedSet, sortedSet2, sortedSet3);
            logger.debug("#" + i2 + "\t p:" + splitCounts2[0] + "n:" + splitCounts2[3] + "u:" + splitCounts2[6] + "\t p:" + splitCounts2[1] + " n:" + splitCounts2[4] + " u:" + splitCounts2[7] + "\t p:" + splitCounts2[2] + " n:" + splitCounts2[5] + " u:" + splitCounts2[8] + "\t ");
            double CCP2 = CCP(splitCounts2, d, d2);
            logger.debug(CCP2 + "\n");
            logger.debug(oWLClassExpressionArr[i2].toString());
            if (CCP2 < CCP) {
                i = i2;
                CCP = CCP2;
            }
        }
        logger.debug("best gain:" + CCP + " \t split #" + i);
        return oWLClassExpressionArr[i];
    }

    private double CCP(int[] iArr, double d, double d2) {
        double d3 = iArr[0] + iArr[3];
        double d4 = iArr[1] + iArr[4];
        double d5 = d3 + d4 + iArr[6] + iArr[7] + iArr[2] + iArr[5];
        double d6 = d5 != CMAESOptimizer.DEFAULT_STOPFITNESS ? d3 + (d4 / d5) : CMAESOptimizer.DEFAULT_STOPFITNESS;
        double d7 = iArr[0] + 1;
        double d8 = iArr[1] + 1;
        double d9 = iArr[3] + 1;
        double d10 = iArr[4] + 1;
        double d11 = d7 + d8 != CMAESOptimizer.DEFAULT_STOPFITNESS ? d7 / (d7 + d8) : 1.0d;
        double d12 = d8 + d10 != CMAESOptimizer.DEFAULT_STOPFITNESS ? (d8 + 0.5d) / (d8 + d10) : 1.0d;
        double d13 = (2.0d - d11) - d12 != CMAESOptimizer.DEFAULT_STOPFITNESS ? (1.0d - d11) / ((2.0d - d11) - d12) : 1.0d;
        double d14 = (2.0d - d11) - d12 != CMAESOptimizer.DEFAULT_STOPFITNESS ? (1.0d - d12) / ((2.0d - d11) - d12) : 1.0d;
        return ((-(d11 + d12)) * (((d11 / (d11 + d12)) * Math.log(d11 / (d11 + d12))) - ((d12 / (d11 + d12)) * Math.log(d12 / (d11 + d12))))) - (((2.0d - d13) - d14) * ((d13 * Math.log(d13)) - (d14 * Math.log(d14))));
    }

    private double gain(int[] iArr, double d, double d2) {
        double d3 = iArr[0] + iArr[3];
        double d4 = iArr[1] + iArr[4];
        double d5 = iArr[6] + iArr[7] + iArr[2] + iArr[5];
        double d6 = d3 + d4 + d5;
        return ((gini(iArr[0] + iArr[3], iArr[1] + iArr[4], d, d2) - ((d3 / d6) * gini(iArr[0], iArr[1], d, d2))) - ((d4 / d6) * gini(iArr[3], iArr[4], d, d2))) - ((-(d5 / d6)) * gini(iArr[6] + iArr[2], iArr[7] + iArr[5], d, d2));
    }

    static double gini(double d, double d2, double d3, double d4) {
        double d5 = d + d2;
        double d6 = ((d * 3) * d3) / (d5 + 3);
        double d7 = ((d2 * 3) * d4) / (d5 + 3);
        return (1.0d - (d6 * d6)) - (d7 * d7);
    }

    private int[] getSplitCounts(OWLClassExpression oWLClassExpression, SortedSet<OWLIndividual> sortedSet, SortedSet<OWLIndividual> sortedSet2, SortedSet<OWLIndividual> sortedSet3) {
        TreeSet treeSet = new TreeSet();
        TreeSet treeSet2 = new TreeSet();
        TreeSet treeSet3 = new TreeSet();
        TreeSet treeSet4 = new TreeSet();
        TreeSet treeSet5 = new TreeSet();
        TreeSet treeSet6 = new TreeSet();
        TreeSet treeSet7 = new TreeSet();
        TreeSet treeSet8 = new TreeSet();
        TreeSet treeSet9 = new TreeSet();
        splitGroup(oWLClassExpression, sortedSet, treeSet, treeSet4, treeSet7);
        splitGroup(oWLClassExpression, sortedSet2, treeSet2, treeSet5, treeSet8);
        splitGroup(oWLClassExpression, sortedSet3, treeSet3, treeSet6, treeSet9);
        return new int[]{treeSet.size(), treeSet2.size(), treeSet3.size(), treeSet4.size(), treeSet5.size(), treeSet6.size(), treeSet7.size(), treeSet8.size(), treeSet9.size()};
    }

    protected void split(OWLClassExpression oWLClassExpression, SortedSet<OWLIndividual> sortedSet, SortedSet<OWLIndividual> sortedSet2, SortedSet<OWLIndividual> sortedSet3, SortedSet<OWLIndividual> sortedSet4, SortedSet<OWLIndividual> sortedSet5, SortedSet<OWLIndividual> sortedSet6, SortedSet<OWLIndividual> sortedSet7, SortedSet<OWLIndividual> sortedSet8, SortedSet<OWLIndividual> sortedSet9) {
        TreeSet treeSet = new TreeSet();
        TreeSet treeSet2 = new TreeSet();
        TreeSet treeSet3 = new TreeSet();
        splitGroup(oWLClassExpression, sortedSet, sortedSet4, sortedSet7, treeSet);
        splitGroup(oWLClassExpression, sortedSet2, sortedSet5, sortedSet8, treeSet2);
        splitGroup(oWLClassExpression, sortedSet3, sortedSet6, sortedSet9, treeSet3);
    }

    private void splitGroup(OWLClassExpression oWLClassExpression, SortedSet<OWLIndividual> sortedSet, SortedSet<OWLIndividual> sortedSet2, SortedSet<OWLIndividual> sortedSet3, SortedSet<OWLIndividual> sortedSet4) {
        OWLObjectComplementOf oWLObjectComplementOf = this.dataFactory.getOWLObjectComplementOf(oWLClassExpression);
        for (OWLIndividual oWLIndividual : sortedSet) {
            if (this.reasoner.hasType(oWLClassExpression, oWLIndividual)) {
                sortedSet2.add(oWLIndividual);
            } else if (this.reasoner.hasType(oWLObjectComplementOf, oWLIndividual)) {
                sortedSet3.add(oWLIndividual);
            } else {
                sortedSet4.add(oWLIndividual);
            }
        }
    }

    public Couple<OWLClassExpression, MassFunction> selectBestConceptDST(OWLClassExpression[] oWLClassExpressionArr, SortedSet<OWLIndividual> sortedSet, SortedSet<OWLIndividual> sortedSet2, SortedSet<OWLIndividual> sortedSet3, double d, double d2) {
        int i = 0;
        int[] splitCounts = getSplitCounts(oWLClassExpressionArr[0], sortedSet, sortedSet2, sortedSet3);
        logger.debug("#0\t p:" + splitCounts[0] + "n:" + splitCounts[3] + "u:" + splitCounts[6] + "\t p:" + splitCounts[1] + " n:" + splitCounts[4] + " u:" + splitCounts[7] + "\t p:" + splitCounts[2] + " n:" + splitCounts[5] + " u:" + splitCounts[8] + "\t ");
        MassFunction<Integer> bba = DSTUtils.getBBA(splitCounts[0] + splitCounts[1], splitCounts[3] + splitCounts[4], splitCounts[6] + splitCounts[7] + splitCounts[2] + splitCounts[5]);
        double nonSpecificityMeasureValue = bba.getNonSpecificityMeasureValue();
        bba.getConfusionMeasure();
        logger.debug("%+10e\n", Double.valueOf(nonSpecificityMeasureValue));
        System.out.println(oWLClassExpressionArr[0]);
        for (int i2 = 1; i2 < oWLClassExpressionArr.length; i2++) {
            int[] splitCounts2 = getSplitCounts(oWLClassExpressionArr[i2], sortedSet, sortedSet2, sortedSet3);
            logger.debug("#" + i2 + "\t p:" + splitCounts2[0] + "n:" + splitCounts2[3] + "u:" + splitCounts2[6] + "\t p:" + splitCounts2[1] + " n:" + splitCounts2[4] + " u:" + splitCounts2[7] + "\t p:" + splitCounts2[2] + " n:" + splitCounts2[5] + " u:" + splitCounts2[8] + "\t ");
            MassFunction<Integer> bba2 = DSTUtils.getBBA(splitCounts2[0] + splitCounts2[1], splitCounts2[3] + splitCounts2[4], splitCounts2[6] + splitCounts2[7] + splitCounts2[2] + splitCounts2[5]);
            double nonSpecificityMeasureValue2 = bba2.getNonSpecificityMeasureValue();
            bba2.getGlobalUncertaintyMeasure();
            logger.debug("%+10e\n", Double.valueOf(nonSpecificityMeasureValue2));
            logger.debug("%+10e\n", Double.valueOf(nonSpecificityMeasureValue2));
            logger.debug(oWLClassExpressionArr[i2].toString());
            if (nonSpecificityMeasureValue2 <= nonSpecificityMeasureValue) {
                i = i2;
                nonSpecificityMeasureValue = nonSpecificityMeasureValue2;
                bba = bba2;
            }
        }
        logger.debug("best gain: %f \t split #%d\n", Double.valueOf(nonSpecificityMeasureValue), Integer.valueOf(i));
        Couple<OWLClassExpression, MassFunction> couple = new Couple<>();
        couple.setFirstElement(oWLClassExpressionArr[i]);
        couple.setSecondElement(bba);
        return couple;
    }

    public Couple<OWLClassExpression, MassFunction> selectWorstConceptDST(OWLClassExpression[] oWLClassExpressionArr, SortedSet<OWLIndividual> sortedSet, SortedSet<OWLIndividual> sortedSet2, SortedSet<OWLIndividual> sortedSet3, double d, double d2) {
        int i = 0;
        int[] splitCounts = getSplitCounts(oWLClassExpressionArr[0], sortedSet, sortedSet2, sortedSet3);
        logger.debug("#0\t p:" + splitCounts[0] + "n:" + splitCounts[3] + "u:" + splitCounts[6] + "\t p:" + splitCounts[1] + " n:" + splitCounts[4] + " u:" + splitCounts[7] + "\t p:" + splitCounts[2] + " n:" + splitCounts[5] + " u:" + splitCounts[8] + "\t ");
        MassFunction<Integer> bba = DSTUtils.getBBA(splitCounts[0] + splitCounts[1], splitCounts[3] + splitCounts[4], splitCounts[6] + splitCounts[7] + splitCounts[2] + splitCounts[5]);
        double nonSpecificityMeasureValue = bba.getNonSpecificityMeasureValue();
        bba.getConfusionMeasure();
        logger.debug("%+10e\n", Double.valueOf(nonSpecificityMeasureValue));
        System.out.println(oWLClassExpressionArr[0]);
        for (int i2 = 1; i2 < oWLClassExpressionArr.length; i2++) {
            int[] splitCounts2 = getSplitCounts(oWLClassExpressionArr[i2], sortedSet, sortedSet2, sortedSet3);
            logger.debug("#" + i2 + "\t p:" + splitCounts2[0] + "n:" + splitCounts2[3] + "u:" + splitCounts2[6] + "\t p:" + splitCounts2[1] + " n:" + splitCounts2[4] + " u:" + splitCounts2[7] + "\t p:" + splitCounts2[2] + " n:" + splitCounts2[5] + " u:" + splitCounts2[8] + "\t ");
            MassFunction<Integer> bba2 = DSTUtils.getBBA(splitCounts2[0] + splitCounts2[1], splitCounts2[3] + splitCounts2[4], splitCounts2[6] + splitCounts2[7] + splitCounts2[2] + splitCounts2[5]);
            double nonSpecificityMeasureValue2 = bba2.getNonSpecificityMeasureValue();
            bba2.getGlobalUncertaintyMeasure();
            logger.debug("%+10e\n", Double.valueOf(nonSpecificityMeasureValue2));
            logger.debug("%+10e\n", Double.valueOf(nonSpecificityMeasureValue2));
            logger.debug(oWLClassExpressionArr[i2].toString());
            if (nonSpecificityMeasureValue2 >= nonSpecificityMeasureValue) {
                i = i2;
                nonSpecificityMeasureValue = nonSpecificityMeasureValue2;
                bba = bba2;
            }
        }
        logger.debug("best gain: %f \t split #%d\n", Double.valueOf(nonSpecificityMeasureValue), Integer.valueOf(i));
        Couple<OWLClassExpression, MassFunction> couple = new Couple<>();
        couple.setFirstElement(oWLClassExpressionArr[i]);
        couple.setSecondElement(bba);
        return couple;
    }

    public OWLClassExpression selectWorstConcept(OWLClassExpression[] oWLClassExpressionArr, SortedSet<OWLIndividual> sortedSet, SortedSet<OWLIndividual> sortedSet2, SortedSet<OWLIndividual> sortedSet3, double d, double d2) {
        int i = 0;
        int[] splitCounts = getSplitCounts(oWLClassExpressionArr[0], sortedSet, sortedSet2, sortedSet3);
        logger.debug("#0  " + oWLClassExpressionArr[0] + "\t p:" + splitCounts[0] + "n:" + splitCounts[1] + "u:" + splitCounts[2] + "\t p:" + splitCounts[3] + " n:" + splitCounts[4] + " u:" + splitCounts[5] + "\t p:" + splitCounts[6] + " n:" + splitCounts[7] + " u:" + splitCounts[8] + "\t ");
        double gain = gain(splitCounts, d, d2);
        System.out.printf("%+10e\n", Double.valueOf(gain));
        System.out.println(oWLClassExpressionArr[0]);
        for (int i2 = 1; i2 < oWLClassExpressionArr.length; i2++) {
            int[] splitCounts2 = getSplitCounts(oWLClassExpressionArr[i2], sortedSet, sortedSet2, sortedSet3);
            logger.debug("#" + i2 + "   " + oWLClassExpressionArr[i2] + "   p: " + splitCounts2[0] + "n:" + splitCounts2[1] + "u:" + splitCounts2[2] + "\t p:" + splitCounts2[3] + " n:" + splitCounts2[4] + " u:" + splitCounts2[5] + "\t p:" + splitCounts2[6] + " n:" + splitCounts2[7] + " u:" + splitCounts2[8] + "\t ");
            double gain2 = gain(splitCounts2, d, d2);
            logger.debug(gain2 + "\n");
            logger.debug(oWLClassExpressionArr[i2].toString());
            if (gain2 > gain) {
                i = i2;
                gain = gain2;
            }
        }
        System.out.printf("best gain: " + gain + " \t split " + oWLClassExpressionArr[i], new Object[0]);
        return oWLClassExpressionArr[i];
    }
}
