package LBJ2.learn;

import LBJ2.classify.Classifier;
import LBJ2.classify.DiscreteFeature;
import LBJ2.classify.Feature;
import LBJ2.classify.FeatureVector;
import LBJ2.classify.FeatureVectorReturner;
import LBJ2.classify.LabelVectorReturner;
import LBJ2.classify.ScoreSet;
import java.io.PrintStream;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Map;

/* loaded from: input_file:LBJ2/learn/MuxLearner.class */
public class MuxLearner extends Learner {
    public static final Learner defaultBaseLearner;
    public static final String defaultDefaultPrediction;
    protected Learner baseLearner;
    protected HashMap network;
    protected String defaultPrediction;
    protected Classifier select;
    protected LinkedList compositeChildren;
    static final boolean $assertionsDisabled;
    static Class class$LBJ2$learn$MuxLearner;

    /* loaded from: input_file:LBJ2/learn/MuxLearner$Parameters.class */
    public static class Parameters extends LBJ2.learn.Parameters {
        public Learner baseLearner = (Learner) MuxLearner.defaultBaseLearner.clone();
        public String defaultPrediction = MuxLearner.defaultDefaultPrediction;
    }

    public MuxLearner() {
        this(WekaWrapper.defaultAttributeString);
    }

    public MuxLearner(Learner learner) {
        this(WekaWrapper.defaultAttributeString, learner);
    }

    public MuxLearner(Learner learner, String str) {
        this(WekaWrapper.defaultAttributeString, learner, str);
    }

    public MuxLearner(Parameters parameters) {
        this(WekaWrapper.defaultAttributeString, parameters);
    }

    public MuxLearner(String str) {
        this(str, (Learner) defaultBaseLearner.clone());
    }

    public MuxLearner(String str, Learner learner) {
        this(str, learner, defaultDefaultPrediction);
    }

    public MuxLearner(String str, Learner learner, String str2) {
        super(str);
        setBase(learner);
        this.defaultPrediction = str2;
        this.network = new HashMap();
    }

    public MuxLearner(String str, Parameters parameters) {
        super(str);
        this.baseLearner = parameters.baseLearner;
        this.defaultPrediction = parameters.defaultPrediction;
        this.network = new HashMap();
    }

    public void setBase(Learner learner) {
        this.baseLearner = learner;
        this.baseLearner.containingPackage = this.containingPackage;
        this.baseLearner.name = this.name;
        this.baseLearner.setLabeler(new LabelVectorReturner());
        this.baseLearner.setExtractor(new FeatureVectorReturner());
    }

    @Override // LBJ2.learn.Learner
    public void setExtractor(Classifier classifier) {
        super.setExtractor(classifier);
        try {
            this.compositeChildren = this.extractor.getCompositeChildren();
            this.select = (Classifier) this.compositeChildren.removeFirst();
        } catch (UnsupportedOperationException e) {
            this.compositeChildren = null;
            this.select = null;
        }
    }

    @Override // LBJ2.learn.Learner
    public void learn(Object obj) {
        FeatureVector featureVector;
        FeatureVector featureVector2;
        FeatureVector classify = this.labeler.classify(obj);
        if (this.select != null) {
            featureVector2 = this.select.classify(obj);
            featureVector = new FeatureVector();
            Iterator it = this.compositeChildren.iterator();
            while (it.hasNext()) {
                featureVector.addFeatures(((Classifier) it.next()).classify(obj));
            }
            if (!$assertionsDisabled && featureVector2.size() != classify.size()) {
                throw new AssertionError(new StringBuffer().append("MuxLearner ERROR: Learner selections and labels have differing sizes: ").append(classify).append(", ").append(featureVector2).toString());
            }
        } else {
            featureVector = (FeatureVector) this.extractor.classify(obj).clone();
            featureVector2 = new FeatureVector((Feature) featureVector.features.removeFirst());
            for (int i = 1; i < classify.size(); i++) {
                featureVector2.addFeature((Feature) featureVector.features.removeFirst());
            }
        }
        Iterator it2 = classify.iterator();
        Iterator it3 = featureVector2.iterator();
        while (it3.hasNext()) {
            featureVector.addLabel((Feature) it2.next());
            DiscreteFeature discreteFeature = (DiscreteFeature) it3.next();
            Learner learner = (Learner) this.network.get(discreteFeature.getValue());
            if (learner == null) {
                learner = (Learner) this.baseLearner.clone();
                this.network.put(discreteFeature.getValue(), learner);
            }
            learner.learn(featureVector);
            featureVector = (FeatureVector) featureVector.clone();
            featureVector.labels.clear();
        }
    }

    @Override // LBJ2.learn.Learner
    public void forget() {
        this.network.clear();
    }

    @Override // LBJ2.learn.Learner
    public ScoreSet scores(Object obj) {
        FeatureVector featureVector;
        String value;
        if (this.select != null) {
            value = ((DiscreteFeature) this.select.classify(obj).firstFeature()).getValue();
            featureVector = new FeatureVector();
            Iterator it = this.compositeChildren.iterator();
            while (it.hasNext()) {
                featureVector.addFeatures(((Classifier) it.next()).classify(obj));
            }
        } else {
            featureVector = (FeatureVector) this.extractor.classify(obj).clone();
            value = ((DiscreteFeature) featureVector.features.removeFirst()).getValue();
        }
        Learner learner = (Learner) this.network.get(value);
        return learner == null ? new ScoreSet(new String[]{this.defaultPrediction}, new double[]{1.0d}) : learner.scores(featureVector);
    }

    @Override // LBJ2.classify.Classifier
    public FeatureVector classify(Object obj) {
        FeatureVector featureVector;
        String value;
        if (this.select != null) {
            value = ((DiscreteFeature) this.select.classify(obj).firstFeature()).getValue();
            featureVector = new FeatureVector();
            Iterator it = this.compositeChildren.iterator();
            while (it.hasNext()) {
                featureVector.addFeatures(((Classifier) it.next()).classify(obj));
            }
        } else {
            featureVector = (FeatureVector) this.extractor.classify(obj).clone();
            value = ((DiscreteFeature) featureVector.features.removeFirst()).getValue();
        }
        Learner learner = (Learner) this.network.get(value);
        return learner == null ? new FeatureVector(new DiscreteFeature(this.containingPackage, this.name, this.defaultPrediction, valueIndexOf(this.defaultPrediction), (short) allowableValues().length)) : learner.classify(featureVector);
    }

    @Override // LBJ2.learn.Learner
    public void write(PrintStream printStream) {
        Map.Entry[] entryArr = (Map.Entry[]) this.network.entrySet().toArray(new Map.Entry[0]);
        Arrays.sort(entryArr, new Comparator(this) { // from class: LBJ2.learn.MuxLearner.1
            private final MuxLearner this$0;

            {
                this.this$0 = this;
            }

            @Override // java.util.Comparator
            public int compare(Object obj, Object obj2) {
                return ((String) ((Map.Entry) obj).getKey()).compareTo((String) ((Map.Entry) obj2).getKey());
            }
        });
        for (int i = 0; i < entryArr.length; i++) {
            printStream.println(new StringBuffer().append("select: ").append(entryArr[i].getKey()).toString());
            ((Learner) entryArr[i].getValue()).write(printStream);
        }
        printStream.println("End of MuxLearner");
    }

    @Override // LBJ2.classify.Classifier
    public Object clone() {
        MuxLearner muxLearner = null;
        try {
            muxLearner = (MuxLearner) super.clone();
        } catch (Exception e) {
            System.err.println(new StringBuffer().append("Error cloning MuxLearner: ").append(e).toString());
            e.printStackTrace();
            System.exit(1);
        }
        muxLearner.baseLearner = (Learner) this.baseLearner.clone();
        muxLearner.network = new HashMap();
        for (Map.Entry entry : this.network.entrySet()) {
            muxLearner.network.put(entry.getKey(), ((Learner) entry.getValue()).clone());
        }
        return muxLearner;
    }

    static Class class$(String str) {
        try {
            return Class.forName(str);
        } catch (ClassNotFoundException e) {
            throw new NoClassDefFoundError().initCause(e);
        }
    }

    static {
        Class cls;
        if (class$LBJ2$learn$MuxLearner == null) {
            cls = class$("LBJ2.learn.MuxLearner");
            class$LBJ2$learn$MuxLearner = cls;
        } else {
            cls = class$LBJ2$learn$MuxLearner;
        }
        $assertionsDisabled = !cls.desiredAssertionStatus();
        defaultBaseLearner = new SparsePerceptron();
        defaultDefaultPrediction = null;
    }
}
