package experimental.analyzer.simple;

import cc.mallet.optimize.Optimizable;
import experimental.analyzer.AnalyzerTag;
import experimental.analyzer.simple.SimpleAnalyzer;
import experimental.analyzer.simple.SimpleAnalyzerTrainer;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
import marmot.morph.mapper.latin.LdtMorphTag;
import marmot.util.Mutable;
import marmot.util.Numerics;

/* loaded from: input_file:experimental/analyzer/simple/SimpleAnalyzerObjective.class */
public class SimpleAnalyzerObjective implements Optimizable.ByGradientValue {
    private SimpleAnalyzerModel model_;
    private Collection<SimpleAnalyzerInstance> instances_;
    private double value_;
    private double[] gradient_;
    private double[] weights_;
    private double penalty_;
    private SimpleAnalyzer.Mode mode_;
    private double[] scores;
    private double[] updates;
    private Map<AnalyzerTag, Map<AnalyzerTag, Mutable<Double>>> relative_counts_;
    private SimpleAnalyzerTrainer.PairConstraint pair_constraint_;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: experimental.analyzer.simple.SimpleAnalyzerObjective$1, reason: invalid class name */
    /* loaded from: input_file:experimental/analyzer/simple/SimpleAnalyzerObjective$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$experimental$analyzer$simple$SimpleAnalyzer$Mode = new int[SimpleAnalyzer.Mode.values().length];

        static {
            try {
                $SwitchMap$experimental$analyzer$simple$SimpleAnalyzer$Mode[SimpleAnalyzer.Mode.binary.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$experimental$analyzer$simple$SimpleAnalyzer$Mode[SimpleAnalyzer.Mode.classifier.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    public SimpleAnalyzerObjective(double d, SimpleAnalyzerModel simpleAnalyzerModel, Collection<SimpleAnalyzerInstance> collection, SimpleAnalyzer.Mode mode, Map<AnalyzerTag, Map<AnalyzerTag, Mutable<Double>>> map, SimpleAnalyzerTrainer.PairConstraint pairConstraint) {
        this.model_ = simpleAnalyzerModel;
        this.instances_ = collection;
        this.weights_ = simpleAnalyzerModel.getWeights();
        this.gradient_ = new double[this.weights_.length];
        this.penalty_ = d;
        this.mode_ = mode;
        this.relative_counts_ = map;
        int numTags = this.model_.getNumTags();
        this.scores = new double[numTags];
        this.updates = new double[numTags];
        this.pair_constraint_ = pairConstraint;
    }

    public void update() {
        this.value_ = 0.0d;
        Arrays.fill(this.gradient_, 0.0d);
        Iterator<SimpleAnalyzerInstance> it = this.instances_.iterator();
        while (it.hasNext()) {
            update(it.next(), 1.0d, false);
        }
        for (int i = 0; i < this.weights_.length; i++) {
            double d = this.weights_[i];
            this.value_ -= (this.penalty_ * d) * d;
            double[] dArr = this.gradient_;
            int i2 = i;
            dArr[i2] = dArr[i2] - ((2.0d * this.penalty_) * d);
        }
    }

    public void update(SimpleAnalyzerInstance simpleAnalyzerInstance, double d, boolean z) {
        Arrays.fill(this.scores, 0.0d);
        Arrays.fill(this.updates, 0.0d);
        int numTags = this.model_.getNumTags();
        this.model_.setWeights(this.weights_);
        this.model_.score(simpleAnalyzerInstance, this.scores);
        switch (AnonymousClass1.$SwitchMap$experimental$analyzer$simple$SimpleAnalyzer$Mode[this.mode_.ordinal()]) {
            case 1:
                this.value_ += binaryUpdate(this.scores, this.updates, numTags, simpleAnalyzerInstance);
                break;
            case LdtMorphTag.NumberIndex /* 2 */:
                this.value_ += classifierUpdate(this.scores, this.updates, numTags, simpleAnalyzerInstance);
                break;
            default:
                throw new RuntimeException("Unsupported mode: " + this.mode_);
        }
        if (!z) {
            this.model_.setWeights(this.gradient_);
        }
        if (!Numerics.approximatelyEqual(d, 1.0d)) {
            for (int i = 0; i < numTags; i++) {
                double[] dArr = this.updates;
                int i2 = i;
                dArr[i2] = dArr[i2] * d;
            }
        }
        this.model_.update(simpleAnalyzerInstance, this.updates);
        this.model_.setWeights(this.weights_);
    }

    private double classifierUpdate(double[] dArr, double[] dArr2, int i, SimpleAnalyzerInstance simpleAnalyzerInstance) {
        double d = Double.NEGATIVE_INFINITY;
        int size = simpleAnalyzerInstance.getTagIndexes().size();
        for (int i2 = 0; i2 < i; i2++) {
            d = Numerics.sumLogProb(dArr[i2], d);
        }
        double d2 = 0.0d - (size * d);
        for (int i3 = 0; i3 < i; i3++) {
            dArr2[i3] = (-size) * Math.exp(dArr[i3] - d);
        }
        if (this.pair_constraint_ != SimpleAnalyzerTrainer.PairConstraint.none) {
            Iterator<Integer> it = simpleAnalyzerInstance.getTagIndexes().iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                Map<AnalyzerTag, Mutable<Double>> map = this.relative_counts_.get(this.model_.getTagTable().toSymbol(Integer.valueOf(intValue)));
                if (map != null) {
                    for (Map.Entry<AnalyzerTag, Mutable<Double>> entry : map.entrySet()) {
                        int index = this.model_.getTagTable().toIndex(entry.getKey());
                        double doubleValue = entry.getValue().get().doubleValue();
                        if (this.pair_constraint_ == SimpleAnalyzerTrainer.PairConstraint.weighted) {
                            d2 += doubleValue * dArr[index];
                            dArr2[index] = dArr2[index] + doubleValue;
                        } else if (index != intValue) {
                            dArr2[index] = 0.0d;
                        }
                    }
                }
                if (map == null || this.pair_constraint_ == SimpleAnalyzerTrainer.PairConstraint.simple) {
                    d2 += dArr[intValue];
                    dArr2[intValue] = dArr2[intValue] + 1.0d;
                }
            }
        } else {
            Iterator<Integer> it2 = simpleAnalyzerInstance.getTagIndexes().iterator();
            while (it2.hasNext()) {
                int intValue2 = it2.next().intValue();
                d2 += dArr[intValue2];
                dArr2[intValue2] = dArr2[intValue2] + 1.0d;
            }
        }
        return d2;
    }

    private double binaryUpdate(double[] dArr, double[] dArr2, int i, SimpleAnalyzerInstance simpleAnalyzerInstance) {
        double d = 0.0d;
        for (int i2 = 0; i2 < i; i2++) {
            double sumLogProb = Numerics.sumLogProb(dArr[i2], 0.0d);
            d -= sumLogProb;
            dArr2[i2] = -Math.exp(dArr[i2] - sumLogProb);
        }
        Iterator<Integer> it = simpleAnalyzerInstance.getTagIndexes().iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            d += dArr[intValue];
            dArr2[intValue] = dArr2[intValue] + 1.0d;
        }
        return d;
    }

    public int getNumParameters() {
        return this.weights_.length;
    }

    public double getParameter(int i) {
        throw new UnsupportedOperationException();
    }

    public void getParameters(double[] dArr) {
        System.arraycopy(this.weights_, 0, dArr, 0, this.weights_.length);
    }

    public void setParameter(int i, double d) {
        throw new UnsupportedOperationException();
    }

    public void setParameters(double[] dArr) {
        System.arraycopy(dArr, 0, this.weights_, 0, this.weights_.length);
        update();
    }

    public double getValue() {
        return this.value_;
    }

    public void getValueGradient(double[] dArr) {
        System.arraycopy(this.gradient_, 0, dArr, 0, this.gradient_.length);
    }
}
