package edu.berkeley.nlp.discPCFG;

import edu.berkeley.nlp.math.DoubleArrays;
import edu.berkeley.nlp.math.SloppyMath;
import edu.berkeley.nlp.util.Pair;

/* loaded from: input_file:edu/berkeley/nlp/discPCFG/ProperNameObjectiveFunction.class */
public class ProperNameObjectiveFunction<F, L> implements ObjectiveFunction {
    IndexLinearizer indexLinearizer;
    Encoding<F, L> encoding;
    EncodedDatum[] data;
    double[] x;
    double sigma;
    double lastValue;
    double[] lastDerivative;
    double[] lastX;
    boolean isUpToDate;

    @Override // edu.berkeley.nlp.discPCFG.ObjectiveFunction
    public void shutdown() {
    }

    public void updateGoldCountsNextRound() {
    }

    @Override // edu.berkeley.nlp.math.Function
    public int dimension() {
        return this.indexLinearizer.getNumLinearIndexes();
    }

    @Override // edu.berkeley.nlp.math.Function
    public double valueAt(double[] dArr) {
        ensureCache(dArr);
        this.isUpToDate = false;
        return this.lastValue;
    }

    @Override // edu.berkeley.nlp.math.DifferentiableFunction
    public double[] derivativeAt(double[] dArr) {
        ensureCache(dArr);
        this.isUpToDate = false;
        return this.lastDerivative;
    }

    @Override // edu.berkeley.nlp.math.DifferentiableRegularizableFunction
    public double[] unregularizedDerivativeAt(double[] dArr) {
        return null;
    }

    private void ensureCache(double[] dArr) {
        if (this.isUpToDate) {
            return;
        }
        this.x = dArr;
        Pair<Double, double[]> calculate = calculate();
        this.lastValue = calculate.getFirst().doubleValue();
        this.lastDerivative = calculate.getSecond();
        this.lastX = dArr;
    }

    public void setX(double[] dArr) {
        this.x = dArr;
    }

    public void isUpToDate(boolean z) {
        this.isUpToDate = z;
    }

    public Pair<Double, double[]> calculate() {
        double d = 0.0d;
        System.out.println("In Calculate...");
        double[] constantArray = DoubleArrays.constantArray(0.0d, dimension());
        int numSubLabels = this.encoding.getNumSubLabels();
        int length = this.data.length;
        for (int i = 0; i < length; i++) {
            EncodedDatum encodedDatum = this.data[i];
            double[] logProbabilities = getLogProbabilities(encodedDatum, this.x, this.encoding, this.indexLinearizer);
            int labelIndex = encodedDatum.getLabelIndex();
            double[] weights = encodedDatum.getWeights();
            int length2 = weights.length;
            int labelSubindexBegin = this.encoding.getLabelSubindexBegin(labelIndex);
            for (int i2 = 0; i2 < length2; i2++) {
                d -= weights[i2] * logProbabilities[labelSubindexBegin + i2];
            }
            double[] dArr = new double[numSubLabels];
            double d2 = 0.0d;
            for (int i3 = 0; i3 < numSubLabels; i3++) {
                dArr[i3] = Math.exp(logProbabilities[i3]);
                d2 += dArr[i3];
            }
            if (Math.abs(d2 - 1.0d) > 0.001d) {
                System.err.println("Probabilities do not sum to 1!");
            }
            for (int i4 = 0; i4 < encodedDatum.getNumActiveFeatures(); i4++) {
                int featureIndex = encodedDatum.getFeatureIndex(i4);
                double featureCount = encodedDatum.getFeatureCount(i4);
                for (int i5 = 0; i5 < numSubLabels; i5++) {
                    int linearIndex = this.indexLinearizer.getLinearIndex(featureIndex, i5);
                    constantArray[linearIndex] = constantArray[linearIndex] + (featureCount * dArr[i5]);
                }
                for (int i6 = 0; i6 < length2; i6++) {
                    int linearIndex2 = this.indexLinearizer.getLinearIndex(featureIndex, labelSubindexBegin + i6);
                    constantArray[linearIndex2] = constantArray[linearIndex2] - (weights[i6] * featureCount);
                }
            }
        }
        double d3 = this.sigma * this.sigma;
        double d4 = 0.0d;
        for (int i7 = 0; i7 < this.x.length; i7++) {
            d4 += this.x[i7] * this.x[i7];
        }
        double d5 = d + (d4 / (2.0d * d3));
        for (int i8 = 0; i8 < this.x.length; i8++) {
            int i9 = i8;
            constantArray[i9] = constantArray[i9] + (this.x[i8] / d3);
        }
        return new Pair<>(Double.valueOf(d5), constantArray);
    }

    @Override // edu.berkeley.nlp.discPCFG.ObjectiveFunction
    public <F, L> double[] getLogProbabilities(EncodedDatum encodedDatum, double[] dArr, Encoding<F, L> encoding, IndexLinearizer indexLinearizer) {
        int numSubLabels = encoding.getNumSubLabels();
        double[] constantArray = DoubleArrays.constantArray(0.0d, numSubLabels);
        for (int i = 0; i < encodedDatum.getNumActiveFeatures(); i++) {
            int featureIndex = encodedDatum.getFeatureIndex(i);
            double featureCount = encodedDatum.getFeatureCount(i);
            for (int i2 = 0; i2 < numSubLabels; i2++) {
                int i3 = i2;
                constantArray[i3] = constantArray[i3] + (dArr[indexLinearizer.getLinearIndex(featureIndex, i2)] * featureCount);
            }
        }
        double logAdd = SloppyMath.logAdd(constantArray);
        for (int i4 = 0; i4 < numSubLabels; i4++) {
            int i5 = i4;
            constantArray[i5] = constantArray[i5] - logAdd;
        }
        return constantArray;
    }

    public ProperNameObjectiveFunction(Encoding<F, L> encoding, EncodedDatum[] encodedDatumArr, IndexLinearizer indexLinearizer, double d) {
        this.indexLinearizer = indexLinearizer;
        this.encoding = encoding;
        this.data = encodedDatumArr;
        this.sigma = d;
    }
}
