package LBJ2.learn;

import LBJ2.classify.DiscreteArrayFeature;
import LBJ2.classify.DiscreteFeature;
import LBJ2.classify.Feature;
import LBJ2.classify.FeatureVector;
import LBJ2.classify.RealArrayFeature;
import LBJ2.classify.RealFeature;
import LBJ2.learn.SparsePerceptron;
import LBJ2.learn.SparseWeightVector;
import java.io.PrintStream;
import java.util.HashMap;
import java.util.Map;

/* loaded from: input_file:LBJ2/learn/SparseAveragedPerceptron.class */
public class SparseAveragedPerceptron extends SparsePerceptron {
    public static final AveragedWeightVector defaultWeightVector;
    protected double averagedBias;
    static final boolean $assertionsDisabled;
    static Class class$LBJ2$learn$SparseAveragedPerceptron;

    /* loaded from: input_file:LBJ2/learn/SparseAveragedPerceptron$AveragedWeightVector.class */
    public static class AveragedWeightVector extends SparseWeightVector {
        protected int examples;

        /* JADX INFO: Access modifiers changed from: protected */
        /* loaded from: input_file:LBJ2/learn/SparseAveragedPerceptron$AveragedWeightVector$AveragedWeightIterator.class */
        public class AveragedWeightIterator extends SparseWeightVector.WeightIterator {
            private final AveragedWeightVector this$0;

            /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
            public AveragedWeightIterator(AveragedWeightVector averagedWeightVector, FeatureVector featureVector) {
                super(averagedWeightVector, featureVector);
                this.this$0 = averagedWeightVector;
            }

            public Double getSimpleWeight() {
                if (!this.isDiscrete) {
                    if (this.currentFeature.fromArray()) {
                        if (this.weightArray != null) {
                            return new Double(this.weightArray[2 * this.arrayIndex]);
                        }
                        return null;
                    }
                    double[] dArr = (double[]) this.this$0.weights.get(new RealFeature(this.currentFeature.getPackage(), this.currentFeature.getIdentifier(), 0.0d));
                    if (dArr != null) {
                        return new Double(dArr[0]);
                    }
                    return null;
                }
                if (this.currentFeature.totalValues() < 2) {
                    double[] dArr2 = (double[]) this.this$0.weights.get(this.currentFeature);
                    if (dArr2 != null) {
                        return new Double(dArr2[0]);
                    }
                    return null;
                }
                if (this.currentFeature.fromArray()) {
                    if (this.weightArray == null) {
                        return null;
                    }
                    if (this.currentFeature.totalValues() == 2) {
                        return new Double(this.weightArray[2 * this.arrayIndex]);
                    }
                    DiscreteFeature discreteFeature = (DiscreteFeature) this.currentFeature;
                    return new Double(this.weightArray[2 * ((this.arrayIndex * discreteFeature.totalValues()) + discreteFeature.getValueIndex())]);
                }
                if (this.currentFeature.totalValues() != 2) {
                    if (this.weightArray == null) {
                        return null;
                    }
                    return new Double(this.weightArray[2 * ((DiscreteFeature) this.currentFeature).getValueIndex()]);
                }
                double[] dArr3 = (double[]) this.this$0.weights.get(new DiscreteFeature(this.currentFeature.getPackage(), this.currentFeature.getIdentifier(), WekaWrapper.defaultAttributeString));
                if (dArr3 != null) {
                    return new Double(dArr3[0]);
                }
                return null;
            }

            @Override // LBJ2.learn.SparseWeightVector.WeightIterator
            public Double getWeight() {
                int valueIndex;
                if (!this.isDiscrete) {
                    if (this.currentFeature.fromArray()) {
                        if (this.weightArray == null) {
                            return null;
                        }
                        int i = 2 * this.arrayIndex;
                        return new Double(((this.this$0.examples * this.weightArray[i]) - this.weightArray[i + 1]) / this.this$0.examples);
                    }
                    double[] dArr = (double[]) this.this$0.weights.get(new RealFeature(this.currentFeature.getPackage(), this.currentFeature.getIdentifier(), 0.0d));
                    if (dArr != null) {
                        return new Double(((this.this$0.examples * dArr[0]) - dArr[1]) / this.this$0.examples);
                    }
                    return null;
                }
                if (this.currentFeature.totalValues() < 2) {
                    double[] dArr2 = (double[]) this.this$0.weights.get(this.currentFeature);
                    if (dArr2 != null) {
                        return new Double(((this.this$0.examples * dArr2[0]) - dArr2[1]) / this.this$0.examples);
                    }
                    return null;
                }
                if (this.currentFeature.fromArray()) {
                    if (this.weightArray == null) {
                        return null;
                    }
                    if (this.currentFeature.totalValues() == 2) {
                        valueIndex = 2 * this.arrayIndex;
                    } else {
                        DiscreteFeature discreteFeature = (DiscreteFeature) this.currentFeature;
                        valueIndex = 2 * ((this.arrayIndex * discreteFeature.totalValues()) + discreteFeature.getValueIndex());
                    }
                    return new Double(((this.this$0.examples * this.weightArray[valueIndex]) - this.weightArray[valueIndex + 1]) / this.this$0.examples);
                }
                if (this.currentFeature.totalValues() != 2) {
                    if (this.weightArray == null) {
                        return null;
                    }
                    int valueIndex2 = 2 * ((DiscreteFeature) this.currentFeature).getValueIndex();
                    return new Double(((this.this$0.examples * this.weightArray[valueIndex2]) - this.weightArray[valueIndex2 + 1]) / this.this$0.examples);
                }
                double[] dArr3 = (double[]) this.this$0.weights.get(new DiscreteFeature(this.currentFeature.getPackage(), this.currentFeature.getIdentifier(), WekaWrapper.defaultAttributeString));
                if (dArr3 != null) {
                    return new Double(((this.this$0.examples * dArr3[0]) - dArr3[1]) / this.this$0.examples);
                }
                return null;
            }

            @Override // LBJ2.learn.SparseWeightVector.WeightIterator
            public void setWeight(double d) {
                setWeight(d, 0.0d);
            }

            @Override // LBJ2.learn.SparseWeightVector.WeightIterator
            public void setWeight(double d, double d2) {
                if (!this.isDiscrete) {
                    if (!this.currentFeature.fromArray()) {
                        RealFeature realFeature = new RealFeature(this.currentFeature.getPackage(), this.currentFeature.getIdentifier(), 0.0d);
                        double[] dArr = (double[]) this.this$0.weights.get(realFeature);
                        if (dArr == null) {
                            realFeature.intern();
                            this.this$0.weights.put(realFeature, dArr);
                            dArr = new double[]{d2};
                        }
                        double d3 = d - dArr[0];
                        dArr[0] = d;
                        double[] dArr2 = dArr;
                        dArr2[1] = dArr2[1] + (this.this$0.examples * d3);
                        return;
                    }
                    if (this.weightArray == null) {
                        RealArrayFeature realArrayFeature = (RealArrayFeature) this.currentFeature;
                        this.weightArray = new double[2 * realArrayFeature.getArrayLength()];
                        RealFeature realFeature2 = new RealFeature(realArrayFeature.getPackage(), realArrayFeature.getIdentifier(), 0.0d);
                        realFeature2.intern();
                        this.this$0.weights.put(realFeature2, this.weightArray);
                        if (d2 != 0.0d) {
                            for (int i = 0; i < this.weightArray.length; i += 2) {
                                this.weightArray[i] = d2;
                            }
                        }
                    }
                    int i2 = 2 * this.arrayIndex;
                    double d4 = d - this.weightArray[i2];
                    this.weightArray[i2] = d;
                    double[] dArr3 = this.weightArray;
                    int i3 = i2 + 1;
                    dArr3[i3] = dArr3[i3] + (this.this$0.examples * d4);
                    return;
                }
                if (this.currentFeature.totalValues() < 2) {
                    double[] dArr4 = (double[]) this.this$0.weights.get(this.currentFeature);
                    if (dArr4 == null) {
                        this.currentFeature.intern();
                        this.this$0.weights.put(this.currentFeature, dArr4);
                        dArr4 = new double[]{d2};
                    }
                    double d5 = d - dArr4[0];
                    dArr4[0] = d;
                    double[] dArr5 = dArr4;
                    dArr5[1] = dArr5[1] + (this.this$0.examples * d5);
                    return;
                }
                if (!this.currentFeature.fromArray()) {
                    if (this.currentFeature.totalValues() == 2) {
                        DiscreteFeature discreteFeature = new DiscreteFeature(this.currentFeature.getPackage(), this.currentFeature.getIdentifier(), WekaWrapper.defaultAttributeString);
                        double[] dArr6 = (double[]) this.this$0.weights.get(discreteFeature);
                        if (dArr6 == null) {
                            discreteFeature.intern();
                            this.this$0.weights.put(discreteFeature, dArr6);
                            dArr6 = new double[]{d2};
                        }
                        double d6 = d - dArr6[0];
                        dArr6[0] = d;
                        double[] dArr7 = dArr6;
                        dArr7[1] = dArr7[1] + (this.this$0.examples * d6);
                        return;
                    }
                    if (this.weightArray == null) {
                        this.weightArray = new double[2 * this.currentFeature.totalValues()];
                        DiscreteFeature discreteFeature2 = new DiscreteFeature(this.currentFeature.getPackage(), this.currentFeature.getIdentifier(), WekaWrapper.defaultAttributeString);
                        discreteFeature2.intern();
                        this.this$0.weights.put(discreteFeature2, this.weightArray);
                        if (d2 != 0.0d) {
                            for (int i4 = 0; i4 < this.weightArray.length; i4 += 2) {
                                this.weightArray[i4] = d2;
                            }
                        }
                    }
                    int valueIndex = 2 * ((DiscreteFeature) this.currentFeature).getValueIndex();
                    double d7 = d - this.weightArray[valueIndex];
                    this.weightArray[valueIndex] = d;
                    double[] dArr8 = this.weightArray;
                    int i5 = valueIndex + 1;
                    dArr8[i5] = dArr8[i5] + (this.this$0.examples * d7);
                    return;
                }
                if (this.currentFeature.totalValues() == 2) {
                    if (this.weightArray == null) {
                        DiscreteArrayFeature discreteArrayFeature = (DiscreteArrayFeature) this.currentFeature;
                        this.weightArray = new double[2 * discreteArrayFeature.getArrayLength()];
                        DiscreteFeature discreteFeature3 = new DiscreteFeature(discreteArrayFeature.getPackage(), discreteArrayFeature.getIdentifier(), WekaWrapper.defaultAttributeString);
                        discreteFeature3.intern();
                        this.this$0.weights.put(discreteFeature3, this.weightArray);
                        if (d2 != 0.0d) {
                            for (int i6 = 0; i6 < this.weightArray.length; i6 += 2) {
                                this.weightArray[i6] = d2;
                            }
                        }
                    }
                    int i7 = 2 * this.arrayIndex;
                    double d8 = d - this.weightArray[i7];
                    this.weightArray[i7] = d;
                    double[] dArr9 = this.weightArray;
                    int i8 = i7 + 1;
                    dArr9[i8] = dArr9[i8] + (this.this$0.examples * d8);
                    return;
                }
                DiscreteArrayFeature discreteArrayFeature2 = (DiscreteArrayFeature) this.currentFeature;
                if (this.weightArray == null) {
                    this.weightArray = new double[2 * discreteArrayFeature2.getArrayLength() * discreteArrayFeature2.totalValues()];
                    DiscreteFeature discreteFeature4 = new DiscreteFeature(discreteArrayFeature2.getPackage(), discreteArrayFeature2.getIdentifier(), WekaWrapper.defaultAttributeString);
                    discreteFeature4.intern();
                    this.this$0.weights.put(discreteFeature4, this.weightArray);
                    if (d2 != 0.0d) {
                        for (int i9 = 0; i9 < this.weightArray.length; i9 += 2) {
                            this.weightArray[i9] = d2;
                        }
                    }
                }
                int valueIndex2 = 2 * ((this.arrayIndex * discreteArrayFeature2.totalValues()) + discreteArrayFeature2.getValueIndex());
                double d9 = d - this.weightArray[valueIndex2];
                this.weightArray[valueIndex2] = d;
                double[] dArr10 = this.weightArray;
                int i10 = valueIndex2 + 1;
                dArr10[i10] = dArr10[i10] + (this.this$0.examples * d9);
            }
        }

        public AveragedWeightVector() {
            super(new HashMap());
        }

        public AveragedWeightVector(HashMap hashMap) {
            super(hashMap);
        }

        public void correctExample() {
            this.examples++;
        }

        public int getExamples() {
            return this.examples;
        }

        @Override // LBJ2.learn.SparseWeightVector
        public SparseWeightVector.WeightIterator weightIterator(FeatureVector featureVector) {
            return new AveragedWeightIterator(this, featureVector);
        }

        public double simpleDot(FeatureVector featureVector) {
            return simpleDot(featureVector, 0.0d);
        }

        public double simpleDot(FeatureVector featureVector, double d) {
            SparseWeightVector.WeightIterator weightIterator = weightIterator(featureVector);
            double d2 = 0.0d;
            while (true) {
                double d3 = d2;
                if (!weightIterator.hasNext()) {
                    return d3;
                }
                weightIterator.next();
                Double simpleWeight = ((AveragedWeightIterator) weightIterator).getSimpleWeight();
                d2 = d3 + ((simpleWeight == null ? d : simpleWeight.doubleValue()) * weightIterator.getCurrentFeatureStrength());
            }
        }

        @Override // LBJ2.learn.SparseWeightVector
        public void scaledAdd(FeatureVector featureVector, double d, double d2) {
            SparseWeightVector.WeightIterator weightIterator = weightIterator(featureVector);
            while (weightIterator.hasNext()) {
                weightIterator.next();
                Double simpleWeight = ((AveragedWeightIterator) weightIterator).getSimpleWeight();
                weightIterator.setWeight((simpleWeight == null ? d2 : simpleWeight.doubleValue()) + (weightIterator.getCurrentFeatureStrength() * d), d2);
            }
            this.examples++;
        }

        @Override // LBJ2.learn.SparseWeightVector
        public Object clone() {
            AveragedWeightVector averagedWeightVector = new AveragedWeightVector();
            for (Map.Entry entry : this.weights.entrySet()) {
                averagedWeightVector.weights.put(entry.getKey(), ((double[]) entry.getValue()).clone());
            }
            return averagedWeightVector;
        }

        @Override // LBJ2.learn.SparseWeightVector
        public String toString() {
            return new StringBuffer().append(this.examples).append("\n").append(super.toString()).toString();
        }
    }

    /* loaded from: input_file:LBJ2/learn/SparseAveragedPerceptron$Parameters.class */
    public static class Parameters extends SparsePerceptron.Parameters {
        public Parameters() {
            this.weightVector = (AveragedWeightVector) SparseAveragedPerceptron.defaultWeightVector.clone();
        }
    }

    public SparseAveragedPerceptron() {
        this(WekaWrapper.defaultAttributeString);
    }

    public SparseAveragedPerceptron(double d) {
        this(WekaWrapper.defaultAttributeString, d);
    }

    public SparseAveragedPerceptron(double d, double d2) {
        this(WekaWrapper.defaultAttributeString, d, d2);
    }

    public SparseAveragedPerceptron(double d, double d2, double d3) {
        this(WekaWrapper.defaultAttributeString, d, d2, d3);
    }

    public SparseAveragedPerceptron(double d, double d2, double d3, double d4) {
        this(WekaWrapper.defaultAttributeString, d, d2, d3, d4);
    }

    public SparseAveragedPerceptron(Parameters parameters) {
        this(WekaWrapper.defaultAttributeString, parameters);
    }

    public SparseAveragedPerceptron(String str) {
        this(str, 0.1d);
    }

    public SparseAveragedPerceptron(String str, double d) {
        this(str, d, 0.0d);
    }

    public SparseAveragedPerceptron(String str, double d, double d2) {
        this(str, d, d2, 0.0d);
    }

    public SparseAveragedPerceptron(String str, double d, double d2, double d3) {
        this(str, d, d2, d3, d3);
    }

    public SparseAveragedPerceptron(String str, double d, double d2, double d3, double d4) {
        super(str, d, d2, d3, d4, null);
        this.weightVector = (AveragedWeightVector) defaultWeightVector.clone();
    }

    public SparseAveragedPerceptron(String str, Parameters parameters) {
        super(str, parameters);
    }

    @Override // LBJ2.learn.LinearThresholdUnit
    public double score(Object obj) {
        double dot = this.weightVector.dot(this.extractor.classify(obj), this.initialWeight);
        int examples = ((AveragedWeightVector) this.weightVector).getExamples();
        if (examples > 0) {
            dot += ((examples * this.bias) - this.averagedBias) / examples;
        }
        return dot;
    }

    @Override // LBJ2.learn.SparsePerceptron, LBJ2.learn.LinearThresholdUnit
    public void promote(Object obj) {
        this.bias += this.learningRate;
        this.averagedBias += ((AveragedWeightVector) this.weightVector).getExamples() * this.learningRate;
        this.weightVector.scaledAdd(this.extractor.classify(obj), this.learningRate, this.initialWeight);
    }

    @Override // LBJ2.learn.SparsePerceptron, LBJ2.learn.LinearThresholdUnit
    public void demote(Object obj) {
        this.bias -= this.learningRate;
        this.averagedBias -= ((AveragedWeightVector) this.weightVector).getExamples() * this.learningRate;
        this.weightVector.scaledAdd(this.extractor.classify(obj), -this.learningRate, this.initialWeight);
    }

    @Override // LBJ2.learn.LinearThresholdUnit, LBJ2.learn.Learner
    public void learn(Object obj) {
        Feature firstFeature = this.labeler.classify(obj).firstFeature();
        if (!$assertionsDisabled && firstFeature == null) {
            throw new AssertionError("An LTU's label classifier must always produce the same feature.");
        }
        if (!$assertionsDisabled && !(firstFeature instanceof DiscreteFeature)) {
            throw new AssertionError("An LTU's label classifier must always produce a single discrete feature.");
        }
        DiscreteFeature discreteFeature = (DiscreteFeature) firstFeature;
        if (!$assertionsDisabled && !discreteFeature.valueEquals(this.allowableValues[0]) && !discreteFeature.valueEquals(this.allowableValues[1])) {
            throw new AssertionError("Example has unallowed label value.");
        }
        boolean z = discreteFeature.getValueIndex() == 1 || (discreteFeature.getValueIndex() == -1 && discreteFeature.valueEquals(this.allowableValues[1]));
        double simpleDot = ((AveragedWeightVector) this.weightVector).simpleDot(this.extractor.classify(obj), this.initialWeight) + this.bias;
        if (z && simpleDot < this.threshold + this.positiveThickness) {
            promote(obj);
        } else if (z || simpleDot < this.threshold - this.negativeThickness) {
            ((AveragedWeightVector) this.weightVector).correctExample();
        } else {
            demote(obj);
        }
    }

    @Override // LBJ2.learn.LinearThresholdUnit, LBJ2.learn.Learner
    public void forget() {
        super.forget();
        this.averagedBias = 0.0d;
    }

    @Override // LBJ2.learn.SparsePerceptron, LBJ2.learn.Learner
    public void write(PrintStream printStream) {
        printStream.println(new StringBuffer().append(this.name).append(": ").append(this.learningRate).append(", ").append(this.initialWeight).append(", ").append(this.threshold).append(", ").append(this.positiveThickness).append(", ").append(this.negativeThickness).append(", ").append(this.bias).append(", ").append(this.averagedBias).toString());
        this.weightVector.write(printStream);
    }

    static Class class$(String str) {
        try {
            return Class.forName(str);
        } catch (ClassNotFoundException e) {
            throw new NoClassDefFoundError().initCause(e);
        }
    }

    static {
        Class cls;
        if (class$LBJ2$learn$SparseAveragedPerceptron == null) {
            cls = class$("LBJ2.learn.SparseAveragedPerceptron");
            class$LBJ2$learn$SparseAveragedPerceptron = cls;
        } else {
            cls = class$LBJ2$learn$SparseAveragedPerceptron;
        }
        $assertionsDisabled = !cls.desiredAssertionStatus();
        defaultWeightVector = new AveragedWeightVector();
    }
}
