package LBJ2.learn;

import LBJ2.classify.DiscreteFeature;
import LBJ2.classify.FeatureVector;
import LBJ2.classify.MultiValueComparer;
import LBJ2.classify.Score;
import LBJ2.learn.SparseNetworkLearner;
import java.util.Iterator;

/* loaded from: input_file:LBJ2/learn/MultiLabelLearner.class */
public class MultiLabelLearner extends SparseNetworkLearner {

    /* loaded from: input_file:LBJ2/learn/MultiLabelLearner$Parameters.class */
    public static class Parameters extends SparseNetworkLearner.Parameters {
    }

    public MultiLabelLearner() {
        this(WekaWrapper.defaultAttributeString);
    }

    public MultiLabelLearner(LinearThresholdUnit linearThresholdUnit) {
        this(WekaWrapper.defaultAttributeString, linearThresholdUnit);
    }

    public MultiLabelLearner(Parameters parameters) {
        this(WekaWrapper.defaultAttributeString, parameters);
    }

    public MultiLabelLearner(String str) {
        super(str);
    }

    public MultiLabelLearner(String str, LinearThresholdUnit linearThresholdUnit) {
        super(str, linearThresholdUnit);
    }

    public MultiLabelLearner(String str, Parameters parameters) {
        super(str, parameters);
    }

    @Override // LBJ2.classify.Classifier
    public String getOutputType() {
        return "discrete%";
    }

    @Override // LBJ2.learn.SparseNetworkLearner, LBJ2.learn.Learner
    public void learn(Object obj) {
        Iterator it = this.labeler.classify(obj).iterator();
        while (it.hasNext()) {
            String value = ((DiscreteFeature) it.next()).getValue();
            if (!this.network.containsKey(value)) {
                LinearThresholdUnit linearThresholdUnit = (LinearThresholdUnit) this.baseLTU.clone();
                linearThresholdUnit.setLabeler(new MultiValueComparer(this.labeler, value));
                linearThresholdUnit.setExtractor(this.extractor);
                this.network.put(value, linearThresholdUnit);
            }
        }
        Iterator it2 = this.network.values().iterator();
        while (it2.hasNext()) {
            ((LinearThresholdUnit) it2.next()).learn(obj);
        }
    }

    @Override // LBJ2.learn.SparseNetworkLearner, LBJ2.classify.Classifier
    public FeatureVector classify(Object obj) {
        Score[] array = scores(obj).toArray();
        FeatureVector featureVector = new FeatureVector();
        for (int i = 0; i < array.length; i++) {
            if (array[i].score >= 0.0d) {
                featureVector.addFeature(new DiscreteFeature(this.containingPackage, this.name, array[i].value, valueIndexOf(array[i].value), (short) allowableValues().length));
            }
        }
        return featureVector;
    }
}
