package LBJ2.learn;

import LBJ2.classify.Classifier;
import LBJ2.classify.DiscreteFeature;
import LBJ2.classify.FeatureVector;
import LBJ2.classify.ScoreSet;
import LBJ2.classify.ValueComparer;
import java.io.PrintStream;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;

/* loaded from: input_file:LBJ2/learn/SparseNetworkLearner.class */
public class SparseNetworkLearner extends Learner {
    public static final LinearThresholdUnit defaultBaseLTU = new SparsePerceptron();
    protected LinearThresholdUnit baseLTU;
    protected HashMap network;

    /* loaded from: input_file:LBJ2/learn/SparseNetworkLearner$Parameters.class */
    public static class Parameters extends LBJ2.learn.Parameters {
        public LinearThresholdUnit baseLTU = (LinearThresholdUnit) SparseNetworkLearner.defaultBaseLTU.clone();
    }

    public SparseNetworkLearner() {
        this(WekaWrapper.defaultAttributeString);
    }

    public SparseNetworkLearner(LinearThresholdUnit linearThresholdUnit) {
        this(WekaWrapper.defaultAttributeString, linearThresholdUnit);
    }

    public SparseNetworkLearner(Parameters parameters) {
        this(WekaWrapper.defaultAttributeString, parameters);
    }

    public SparseNetworkLearner(String str) {
        super(str);
        this.baseLTU = (LinearThresholdUnit) defaultBaseLTU.clone();
        this.network = new HashMap();
    }

    public SparseNetworkLearner(String str, LinearThresholdUnit linearThresholdUnit) {
        super(str);
        if (!linearThresholdUnit.getOutputType().equals("discrete")) {
            System.err.println("LBJ WARNING: SparseNetworkLearner will only work with a LinearThresholdUnit that returns discrete.");
            System.err.println(new StringBuffer().append("             The given LTU, ").append(linearThresholdUnit.getClass().getName()).append(", returns ").append(linearThresholdUnit.getOutputType()).append(".").toString());
        }
        setLTU(linearThresholdUnit);
        this.network = new HashMap();
    }

    public SparseNetworkLearner(String str, Parameters parameters) {
        super(str);
        if (!parameters.baseLTU.getOutputType().equals("discrete")) {
            System.err.println("LBJ WARNING: SparseNetworkLearner will only work with a LinearThresholdUnit that returns discrete.");
            System.err.println(new StringBuffer().append("             The given LTU, ").append(parameters.baseLTU.getClass().getName()).append(", returns ").append(parameters.baseLTU.getOutputType()).append(".").toString());
        }
        setLTU(parameters.baseLTU);
        this.network = new HashMap();
    }

    public void setLTU(LinearThresholdUnit linearThresholdUnit) {
        this.baseLTU = linearThresholdUnit;
        this.baseLTU.name = new StringBuffer().append(this.name).append("$baseLTU").toString();
    }

    @Override // LBJ2.learn.Learner
    public void setLabeler(Classifier classifier) {
        if (getClass().getName().indexOf("SparseNetworkLearner") != -1 && !classifier.getOutputType().equals("discrete")) {
            System.err.println("LBJ WARNING: SparseNetworkLearner will only work with a label classifier that returns discrete.");
            System.err.println(new StringBuffer().append("             The given label classifier, ").append(classifier.getClass().getName()).append(", returns ").append(classifier.getOutputType()).append(".").toString());
        }
        super.setLabeler(classifier);
        for (Map.Entry entry : this.network.entrySet()) {
            ((ValueComparer) ((LinearThresholdUnit) entry.getValue()).getLabeler()).setLabeler(this.labeler);
        }
    }

    @Override // LBJ2.learn.Learner
    public void setExtractor(Classifier classifier) {
        super.setExtractor(classifier);
        this.baseLTU.setExtractor(classifier);
        Iterator it = this.network.values().iterator();
        while (it.hasNext()) {
            ((LinearThresholdUnit) it.next()).setExtractor(classifier);
        }
    }

    @Override // LBJ2.learn.Learner
    public void learn(Object obj) {
        String value = ((DiscreteFeature) this.labeler.classify(obj).firstFeature()).getValue();
        if (!this.network.containsKey(value)) {
            LinearThresholdUnit linearThresholdUnit = (LinearThresholdUnit) this.baseLTU.clone();
            linearThresholdUnit.setLabeler(new ValueComparer(this.labeler, value));
            linearThresholdUnit.setExtractor(this.extractor);
            this.network.put(value, linearThresholdUnit);
        }
        Iterator it = this.network.values().iterator();
        while (it.hasNext()) {
            ((LinearThresholdUnit) it.next()).learn(obj);
        }
    }

    @Override // LBJ2.learn.Learner
    public void doneLearning() {
        Iterator it = this.network.values().iterator();
        while (it.hasNext()) {
            ((LinearThresholdUnit) it.next()).doneLearning();
        }
    }

    @Override // LBJ2.learn.Learner
    public void forget() {
        this.network.clear();
    }

    @Override // LBJ2.learn.Learner
    public ScoreSet scores(Object obj) {
        ScoreSet scoreSet = new ScoreSet();
        for (Map.Entry entry : this.network.entrySet()) {
            LinearThresholdUnit linearThresholdUnit = (LinearThresholdUnit) entry.getValue();
            scoreSet.put((String) entry.getKey(), linearThresholdUnit.score(obj) - linearThresholdUnit.getThreshold());
        }
        return scoreSet;
    }

    @Override // LBJ2.classify.Classifier
    public FeatureVector classify(Object obj) {
        double d = Double.NEGATIVE_INFINITY;
        String str = null;
        for (Map.Entry entry : this.network.entrySet()) {
            double score = ((LinearThresholdUnit) entry.getValue()).score(obj);
            if (score > d) {
                str = (String) entry.getKey();
                d = score;
            }
        }
        return str == null ? new FeatureVector() : new FeatureVector(new DiscreteFeature(this.containingPackage, this.name, str, valueIndexOf(str), (short) allowableValues().length));
    }

    public String valueOf(Object obj, Collection collection) {
        double d = Double.NEGATIVE_INFINITY;
        String str = null;
        Iterator it = collection.size() == 0 ? this.network.keySet().iterator() : collection.iterator();
        while (it.hasNext()) {
            String str2 = (String) it.next();
            double d2 = Double.NEGATIVE_INFINITY;
            if (this.network.containsKey(str2)) {
                d2 = ((LinearThresholdUnit) this.network.get(str2)).score(obj);
            }
            if (d2 > d) {
                str = str2;
                d = d2;
            }
        }
        return str;
    }

    public ScoreSet scores(Object obj, Collection collection) {
        ScoreSet scoreSet = new ScoreSet();
        Iterator it = collection.size() == 0 ? this.network.keySet().iterator() : collection.iterator();
        while (it.hasNext()) {
            String str = (String) it.next();
            double d = Double.NEGATIVE_INFINITY;
            if (this.network.containsKey(str)) {
                LinearThresholdUnit linearThresholdUnit = (LinearThresholdUnit) this.network.get(str);
                d = linearThresholdUnit.score(obj) - linearThresholdUnit.getThreshold();
            }
            scoreSet.put(str, d);
        }
        return scoreSet;
    }

    @Override // LBJ2.learn.Learner
    public void write(PrintStream printStream) {
        printStream.println(this.baseLTU.getClass().getName());
        this.baseLTU.write(printStream);
        Map.Entry[] entryArr = (Map.Entry[]) this.network.entrySet().toArray(new Map.Entry[0]);
        Arrays.sort(entryArr, new Comparator(this) { // from class: LBJ2.learn.SparseNetworkLearner.1
            private final SparseNetworkLearner this$0;

            {
                this.this$0 = this;
            }

            @Override // java.util.Comparator
            public int compare(Object obj, Object obj2) {
                return ((String) ((Map.Entry) obj).getKey()).compareTo((String) ((Map.Entry) obj2).getKey());
            }
        });
        for (int i = 0; i < entryArr.length; i++) {
            printStream.println(new StringBuffer().append("label: ").append(entryArr[i].getKey()).toString());
            ((LinearThresholdUnit) entryArr[i].getValue()).write(printStream);
        }
        printStream.println("End of SparseNetworkLearner");
    }

    @Override // LBJ2.classify.Classifier
    public Object clone() {
        SparseNetworkLearner sparseNetworkLearner = null;
        try {
            sparseNetworkLearner = (SparseNetworkLearner) super.clone();
        } catch (Exception e) {
            System.err.println(new StringBuffer().append("Error cloning SparseNetworkLearner: ").append(e).toString());
            e.printStackTrace();
            System.exit(1);
        }
        sparseNetworkLearner.baseLTU = (LinearThresholdUnit) this.baseLTU.clone();
        sparseNetworkLearner.network = new HashMap();
        for (Map.Entry entry : this.network.entrySet()) {
            sparseNetworkLearner.network.put(entry.getKey(), ((LinearThresholdUnit) entry.getValue()).clone());
        }
        return sparseNetworkLearner;
    }
}
