/*
 * Decompiled with CFR 0.152.
 */
package edu.berkeley.nlp.discPCFG;

import edu.berkeley.nlp.discPCFG.EncodedDatum;
import edu.berkeley.nlp.discPCFG.Encoding;
import edu.berkeley.nlp.discPCFG.IndexLinearizer;
import edu.berkeley.nlp.discPCFG.ObjectiveFunction;
import edu.berkeley.nlp.math.DoubleArrays;
import edu.berkeley.nlp.math.SloppyMath;
import edu.berkeley.nlp.util.Pair;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class ProperNameObjectiveFunction<F, L>
implements ObjectiveFunction {
    IndexLinearizer indexLinearizer;
    Encoding<F, L> encoding;
    EncodedDatum[] data;
    double[] x;
    double sigma;
    double lastValue;
    double[] lastDerivative;
    double[] lastX;
    boolean isUpToDate;

    @Override
    public void shutdown() {
    }

    public void updateGoldCountsNextRound() {
    }

    @Override
    public int dimension() {
        return this.indexLinearizer.getNumLinearIndexes();
    }

    @Override
    public double valueAt(double[] x) {
        this.ensureCache(x);
        this.isUpToDate = false;
        return this.lastValue;
    }

    @Override
    public double[] derivativeAt(double[] x) {
        this.ensureCache(x);
        this.isUpToDate = false;
        return this.lastDerivative;
    }

    @Override
    public double[] unregularizedDerivativeAt(double[] x) {
        return null;
    }

    private void ensureCache(double[] x) {
        if (!this.isUpToDate) {
            this.x = x;
            Pair<Double, double[]> currentValueAndDerivative = this.calculate();
            this.lastValue = currentValueAndDerivative.getFirst();
            this.lastDerivative = currentValueAndDerivative.getSecond();
            this.lastX = x;
        }
    }

    public void setX(double[] x) {
        this.x = x;
    }

    public void isUpToDate(boolean b) {
        this.isUpToDate = b;
    }

    public Pair<Double, double[]> calculate() {
        int index;
        double objective = 0.0;
        System.out.println("In Calculate...");
        double[] derivatives = DoubleArrays.constantArray(0.0, this.dimension());
        int numSubLabels = this.encoding.getNumSubLabels();
        for (EncodedDatum datum : this.data) {
            double[] logProbabilities = this.getLogProbabilities(datum, this.x, this.encoding, this.indexLinearizer);
            int C = datum.getLabelIndex();
            double[] labelWeights = datum.getWeights();
            int numSubstatesC = labelWeights.length;
            int substate0 = this.encoding.getLabelSubindexBegin(C);
            for (int c = 0; c < numSubstatesC; ++c) {
                objective -= labelWeights[c] * logProbabilities[substate0 + c];
            }
            double[] probabilities = new double[numSubLabels];
            double sum = 0.0;
            for (int c = 0; c < numSubLabels; ++c) {
                probabilities[c] = Math.exp(logProbabilities[c]);
                sum += probabilities[c];
            }
            if (Math.abs(sum - 1.0) > 0.001) {
                System.err.println("Probabilities do not sum to 1!");
            }
            for (int i = 0; i < datum.getNumActiveFeatures(); ++i) {
                int index2;
                int c;
                int featureIndex = datum.getFeatureIndex(i);
                double featureCount = datum.getFeatureCount(i);
                for (c = 0; c < numSubLabels; ++c) {
                    int n = index2 = this.indexLinearizer.getLinearIndex(featureIndex, c);
                    derivatives[n] = derivatives[n] + featureCount * probabilities[c];
                }
                for (c = 0; c < numSubstatesC; ++c) {
                    int n = index2 = this.indexLinearizer.getLinearIndex(featureIndex, substate0 + c);
                    derivatives[n] = derivatives[n] - labelWeights[c] * featureCount;
                }
            }
        }
        double sigma2 = this.sigma * this.sigma;
        double penalty = 0.0;
        for (index = 0; index < this.x.length; ++index) {
            penalty += this.x[index] * this.x[index];
        }
        objective += penalty / (2.0 * sigma2);
        for (index = 0; index < this.x.length; ++index) {
            int n = index;
            derivatives[n] = derivatives[n] + this.x[index] / sigma2;
        }
        return new Pair<Double, double[]>(objective, derivatives);
    }

    @Override
    public <F, L> double[] getLogProbabilities(EncodedDatum datum, double[] weights, Encoding<F, L> encoding, IndexLinearizer indexLinearizer) {
        int numSubLabels = encoding.getNumSubLabels();
        double[] logProbabilities = DoubleArrays.constantArray(0.0, numSubLabels);
        for (int i = 0; i < datum.getNumActiveFeatures(); ++i) {
            int featureIndex = datum.getFeatureIndex(i);
            double featureCount = datum.getFeatureCount(i);
            int j = 0;
            while (j < numSubLabels) {
                int index = indexLinearizer.getLinearIndex(featureIndex, j);
                double weight = weights[index];
                int n = j++;
                logProbabilities[n] = logProbabilities[n] + weight * featureCount;
            }
        }
        double logNormalizer = SloppyMath.logAdd(logProbabilities);
        int i = 0;
        while (i < numSubLabels) {
            int n = i++;
            logProbabilities[n] = logProbabilities[n] - logNormalizer;
        }
        return logProbabilities;
    }

    public ProperNameObjectiveFunction(Encoding<F, L> encoding, EncodedDatum[] data, IndexLinearizer indexLinearizer, double sigma) {
        this.indexLinearizer = indexLinearizer;
        this.encoding = encoding;
        this.data = data;
        this.sigma = sigma;
    }
}

