package org.dllearner.algorithms.decisiontrees.dsttdt;

import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.SortedSet;
import java.util.Stack;
import java.util.TreeSet;
import org.dllearner.algorithms.decisiontrees.dsttdt.dst.MassFunction;
import org.dllearner.algorithms.decisiontrees.dsttdt.models.DSTDLTree;
import org.dllearner.algorithms.decisiontrees.dsttdt.models.EvidentialModel;
import org.dllearner.algorithms.decisiontrees.heuristics.TreeInductionHeuristics;
import org.dllearner.algorithms.decisiontrees.refinementoperators.DLTreesRefinementOperator;
import org.dllearner.algorithms.decisiontrees.utils.Couple;
import org.dllearner.algorithms.decisiontrees.utils.Npla;
import org.dllearner.algorithms.decisiontrees.utils.Split;
import org.dllearner.core.AbstractCELA;
import org.dllearner.core.AbstractClassExpressionLearningProblem;
import org.dllearner.core.AbstractReasonerComponent;
import org.dllearner.core.ComponentAnn;
import org.dllearner.core.ComponentInitException;
import org.dllearner.core.EvaluatedDescription;
import org.dllearner.core.annotations.OutVariable;
import org.dllearner.core.annotations.Unused;
import org.dllearner.core.config.ConfigOption;
import org.dllearner.learningproblems.PosNegUndLP;
import org.dllearner.refinementoperators.RefinementOperator;
import org.semanticweb.owlapi.model.OWLClassExpression;
import org.semanticweb.owlapi.model.OWLIndividual;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@ComponentAnn(name = "ETDT", shortName = "etdt", version = 1.0d, description = "An Evidence-based Terminological Decision Tree")
/* loaded from: input_file:org/dllearner/algorithms/decisiontrees/dsttdt/DSTTDTClassifier.class */
public class DSTTDTClassifier extends AbstractCELA {
    private static Logger logger = LoggerFactory.getLogger(DSTTDTClassifier.class);

    @OutVariable
    private DSTDLTree currentmodel;
    private boolean stop;

    @Unused
    protected OWLClassExpression classToDescribe;

    @ConfigOption(description = "instance of heuristic to use", defaultValue = "TreeInductionHeuristics")
    protected TreeInductionHeuristics heuristic;

    @ConfigOption(description = "refinement operator instance to use", defaultValue = "DLTreesRefinementOperator")
    protected RefinementOperator operator;

    @ConfigOption(defaultValue = "0.05", description = "Purity threshold for setting a leaf")
    protected double puritythreshold;

    @ConfigOption(defaultValue = "4", description = "value for limiting the number of generated concepts")
    protected int beam;

    @ConfigOption(defaultValue = "false", description = "a flag to decide if further control on the purity measure should be made")
    protected boolean nonSpecifityControl;
    protected double prPos;
    protected double prNeg;

    public DSTTDTClassifier() {
    }

    public DSTTDTClassifier(AbstractClassExpressionLearningProblem abstractClassExpressionLearningProblem, AbstractReasonerComponent abstractReasonerComponent) {
        super(abstractClassExpressionLearningProblem, abstractReasonerComponent);
    }

    public boolean isNonSpecifityControl() {
        return this.nonSpecifityControl;
    }

    public void setNonSpecifityControl(boolean z) {
        this.nonSpecifityControl = z;
    }

    public double getPuritythreshold() {
        return this.puritythreshold;
    }

    public void setPuritythreshold(double d) {
        this.puritythreshold = d;
    }

    public int getBeam() {
        return this.beam;
    }

    public void setBeam(int i) {
        this.beam = i;
    }

    public OWLClassExpression getClassToDescribe() {
        return this.classToDescribe;
    }

    public void setClassToDescribe(OWLClassExpression oWLClassExpression) {
        this.classToDescribe = oWLClassExpression;
    }

    public TreeInductionHeuristics getHeuristic() {
        return this.heuristic;
    }

    public void setHeuristic(TreeInductionHeuristics treeInductionHeuristics) {
        this.heuristic = treeInductionHeuristics;
    }

    public RefinementOperator getOperator() {
        return this.operator;
    }

    public void setOperator(RefinementOperator refinementOperator) {
        this.operator = refinementOperator;
    }

    @Override // org.dllearner.core.Component
    public void init() throws ComponentInitException {
        this.baseURI = this.reasoner.getBaseURI();
        this.prefixes = this.reasoner.getPrefixes();
        if (this.heuristic == null) {
            this.heuristic = new TreeInductionHeuristics();
            this.heuristic.setProblem(this.learningProblem);
            this.heuristic.setReasoner(this.reasoner);
            this.heuristic.init();
        }
        if (this.operator == null) {
            this.operator = new DLTreesRefinementOperator((PosNegUndLP) this.learningProblem, getReasoner(), 4);
            ((DLTreesRefinementOperator) this.operator).setReasoner(this.reasoner);
            ((DLTreesRefinementOperator) this.operator).setBeam(4);
            this.operator.init();
        }
        this.initialized = true;
    }

    public DSTDLTree induceDSTDLTree(SortedSet<OWLIndividual> sortedSet, SortedSet<OWLIndividual> sortedSet2, SortedSet<OWLIndividual> sortedSet3) {
        Npla npla = new Npla(sortedSet, sortedSet2, sortedSet3, Integer.valueOf(this.beam), Double.valueOf(this.prPos), Double.valueOf(this.prNeg));
        DSTDLTree dSTDLTree = new DSTDLTree();
        Stack stack = new Stack();
        Couple couple = new Couple();
        couple.setFirstElement(dSTDLTree);
        couple.setSecondElement(npla);
        stack.push(couple);
        Stack stack2 = new Stack();
        while (!stack.isEmpty()) {
            Couple couple2 = (Couple) stack.pop();
            Npla npla2 = (Npla) couple2.getSecondElement();
            SortedSet<OWLIndividual> sortedSet4 = (SortedSet) npla2.getFirst();
            SortedSet<OWLIndividual> sortedSet5 = (SortedSet) npla2.getSecond();
            SortedSet<OWLIndividual> sortedSet6 = (SortedSet) npla2.getThird();
            DSTDLTree dSTDLTree2 = (DSTDLTree) couple2.getFirstElement();
            int size = sortedSet4.size();
            int size2 = sortedSet5.size();
            int size3 = sortedSet6.size();
            System.out.printf("Learning problem\t p:%d\t n:%d\t u:%d\t prPos:%4f\t prNeg:%4f\n", Integer.valueOf(size), Integer.valueOf(size2), Integer.valueOf(size3), Double.valueOf(this.prPos), Double.valueOf(this.prNeg));
            ArrayList arrayList = new ArrayList();
            arrayList.add(-1);
            arrayList.add(1);
            MassFunction massFunction = new MassFunction(arrayList);
            ArrayList arrayList2 = new ArrayList();
            arrayList2.add(1);
            double d = size / ((size + size2) + size3);
            if (size + size2 + size3 == 0) {
                d = this.prPos;
            }
            massFunction.setValues(arrayList2, d);
            ArrayList arrayList3 = new ArrayList();
            arrayList3.add(-1);
            double d2 = size2 / ((size + size2) + size3);
            if (size + size2 + size3 == 0) {
                d2 = this.prNeg;
            }
            massFunction.setValues(arrayList3, d2);
            double d3 = size3 / ((size + size2) + size3);
            if (size + size2 + size3 == 0) {
                d3 = 0.0d;
            }
            massFunction.setValues(arrayList, d3);
            if (size != 0 || size2 != 0) {
                double d4 = size;
                double d5 = size2;
                double d6 = d4 / (d4 + d5);
                double d7 = d5 / (d4 + d5);
                if (d7 == 0.0d && d6 > this.puritythreshold) {
                    dSTDLTree2.setRoot(this.dataFactory.getOWLThing(), massFunction);
                } else if (d6 == 0.0d && d7 > this.puritythreshold) {
                    dSTDLTree2.setRoot(this.dataFactory.getOWLNothing(), massFunction);
                } else if (!this.nonSpecifityControl) {
                    DLTreesRefinementOperator dLTreesRefinementOperator = (DLTreesRefinementOperator) this.operator;
                    Set<OWLClassExpression> refine = stack2.isEmpty() ? dLTreesRefinementOperator.refine(this.dataFactory.getOWLThing(), sortedSet4, sortedSet5) : dLTreesRefinementOperator.refine(((DSTDLTree) stack2.pop()).getRoot(), sortedSet4, sortedSet5);
                    System.out.println("Refinement:" + refine);
                    ArrayList arrayList4 = new ArrayList(refine);
                    OWLClassExpression[] oWLClassExpressionArr = (OWLClassExpression[]) arrayList4.toArray(new OWLClassExpression[arrayList4.size()]);
                    Couple<OWLClassExpression, MassFunction> selectBestConceptDST = dLTreesRefinementOperator.getRo() == 3 ? this.heuristic.selectBestConceptDST(oWLClassExpressionArr, sortedSet4, sortedSet5, sortedSet6, this.prPos, this.prNeg) : this.heuristic.selectWorstConceptDST(oWLClassExpressionArr, sortedSet4, sortedSet5, sortedSet6, this.prPos, this.prNeg);
                    MassFunction secondElement = selectBestConceptDST.getSecondElement();
                    TreeSet treeSet = new TreeSet();
                    TreeSet treeSet2 = new TreeSet();
                    TreeSet treeSet3 = new TreeSet();
                    TreeSet treeSet4 = new TreeSet();
                    TreeSet treeSet5 = new TreeSet();
                    TreeSet treeSet6 = new TreeSet();
                    Split.split(selectBestConceptDST.getFirstElement(), this.dataFactory, this.reasoner, sortedSet4, sortedSet5, sortedSet6, treeSet, treeSet2, treeSet3, treeSet4, treeSet5, treeSet6);
                    dSTDLTree2.setRoot(selectBestConceptDST.getFirstElement(), secondElement);
                    DSTDLTree dSTDLTree3 = new DSTDLTree();
                    DSTDLTree dSTDLTree4 = new DSTDLTree();
                    dSTDLTree2.setPosTree(dSTDLTree3);
                    dSTDLTree2.setNegTree(dSTDLTree4);
                    Npla npla3 = new Npla(treeSet, treeSet2, treeSet3, Integer.valueOf(this.beam), Double.valueOf(d6), Double.valueOf(d7));
                    Npla npla4 = new Npla(treeSet4, treeSet5, treeSet6, Integer.valueOf(this.beam), Double.valueOf(d6), Double.valueOf(d7));
                    Couple couple3 = new Couple();
                    couple3.setFirstElement(dSTDLTree3);
                    couple3.setSecondElement(npla3);
                    Couple couple4 = new Couple();
                    couple4.setFirstElement(dSTDLTree4);
                    couple4.setSecondElement(npla4);
                    stack.push(couple4);
                    stack.push(couple3);
                    stack2.push(dSTDLTree2);
                } else if (massFunction.getNonSpecificityMeasureValue() < 0.1d) {
                    DLTreesRefinementOperator dLTreesRefinementOperator2 = (DLTreesRefinementOperator) this.operator;
                    Set<OWLClassExpression> refine2 = stack2.isEmpty() ? dLTreesRefinementOperator2.refine(this.dataFactory.getOWLThing(), sortedSet4, sortedSet5) : dLTreesRefinementOperator2.refine(((DSTDLTree) stack2.pop()).getRoot(), sortedSet4, sortedSet5);
                    OWLClassExpression[] oWLClassExpressionArr2 = (OWLClassExpression[]) refine2.toArray(new OWLClassExpression[refine2.size()]);
                    Couple<OWLClassExpression, MassFunction> selectBestConceptDST2 = dLTreesRefinementOperator2.getRo() == 3 ? this.heuristic.selectBestConceptDST(oWLClassExpressionArr2, sortedSet4, sortedSet5, sortedSet6, this.prPos, this.prNeg) : this.heuristic.selectWorstConceptDST(oWLClassExpressionArr2, sortedSet4, sortedSet5, sortedSet6, this.prPos, this.prNeg);
                    MassFunction secondElement2 = selectBestConceptDST2.getSecondElement();
                    TreeSet treeSet7 = new TreeSet();
                    TreeSet treeSet8 = new TreeSet();
                    TreeSet treeSet9 = new TreeSet();
                    TreeSet treeSet10 = new TreeSet();
                    TreeSet treeSet11 = new TreeSet();
                    TreeSet treeSet12 = new TreeSet();
                    Split.split(selectBestConceptDST2.getFirstElement(), this.dataFactory, this.reasoner, sortedSet4, sortedSet5, sortedSet6, treeSet7, treeSet8, treeSet9, treeSet10, treeSet11, treeSet12);
                    dSTDLTree2.setRoot(selectBestConceptDST2.getFirstElement(), secondElement2);
                    DSTDLTree dSTDLTree5 = new DSTDLTree();
                    DSTDLTree dSTDLTree6 = new DSTDLTree();
                    dSTDLTree2.setPosTree(dSTDLTree5);
                    dSTDLTree2.setNegTree(dSTDLTree6);
                    Npla npla5 = new Npla(treeSet7, treeSet8, treeSet9, Integer.valueOf(this.beam), Double.valueOf(d6), Double.valueOf(d7));
                    Npla npla6 = new Npla(treeSet10, treeSet11, treeSet12, Integer.valueOf(this.beam), Double.valueOf(d6), Double.valueOf(d7));
                    Couple couple5 = new Couple();
                    couple5.setFirstElement(dSTDLTree5);
                    couple5.setSecondElement(npla5);
                    Couple couple6 = new Couple();
                    couple6.setFirstElement(dSTDLTree6);
                    couple6.setSecondElement(npla6);
                    stack.push(couple6);
                    stack.push(couple5);
                    stack2.push(dSTDLTree2);
                } else if (d6 > d7) {
                    dSTDLTree2.setRoot(this.dataFactory.getOWLThing(), massFunction);
                } else {
                    dSTDLTree2.setRoot(this.dataFactory.getOWLNothing(), massFunction);
                }
            } else if (this.prPos >= this.prNeg) {
                dSTDLTree2.setRoot(this.dataFactory.getOWLThing(), massFunction);
            } else {
                dSTDLTree2.setRoot(this.dataFactory.getOWLNothing(), massFunction);
            }
        }
        this.currentmodel = dSTDLTree;
        this.stop = true;
        return dSTDLTree;
    }

    private void classifyExampleDST(List<Couple<Integer, MassFunction<Integer>>> list, OWLIndividual oWLIndividual, DSTDLTree dSTDLTree) {
        Stack stack = new Stack();
        stack.push(dSTDLTree);
        while (!stack.isEmpty()) {
            DSTDLTree dSTDLTree2 = (DSTDLTree) stack.pop();
            OWLClassExpression root = dSTDLTree2.getRoot();
            MassFunction rootBBA = dSTDLTree2.getRootBBA();
            if (root.equals(this.dataFactory.getOWLThing())) {
                Couple<Integer, MassFunction<Integer>> couple = new Couple<>();
                couple.setFirstElement(1);
                couple.setSecondElement(rootBBA);
                list.add(couple);
            } else if (root.equals(this.dataFactory.getOWLNothing())) {
                Couple<Integer, MassFunction<Integer>> couple2 = new Couple<>();
                couple2.setFirstElement(-1);
                couple2.setSecondElement(rootBBA);
                list.add(couple2);
            } else if (this.reasoner.hasType(root, oWLIndividual)) {
                if (dSTDLTree2.getPosSubTree() != null) {
                    stack.push(dSTDLTree2.getPosSubTree());
                } else {
                    Couple<Integer, MassFunction<Integer>> couple3 = new Couple<>();
                    couple3.setFirstElement(1);
                    couple3.setSecondElement(rootBBA);
                    list.add(couple3);
                }
            } else if (!this.reasoner.hasType((OWLClassExpression) this.dataFactory.getOWLObjectComplementOf(root), oWLIndividual)) {
                if (dSTDLTree2.getPosSubTree() != null) {
                    stack.push(dSTDLTree2.getPosSubTree());
                } else {
                    Couple<Integer, MassFunction<Integer>> couple4 = new Couple<>();
                    couple4.setFirstElement(1);
                    couple4.setSecondElement(rootBBA);
                    list.add(couple4);
                }
                if (dSTDLTree2.getNegSubTree() != null) {
                    stack.push(dSTDLTree2.getNegSubTree());
                } else {
                    Couple<Integer, MassFunction<Integer>> couple5 = new Couple<>();
                    couple5.setFirstElement(-1);
                    couple5.setSecondElement(rootBBA);
                    list.add(couple5);
                }
            } else if (dSTDLTree2.getNegSubTree() != null) {
                stack.push(dSTDLTree2.getNegSubTree());
            } else {
                Couple<Integer, MassFunction<Integer>> couple6 = new Couple<>();
                couple6.setFirstElement(-1);
                couple6.setSecondElement(rootBBA);
                list.add(couple6);
            }
        }
    }

    public DSTDLTree getCurrentmodel() {
        return this.currentmodel;
    }

    public void setCurrentmodel(DSTDLTree dSTDLTree) {
        this.currentmodel = dSTDLTree;
    }

    public int classifyExamplesDST(OWLIndividual oWLIndividual, DSTDLTree dSTDLTree) {
        return predict(getBBA(oWLIndividual, dSTDLTree));
    }

    public MassFunction<Integer> getBBA(OWLIndividual oWLIndividual, EvidentialModel evidentialModel) {
        ArrayList arrayList = new ArrayList();
        classifyExampleDST(arrayList, oWLIndividual, (DSTDLTree) evidentialModel);
        MassFunction<Integer> massFunction = (MassFunction) ((Couple) arrayList.get(0)).getSecondElement();
        MassFunction[] massFunctionArr = new MassFunction[arrayList.size() - 1];
        for (int i = 1; i < arrayList.size(); i++) {
            massFunctionArr[i - 1] = (MassFunction) ((Couple) arrayList.get(i)).getSecondElement();
        }
        if (massFunctionArr.length >= 1) {
            massFunction = massFunction.combineEvidences(massFunctionArr);
        }
        return massFunction;
    }

    private int predict(MassFunction<Integer> massFunction) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(1);
        double confirmationFunctionValue = massFunction.getConfirmationFunctionValue(arrayList);
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(-1);
        double confirmationFunctionValue2 = massFunction.getConfirmationFunctionValue(arrayList2);
        ArrayList arrayList3 = new ArrayList();
        arrayList3.add(-1);
        arrayList3.add(1);
        double confirmationFunctionValue3 = massFunction.getConfirmationFunctionValue(arrayList3);
        if (confirmationFunctionValue3 <= confirmationFunctionValue || confirmationFunctionValue3 <= confirmationFunctionValue2) {
            return confirmationFunctionValue >= confirmationFunctionValue2 ? 1 : -1;
        }
        if (confirmationFunctionValue > confirmationFunctionValue2) {
            return 1;
        }
        return confirmationFunctionValue < confirmationFunctionValue2 ? -1 : 0;
    }

    @Override // org.dllearner.core.LearningAlgorithm
    public void start() {
        PosNegUndLP posNegUndLP = (PosNegUndLP) this.learningProblem;
        SortedSet<OWLIndividual> sortedSet = (SortedSet) posNegUndLP.getPositiveExamples();
        SortedSet<OWLIndividual> sortedSet2 = (SortedSet) posNegUndLP.getNegativeExamples();
        SortedSet<OWLIndividual> sortedSet3 = (SortedSet) posNegUndLP.getUncertainExamples();
        this.prPos = sortedSet.size() / ((sortedSet.size() + sortedSet2.size()) + sortedSet3.size());
        this.prNeg = sortedSet2.size() / ((sortedSet.size() + sortedSet2.size()) + sortedSet3.size());
        double d = this.prPos + this.prNeg;
        if (d == 0.0d) {
            this.prPos = 0.5d;
            this.prNeg = 0.5d;
        } else {
            this.prPos /= d;
            this.prNeg /= d;
        }
        System.out.printf("New learning problem prepared.\n", new Object[0]);
        System.out.println("Learning a tree ");
        this.currentmodel = induceDSTDLTree(sortedSet, sortedSet2, sortedSet3);
        stop();
    }

    @Override // org.dllearner.core.AbstractCELA, org.dllearner.core.StoppableLearningAlgorithm
    public void stop() {
        this.stop = true;
    }

    @Override // org.dllearner.core.AbstractCELA, org.dllearner.core.StoppableLearningAlgorithm
    public boolean isRunning() {
        return !this.stop;
    }

    @Override // org.dllearner.core.AbstractCELA
    public OWLClassExpression getCurrentlyBestDescription() {
        return DSTDLTree.deriveDefinition(this.currentmodel, false);
    }

    @Override // org.dllearner.core.AbstractCELA
    public EvaluatedDescription getCurrentlyBestEvaluatedDescription() {
        return null;
    }
}
