package LBJ2.learn;

import LBJ2.classify.Classifier;
import LBJ2.classify.DiscreteFeature;
import LBJ2.classify.FeatureVector;
import LBJ2.classify.Score;
import LBJ2.classify.ScoreSet;
import java.io.PrintStream;
import java.util.Arrays;

/* loaded from: input_file:LBJ2/learn/AdaBoost.class */
public class AdaBoost extends Learner {
    protected Learner weakLearner;
    protected int rounds;
    private Learner[] weakLearners;
    private double[] alpha;

    public AdaBoost() {
    }

    public AdaBoost(String str, Learner learner, int i) {
        super(str);
        this.weakLearner = learner;
        this.rounds = i;
    }

    @Override // LBJ2.learn.Learner
    public void setLabeler(Classifier classifier) {
        super.setLabeler(classifier);
        this.weakLearner.setLabeler(classifier);
    }

    @Override // LBJ2.learn.Learner
    public void setExtractor(Classifier classifier) {
        super.setExtractor(classifier);
        this.weakLearner.setExtractor(classifier);
    }

    @Override // LBJ2.learn.Learner
    public void learn(Object obj) {
        System.err.println("AdaBoost cannot be trained in an online fashion.");
        System.err.println("Use learn(Object[]) instead.");
    }

    @Override // LBJ2.learn.Learner
    public void learn(Object[] objArr) {
        double[] dArr = new double[objArr.length];
        Arrays.fill(dArr, 1.0d / objArr.length);
        this.weakLearners = new Learner[this.rounds];
        this.alpha = new double[this.rounds];
        for (int i = 0; i < this.rounds; i++) {
            Object[] objArr2 = new Object[objArr.length];
            for (int i2 = 0; i2 < objArr.length; i2++) {
                double random = Math.random();
                double d = 0.0d;
                int i3 = 0;
                while (d <= random) {
                    int i4 = i3;
                    i3++;
                    d += dArr[i4];
                }
                objArr2[i2] = objArr[i3 - 1];
            }
            this.weakLearners[i] = (Learner) this.weakLearner.clone();
            this.weakLearners[i].learn(objArr2);
            FeatureVector[] classify = this.weakLearners[i].classify(objArr2);
            double d2 = 0.0d;
            for (int i5 = 0; i5 < objArr2.length; i5++) {
                if (classify[i5].equals(this.labeler.classify(objArr2[i5]))) {
                    d2 += dArr[i5];
                    classify[i5] = null;
                }
            }
            double d3 = d2 / (1.0d - d2);
            this.alpha[i] = Math.log(d3) / 2.0d;
            if (i + 1 < this.rounds) {
                double sqrt = Math.sqrt(d3);
                double d4 = 0.0d;
                for (int i6 = 0; i6 < objArr2.length; i6++) {
                    if (classify[i6] == null) {
                        int i7 = i6;
                        dArr[i7] = dArr[i7] / sqrt;
                    } else {
                        int i8 = i6;
                        dArr[i8] = dArr[i8] * sqrt;
                    }
                    d4 += dArr[i6];
                }
                for (int i9 = 0; i9 < objArr2.length; i9++) {
                    int i10 = i9;
                    dArr[i10] = dArr[i10] / d4;
                }
            }
        }
    }

    @Override // LBJ2.learn.Learner
    public void forget() {
        this.weakLearners = null;
        this.alpha = null;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // LBJ2.learn.Learner
    public ScoreSet scores(Object obj) {
        double[] dArr = new double[2];
        String[] strArr = new String[2];
        for (int i = 0; i < this.rounds; i++) {
            String value = ((DiscreteFeature) this.weakLearners[i].classify(obj).firstFeature()).getValue();
            if (strArr[0] == 0) {
                strArr[0] = value;
                dArr[0] = this.alpha[i];
            } else if (value.equals(strArr[0])) {
                dArr[0] = dArr[0] + this.alpha[i];
            } else {
                strArr[1] = value;
                dArr[1] = dArr[1] + this.alpha[i];
            }
        }
        return new ScoreSet(strArr, dArr);
    }

    @Override // LBJ2.classify.Classifier
    public FeatureVector classify(Object obj) {
        Score[] array = scores(obj).toArray();
        boolean z = array[0].score <= array[1].score;
        return new FeatureVector(new DiscreteFeature(this.containingPackage, this.name, array[z ? 1 : 0].value, valueIndexOf(array[z ? 1 : 0].value), (short) allowableValues().length));
    }

    @Override // LBJ2.learn.Learner
    public void write(PrintStream printStream) {
        printStream.println(this.name);
        if (this.rounds > 0) {
            printStream.print(this.alpha[0]);
            for (int i = 1; i < this.rounds; i++) {
                printStream.print(new StringBuffer().append(", ").append(this.alpha[i]).toString());
            }
            printStream.println();
        } else {
            printStream.println("---");
        }
        printStream.println(this.weakLearner.getClass().getName());
        this.weakLearner.write(printStream);
        if (this.rounds > 0) {
            this.weakLearners[0].write(printStream);
            for (int i2 = 1; i2 < this.rounds; i2++) {
                this.weakLearners[i2].write(printStream);
            }
        }
    }

    @Override // LBJ2.classify.Classifier
    public Object clone() {
        AdaBoost adaBoost = null;
        try {
            adaBoost = (AdaBoost) super.clone();
        } catch (Exception e) {
            System.err.println(new StringBuffer().append("Error cloning AdaBoost: ").append(e).toString());
            System.exit(1);
        }
        if (adaBoost.weakLearners != null) {
            adaBoost.weakLearners = (Learner[]) adaBoost.weakLearners.clone();
            for (int i = 0; i < adaBoost.weakLearners.length; i++) {
                adaBoost.weakLearners[i] = (Learner) adaBoost.weakLearners.clone();
            }
            adaBoost.alpha = (double[]) adaBoost.alpha.clone();
        }
        return adaBoost;
    }
}
