/*
 * Decompiled with CFR 0.152.
 */
package edu.berkeley.nlp.classify;

import edu.berkeley.nlp.classify.BasicFeatureVector;
import edu.berkeley.nlp.classify.BasicLabeledFeatureVector;
import edu.berkeley.nlp.classify.Encoding;
import edu.berkeley.nlp.classify.FeatureExtractor;
import edu.berkeley.nlp.classify.FeatureVector;
import edu.berkeley.nlp.classify.IndexLinearizer;
import edu.berkeley.nlp.classify.LabeledFeatureVector;
import edu.berkeley.nlp.classify.LabeledInstance;
import edu.berkeley.nlp.classify.ProbabilisticClassifier;
import edu.berkeley.nlp.classify.ProbabilisticClassifierFactory;
import edu.berkeley.nlp.math.DifferentiableFunction;
import edu.berkeley.nlp.math.DoubleArrays;
import edu.berkeley.nlp.math.LBFGSMinimizer;
import edu.berkeley.nlp.math.SloppyMath;
import edu.berkeley.nlp.util.Counter;
import edu.berkeley.nlp.util.Indexer;
import edu.berkeley.nlp.util.Logger;
import edu.berkeley.nlp.util.Pair;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class MaximumEntropyClassifier<I, F, L>
implements ProbabilisticClassifier<I, L>,
Serializable {
    private static final long serialVersionUID = 1L;
    private double[] weights;
    private Encoding<F, L> encoding;
    private IndexLinearizer indexLinearizer;
    private transient FeatureExtractor<I, F> featureExtractor;

    public void setFeatureExtractor(FeatureExtractor<I, F> featureExtractor) {
        this.featureExtractor = featureExtractor;
    }

    private static <F, L> double[] getLogProbabilities(EncodedDatum datum, double[] weights, Encoding<F, L> encoding, IndexLinearizer indexLinearizer) {
        double[] logProbabilities = new double[encoding.getNumLabels()];
        for (int labelIndex = 0; labelIndex < encoding.getNumLabels(); ++labelIndex) {
            for (int num = 0; num < datum.getNumActiveFeatures(); ++num) {
                int featureIndex = datum.getFeatureIndex(num);
                double featureCount = datum.getFeatureCount(num);
                int linearFeatureIndex = indexLinearizer.getLinearIndex(featureIndex, labelIndex);
                int n = labelIndex;
                logProbabilities[n] = logProbabilities[n] + weights[linearFeatureIndex] * featureCount;
            }
        }
        double logSumProb = SloppyMath.logAdd(logProbabilities);
        int labelIndex = 0;
        while (labelIndex < encoding.getNumLabels()) {
            int n = labelIndex++;
            logProbabilities[n] = logProbabilities[n] - logSumProb;
        }
        return logProbabilities;
    }

    @Override
    public Counter<L> getProbabilities(I input) {
        BasicFeatureVector<F> featureVector = new BasicFeatureVector<F>(this.featureExtractor.extractFeatures(input));
        return this.getProbabilities(featureVector);
    }

    @Override
    private Counter<L> getProbabilities(FeatureVector<F> featureVector) {
        EncodedDatum encodedDatum = EncodedDatum.encodeDatum(featureVector, this.encoding);
        double[] logProbabilities = MaximumEntropyClassifier.getLogProbabilities(encodedDatum, this.weights, this.encoding, this.indexLinearizer);
        return this.logProbabiltyArrayToProbabiltyCounter(logProbabilities);
    }

    private Counter<L> logProbabiltyArrayToProbabiltyCounter(double[] logProbabilities) {
        Counter<L> probabiltyCounter = new Counter<L>();
        for (int labelIndex = 0; labelIndex < logProbabilities.length; ++labelIndex) {
            double logProbability = logProbabilities[labelIndex];
            double probability = Math.exp(logProbability);
            L label = this.encoding.getLabel(labelIndex);
            probabiltyCounter.setCount(label, probability);
        }
        return probabiltyCounter;
    }

    @Override
    public L getLabel(I input) {
        return this.getProbabilities(input).argMax();
    }

    public MaximumEntropyClassifier(double[] weights, Encoding<F, L> encoding, IndexLinearizer indexLinearizer, FeatureExtractor<I, F> featureExtractor) {
        this.weights = weights;
        this.encoding = encoding;
        this.indexLinearizer = indexLinearizer;
        this.featureExtractor = featureExtractor;
    }

    public static void main(String[] args) {
        LabeledInstance<String[], String> datum1 = new LabeledInstance<String[], String>("cat", new String[]{"fuzzy", "claws", "small"});
        LabeledInstance<String[], String> datum2 = new LabeledInstance<String[], String>("bear", new String[]{"fuzzy", "claws", "big"});
        LabeledInstance<String[], String> datum3 = new LabeledInstance<String[], String>("cat", new String[]{"claws", "medium"});
        LabeledInstance<String[], String> datum4 = new LabeledInstance<String[], String>("cat", new String[]{"claws", "small"});
        ArrayList trainingData = new ArrayList();
        trainingData.add(datum1);
        trainingData.add(datum2);
        trainingData.add(datum3);
        ArrayList<LabeledInstance<String[], String>> testData = new ArrayList<LabeledInstance<String[], String>>();
        testData.add(datum4);
        FeatureExtractor<String[], String> featureExtractor = new FeatureExtractor<String[], String>(){
            private static final long serialVersionUID = 8296036312980792350L;

            @Override
            public Counter<String> extractFeatures(String[] featureArray) {
                return new Counter<String>(Arrays.asList(featureArray));
            }
        };
        Factory maximumEntropyClassifierFactory = new Factory(1.0, 20, featureExtractor);
        ProbabilisticClassifier maximumEntropyClassifier = maximumEntropyClassifierFactory.trainClassifier(trainingData);
        System.out.println("Probabilities on test instance: " + maximumEntropyClassifier.getProbabilities(datum4.getInput()));
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    public static class EncodedDatum {
        int labelIndex;
        int[] featureIndexes;
        double[] featureCounts;

        public static <F, L> EncodedDatum encodeDatum(FeatureVector<F> featureVector, Encoding<F, L> encoding) {
            Counter<F> features = featureVector.getFeatures();
            Counter<Object> knownFeatures = new Counter<Object>();
            for (F feature : features.keySet()) {
                if (encoding.getFeatureIndex(feature) < 0) continue;
                knownFeatures.incrementCount(feature, features.getCount(feature));
            }
            int numActiveFeatures = knownFeatures.keySet().size();
            int[] featureIndexes = new int[numActiveFeatures];
            double[] featureCounts = new double[knownFeatures.keySet().size()];
            int i = 0;
            for (Object feature : knownFeatures.keySet()) {
                int index = encoding.getFeatureIndex(feature);
                double count = knownFeatures.getCount(feature);
                featureIndexes[i] = index;
                featureCounts[i] = count;
                ++i;
            }
            EncodedDatum encodedDatum = new EncodedDatum(-1, featureIndexes, featureCounts);
            return encodedDatum;
        }

        public static <F, L> EncodedDatum encodeLabeledDatum(LabeledFeatureVector<F, L> labeledDatum, Encoding<F, L> encoding) {
            EncodedDatum encodedDatum = EncodedDatum.encodeDatum(labeledDatum, encoding);
            encodedDatum.labelIndex = encoding.getLabelIndex(labeledDatum.getLabel());
            return encodedDatum;
        }

        public int getLabelIndex() {
            return this.labelIndex;
        }

        public int getNumActiveFeatures() {
            return this.featureCounts.length;
        }

        public int getFeatureIndex(int num) {
            return this.featureIndexes[num];
        }

        public double getFeatureCount(int num) {
            return this.featureCounts[num];
        }

        public EncodedDatum(int labelIndex, int[] featureIndexes, double[] featureCounts) {
            this.labelIndex = labelIndex;
            this.featureIndexes = featureIndexes;
            this.featureCounts = featureCounts;
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    public static class ObjectiveFunction<F, L>
    implements DifferentiableFunction {
        IndexLinearizer indexLinearizer;
        Encoding<F, L> encoding;
        EncodedDatum[] data;
        double sigma;
        double lastValue;
        double[] lastDerivative;
        double[] lastX;

        @Override
        public int dimension() {
            return this.indexLinearizer.getNumLinearIndexes();
        }

        @Override
        public double valueAt(double[] x) {
            this.ensureCache(x);
            return this.lastValue;
        }

        @Override
        public double[] derivativeAt(double[] x) {
            this.ensureCache(x);
            return this.lastDerivative;
        }

        private void ensureCache(double[] x) {
            if (this.requiresUpdate(this.lastX, x)) {
                Pair<Double, double[]> currentValueAndDerivative = this.calculate(x);
                this.lastValue = currentValueAndDerivative.getFirst();
                this.lastDerivative = currentValueAndDerivative.getSecond();
                this.lastX = x;
            }
        }

        private boolean requiresUpdate(double[] lastX, double[] x) {
            if (lastX == null) {
                return true;
            }
            for (int i = 0; i < x.length; ++i) {
                if (lastX[i] == x[i]) continue;
                return true;
            }
            return false;
        }

        private Pair<Double, double[]> calculate(double[] x) {
            double objective = 0.0;
            double[] derivatives = DoubleArrays.constantArray(0.0, this.dimension());
            double[] classActivations = new double[this.encoding.getNumLabels()];
            double[] classPosteriors = new double[this.encoding.getNumLabels()];
            for (EncodedDatum datum : this.data) {
                int featureIndex;
                int num;
                int numActiveFeatures = datum.getNumActiveFeatures();
                for (int labelIndex = 0; labelIndex < this.encoding.getNumLabels(); ++labelIndex) {
                    double activation = 0.0;
                    for (num = 0; num < numActiveFeatures; ++num) {
                        featureIndex = datum.getFeatureIndex(num);
                        double featureCount = datum.getFeatureCount(num);
                        int linearFeatureIndex = this.indexLinearizer.getLinearIndex(featureIndex, labelIndex);
                        activation += x[linearFeatureIndex] * featureCount;
                    }
                    classActivations[labelIndex] = activation;
                }
                double logSumActivation = SloppyMath.logAdd(classActivations);
                int correctLabelIndex = datum.getLabelIndex();
                objective += classActivations[correctLabelIndex] - logSumActivation;
                for (int labelIndex = 0; labelIndex < this.encoding.getNumLabels(); ++labelIndex) {
                    classPosteriors[labelIndex] = SloppyMath.exp(classActivations[labelIndex] - logSumActivation);
                }
                for (num = 0; num < numActiveFeatures; ++num) {
                    featureIndex = datum.getFeatureIndex(num);
                    int correctLinearFeatureIndex = this.indexLinearizer.getLinearIndex(featureIndex, correctLabelIndex);
                    double featureCount = datum.getFeatureCount(num);
                    int n = correctLinearFeatureIndex;
                    derivatives[n] = derivatives[n] + featureCount;
                    for (int labelIndex = 0; labelIndex < this.encoding.getNumLabels(); ++labelIndex) {
                        int linearFeatureIndex = this.indexLinearizer.getLinearIndex(featureIndex, labelIndex);
                        double classProb = classPosteriors[labelIndex];
                        int n2 = linearFeatureIndex;
                        derivatives[n2] = derivatives[n2] - classProb * featureCount;
                    }
                }
            }
            objective *= -1.0;
            DoubleArrays.scale(derivatives, -1.0);
            int i = 0;
            while (i < x.length) {
                double weight = x[i];
                objective += weight * weight / (2.0 * this.sigma * this.sigma);
                int n = i++;
                derivatives[n] = derivatives[n] + weight / (this.sigma * this.sigma);
            }
            return new Pair<Double, double[]>(objective, derivatives);
        }

        public ObjectiveFunction(Encoding<F, L> encoding, EncodedDatum[] data, IndexLinearizer indexLinearizer, double sigma) {
            this.indexLinearizer = indexLinearizer;
            this.encoding = encoding;
            this.data = data;
            this.sigma = sigma;
        }

        public double[] unregularizedDerivativeAt(double[] x) {
            return null;
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    public static class Factory<I, F, L>
    implements ProbabilisticClassifierFactory<I, L> {
        double sigma;
        int iterations;
        FeatureExtractor<I, F> featureExtractor;

        @Override
        public ProbabilisticClassifier<I, L> trainClassifier(List<LabeledInstance<I, L>> trainingData) {
            return this.trainClassifier(trainingData, true);
        }

        public ProbabilisticClassifier<I, L> trainClassifier(List<LabeledInstance<I, L>> trainingData, boolean verbose) {
            if (verbose) {
                Logger.i().startTrack("Building encoding");
            }
            Encoding<F, L> encoding = this.buildEncoding(trainingData);
            IndexLinearizer indexLinearizer = this.buildIndexLinearizer(encoding);
            double[] initialWeights = this.buildInitialWeights(indexLinearizer);
            EncodedDatum[] data = this.encodeData(trainingData, encoding);
            if (verbose) {
                Logger.i().endTrack();
            }
            LBFGSMinimizer minimizer = new LBFGSMinimizer(this.iterations);
            ObjectiveFunction<F, L> objective = new ObjectiveFunction<F, L>(encoding, data, indexLinearizer, this.sigma);
            if (verbose) {
                Logger.i().startTrack("Training weights");
            }
            double[] weights = minimizer.minimize(objective, initialWeights, 1.0E-4, verbose);
            if (verbose) {
                Logger.i().endTrack();
            }
            return new MaximumEntropyClassifier<I, F, L>(weights, encoding, indexLinearizer, this.featureExtractor);
        }

        private double[] buildInitialWeights(IndexLinearizer indexLinearizer) {
            return DoubleArrays.constantArray(0.0, indexLinearizer.getNumLinearIndexes());
        }

        private IndexLinearizer buildIndexLinearizer(Encoding<F, L> encoding) {
            return new IndexLinearizer(encoding.getNumFeatures(), encoding.getNumLabels());
        }

        private Encoding<F, L> buildEncoding(List<LabeledInstance<I, L>> data) {
            Indexer featureIndexer = new Indexer();
            Indexer labelIndexer = new Indexer();
            for (LabeledInstance<I, L> labeledInstance : data) {
                L label = labeledInstance.getLabel();
                Counter<F> features = this.featureExtractor.extractFeatures(labeledInstance.getInput());
                BasicLabeledFeatureVector<F, L> labeledDatum = new BasicLabeledFeatureVector<F, L>(label, features);
                labelIndexer.getIndex(labeledDatum.getLabel());
                for (Object feature : labeledDatum.getFeatures().keySet()) {
                    featureIndexer.getIndex(feature);
                }
            }
            return new Encoding(featureIndexer, labelIndexer);
        }

        private EncodedDatum[] encodeData(List<LabeledInstance<I, L>> data, Encoding<F, L> encoding) {
            EncodedDatum[] encodedData = new EncodedDatum[data.size()];
            for (int i = 0; i < data.size(); ++i) {
                LabeledInstance<I, L> labeledInstance = data.get(i);
                L label = labeledInstance.getLabel();
                Counter<F> features = this.featureExtractor.extractFeatures(labeledInstance.getInput());
                BasicLabeledFeatureVector<F, L> labeledFeatureVector = new BasicLabeledFeatureVector<F, L>(label, features);
                encodedData[i] = EncodedDatum.encodeLabeledDatum(labeledFeatureVector, encoding);
            }
            return encodedData;
        }

        public Factory(double sigma, int iterations, FeatureExtractor<I, F> featureExtractor) {
            this.sigma = sigma;
            this.iterations = iterations;
            this.featureExtractor = featureExtractor;
        }
    }
}

