/*
 * Decompiled with CFR 0.152.
 */
package edu.berkeley.nlp.PCFGLA;

import edu.berkeley.nlp.PCFGLA.Featurizer;
import edu.berkeley.nlp.PCFGLA.Grammar;
import edu.berkeley.nlp.PCFGLA.GrammarTrainer;
import edu.berkeley.nlp.PCFGLA.Lexicon;
import edu.berkeley.nlp.PCFGLA.SimpleLexicon;
import edu.berkeley.nlp.PCFGLA.StateSetTreeList;
import edu.berkeley.nlp.PCFGLA.smoothing.Smoother;
import edu.berkeley.nlp.math.CachingDifferentiableFunction;
import edu.berkeley.nlp.math.DifferentiableFunction;
import edu.berkeley.nlp.math.DoubleArrays;
import edu.berkeley.nlp.math.LBFGSMinimizer;
import edu.berkeley.nlp.math.SloppyMath;
import edu.berkeley.nlp.syntax.StateSet;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.util.ArrayUtil;
import edu.berkeley.nlp.util.Counter;
import edu.berkeley.nlp.util.Indexer;
import edu.berkeley.nlp.util.Pair;
import edu.berkeley.nlp.util.ScalingTools;
import java.io.Serializable;
import java.util.List;
import java.util.Random;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class FeaturizedLexicon
implements Lexicon,
Serializable {
    private double[][][] expectedCounts;
    private double[][][] scores;
    private double[][] normalizers;
    public int[] wordCounter;
    private int[][] tagWordCounts;
    private int[][] tagWordsWithFeatures;
    private static final long serialVersionUID = 3L;
    public short[] numSubStates;
    int numStates;
    int nWords;
    double threshold;
    boolean isLogarithmMode;
    boolean useVarDP = false;
    private Indexer<String> wordIndexer = new Indexer();
    public int[][][][] indexedFeatures;
    Smoother smoother;
    private Featurizer featurizer;
    private Indexer<String> featureIndex = new Indexer();
    private double[] featureWeights;
    private double regularizationConstant = 1.0;
    private transient LBFGSMinimizer minimizer = new LBFGSMinimizer();
    private static final double PRIOR_MEAN = -3.0;

    public FeaturizedLexicon(short[] numSubStates, Featurizer featurizer, StateSetTreeList trainTrees) {
        this(numSubStates, featurizer);
        this.init(trainTrees);
    }

    public FeaturizedLexicon(short[] numSubStates, Featurizer featurizer) {
        this.numSubStates = numSubStates;
        this.wordIndexer = new Indexer();
        this.numStates = numSubStates.length;
        this.isLogarithmMode = false;
        this.featurizer = featurizer;
        this.minimizer.setMaxIterations(20);
    }

    public LBFGSMinimizer getMinimizer() {
        if (this.minimizer == null) {
            this.minimizer = new LBFGSMinimizer();
        }
        return this.minimizer;
    }

    private double[][][] projectWeightsToScores(double[] weights) {
        double[][][] thetas = new double[this.numStates][][];
        for (int tag = 0; tag < this.numStates; ++tag) {
            thetas[tag] = new double[this.numSubStates[tag]][];
            this.normalizers[tag] = new double[this.numSubStates[tag]];
            int expLength = this.expectedCounts[tag].length;
            for (int substate = 0; substate < expLength; ++substate) {
                thetas[tag][substate] = new double[this.wordIndexer.size()];
                double[] importantThetas = new double[this.tagWordsWithFeatures[tag].length];
                int j = 0;
                for (int word : this.tagWordsWithFeatures[tag]) {
                    double score = 0.0;
                    if (this.indexedFeatures[tag][substate][word].length == 0) {
                        throw new RuntimeException("Shouldn't be here!");
                    }
                    for (int f : this.indexedFeatures[tag][substate][word]) {
                        score += weights[f];
                    }
                    thetas[tag][substate][word] = score;
                    importantThetas[j++] = score;
                }
                this.normalizers[tag][substate] = SloppyMath.logAdd(importantThetas);
                for (int word : this.tagWordsWithFeatures[tag]) {
                    thetas[tag][substate][word] = Math.exp(thetas[tag][substate][word] - this.normalizers[tag][substate]);
                }
            }
        }
        this.isLogarithmMode = false;
        return thetas;
    }

    private DifferentiableFunction objective(final double[][][] expectedCounts) {
        final double[][] eTotals = new double[expectedCounts.length][];
        for (int tag = 0; tag < this.numStates; ++tag) {
            eTotals[tag] = new double[this.numSubStates[tag]];
            for (int substate = 0; substate < this.numSubStates[tag]; ++substate) {
                for (int word : this.tagWordsWithFeatures[tag]) {
                    double[] dArray = eTotals[tag];
                    int n = substate;
                    dArray[n] = dArray[n] + expectedCounts[tag][substate][word];
                }
                eTotals[tag][substate] = Math.log(eTotals[tag][substate]);
            }
        }
        return new CachingDifferentiableFunction(){

            @Override
            public int dimension() {
                return FeaturizedLexicon.this.featureWeights.length;
            }

            @Override
            public double valueAt(double[] x) {
                if (this.isCached(x)) {
                    return super.valueAt(x);
                }
                double[][][] thetas = FeaturizedLexicon.this.projectWeightsToScores(x);
                double logProb = 0.0;
                for (int tag = 0; tag < FeaturizedLexicon.this.numStates; ++tag) {
                    int expLength = expectedCounts[tag].length;
                    for (int substate = 0; substate < expLength; ++substate) {
                        for (int word : FeaturizedLexicon.this.tagWordsWithFeatures[tag]) {
                            if (!(expectedCounts[tag][substate][word] > 0.0)) continue;
                            logProb += expectedCounts[tag][substate][word] * Math.log(thetas[tag][substate][word]);
                        }
                    }
                }
                return -logProb + FeaturizedLexicon.this.regularizationValue(x);
            }

            @Override
            protected Pair<Double, double[]> calculate(double[] x) {
                double[] gradient = new double[x.length];
                double[][][] thetas = FeaturizedLexicon.this.projectWeightsToScores(x);
                double logProb = 0.0;
                for (int tag = 0; tag < FeaturizedLexicon.this.numStates; ++tag) {
                    int expLength = expectedCounts[tag].length;
                    for (int substate = 0; substate < expLength; ++substate) {
                        double logTotal = eTotals[tag][substate];
                        for (int word : FeaturizedLexicon.this.tagWordsWithFeatures[tag]) {
                            double e = expectedCounts[tag][substate][word];
                            double lT = Math.log(thetas[tag][substate][word]);
                            double margin = e - Math.exp(logTotal + lT);
                            if (e > 0.0) {
                                logProb += expectedCounts[tag][substate][word] * Math.log(thetas[tag][substate][word]);
                            }
                            int[] arr$ = FeaturizedLexicon.this.indexedFeatures[tag][substate][word];
                            int len$ = arr$.length;
                            for (int i$ = 0; i$ < len$; ++i$) {
                                int f;
                                int n = f = arr$[i$];
                                gradient[n] = gradient[n] - margin;
                            }
                        }
                    }
                }
                double[] finalGrad = DoubleArrays.add(gradient, FeaturizedLexicon.this.regularizationGradient(x));
                double finalLP = -logProb + FeaturizedLexicon.this.regularizationValue(x);
                return Pair.makePair(finalLP, finalGrad);
            }
        };
    }

    private double[] regularizationGradient(double[] x) {
        double[] centered = DoubleArrays.add(x, 3.0);
        return DoubleArrays.multiply(centered, this.regularizationConstant);
    }

    private double regularizationValue(double[] weights) {
        double[] centered = DoubleArrays.add(weights, 3.0);
        return DoubleArrays.innerProduct(centered, centered) * 0.5 * this.regularizationConstant;
    }

    private void refeaturize() {
        this.indexedFeatures = new int[this.numStates][][][];
        this.featureIndex = new Indexer();
        this.tagWordsWithFeatures = new int[this.numStates][];
        for (int tag = 0; tag < this.numStates; ++tag) {
            SimpleLexicon.IntegerIndexer tagIndexer = new SimpleLexicon.IntegerIndexer(this.wordIndexer.size());
            this.indexedFeatures[tag] = new int[this.numSubStates[tag]][this.wordIndexer.size()][];
            for (int globalWordIndex = 0; globalWordIndex < this.wordIndexer.size(); ++globalWordIndex) {
                String word = this.wordIndexer.getObject(globalWordIndex);
                List<String>[] features = this.featurizer.featurize(word, tag, this.numSubStates[tag], this.wordCounter[globalWordIndex], this.tagWordCounts[tag][globalWordIndex]);
                for (int state = 0; state < this.numSubStates[tag]; ++state) {
                    int[] indices = new int[features[state].size()];
                    for (int i = 0; i < indices.length; ++i) {
                        indices[i] = this.featureIndex.getIndex(features[state].get(i));
                    }
                    this.indexedFeatures[tag][state][globalWordIndex] = indices;
                    if (features[state].size() <= 0) continue;
                    tagIndexer.add(globalWordIndex);
                }
            }
            this.tagWordsWithFeatures[tag] = new int[tagIndexer.size()];
            for (int j = 0; j < tagIndexer.size(); ++j) {
                this.tagWordsWithFeatures[tag][j] = tagIndexer.get(j);
            }
        }
        if (this.featureWeights == null || this.featureWeights.length != this.featureIndex.size()) {
            this.featureWeights = new double[this.featureIndex.size()];
        }
    }

    @Override
    public void optimize() {
        this.refeaturize();
        LBFGSMinimizer minimizer = this.getMinimizer();
        DifferentiableFunction objective = this.objective(this.expectedCounts);
        minimizer.dumpHistory();
        this.featureWeights = minimizer.minimize(objective, this.featureWeights, 1.0E-5, true);
        this.scores = this.projectWeightsToScores(this.featureWeights);
    }

    @Override
    public double[] score(String word, short tag, int pos, boolean noSmoothing, boolean isSignature) {
        StateSet stateSet = new StateSet(tag, 1, word, (short)pos, (short)(pos + 1));
        stateSet.wordIndex = -2;
        stateSet.sigIndex = -2;
        return this.score(stateSet, tag, noSmoothing, isSignature);
    }

    @Override
    public double[] score(StateSet stateSet, short tag, boolean noSmoothing, boolean isSignature) {
        double[] res = new double[this.numSubStates[tag]];
        int globalWordIndex = stateSet.wordIndex;
        if (globalWordIndex < 1) {
            globalWordIndex = stateSet.wordIndex = this.wordIndexer.indexOf(stateSet.getWord());
        }
        if (globalWordIndex < 0) {
            List<String>[] features = this.featurizer.featurize(stateSet.getWord(), tag, this.numSubStates[tag], 0, 0);
            for (int state = 0; state < this.numSubStates[tag]; ++state) {
                double score = 0.0;
                for (String feature : features[state]) {
                    int index = this.featureIndex.indexOf(feature);
                    if (index >= 0) {
                        score += this.featureWeights[index];
                        continue;
                    }
                    score += -300.0;
                }
                res[state] = this.isLogarithmMode() ? score - this.normalizers[tag][state] : Math.exp(score - this.normalizers[tag][state]);
            }
        } else {
            for (int i = 0; i < this.numSubStates[tag]; ++i) {
                res[i] = this.scores[tag][i][globalWordIndex];
            }
        }
        return res;
    }

    @Override
    public String getSignature(String word, int sentencePosition) {
        return word;
    }

    @Override
    public boolean isLogarithmMode() {
        return this.isLogarithmMode;
    }

    @Override
    public void logarithmMode() {
        if (this.isLogarithmMode) {
            return;
        }
        for (int tag = 0; tag < this.scores.length; ++tag) {
            for (int word = 0; word < this.scores[tag].length; ++word) {
                for (int substate = 0; substate < this.scores[tag][word].length; ++substate) {
                    this.scores[tag][word][substate] = Math.log(this.scores[tag][word][substate]);
                }
            }
        }
        this.isLogarithmMode = true;
    }

    public void init(StateSetTreeList trainTrees) {
        for (Tree<StateSet> tree : trainTrees) {
            List<StateSet> words = tree.getYield();
            for (StateSet word : words) {
                String sig = word.getWord();
                this.wordIndexer.add(sig);
            }
        }
        this.wordCounter = new int[this.wordIndexer.size()];
        this.tagWordCounts = new int[this.numStates][this.wordIndexer.size()];
        for (Tree<StateSet> tree : trainTrees) {
            List<StateSet> tags = tree.getPreTerminalYield();
            List<StateSet> words = tree.getYield();
            int ind = 0;
            for (StateSet word : words) {
                String sig = word.getWord();
                int n = this.wordIndexer.indexOf(sig);
                this.wordCounter[n] = this.wordCounter[n] + 1;
                int[] nArray = this.tagWordCounts[tags.get(ind).getState()];
                int n2 = this.wordIndexer.indexOf(sig);
                nArray[n2] = nArray[n2] + 1;
                ++ind;
            }
        }
        this.resetCounts();
        this.nWords = this.wordIndexer.size();
        this.labelTrees(trainTrees);
    }

    public void resetCounts() {
        this.expectedCounts = new double[this.numStates][][];
        this.scores = new double[this.numStates][][];
        this.normalizers = new double[this.numStates][];
        for (int tag = 0; tag < this.numStates; ++tag) {
            this.expectedCounts[tag] = new double[this.numSubStates[tag]][this.wordIndexer.size()];
            this.normalizers[tag] = new double[this.numSubStates[tag]];
            this.scores[tag] = new double[this.numSubStates[tag]][this.wordIndexer.size()];
        }
    }

    public void labelTrees(StateSetTreeList trainTrees) {
        for (Tree<StateSet> tree : trainTrees) {
            List<StateSet> words = tree.getYield();
            for (StateSet word : words) {
                word.wordIndex = this.wordIndexer.indexOf(word.getWord());
                word.sigIndex = -1;
            }
        }
    }

    @Override
    public double[] scoreWord(StateSet stateSet, int tag) {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    @Override
    public double[] scoreSignature(StateSet stateSet, int tag) {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    @Override
    public void trainTree(Tree<StateSet> trainTree, double randomness, Lexicon oldLexicon, boolean secondHalf, boolean noSmoothing, int unkThreshold) {
        double sentenceScore = 0.0;
        if (randomness == -1.0 && (sentenceScore = trainTree.getLabel().getIScore(0)) == 0.0) {
            System.out.println("Something is wrong with this tree. I will skip it.");
            return;
        }
        int sentenceScale = trainTree.getLabel().getIScale();
        List<StateSet> words = trainTree.getYield();
        List<StateSet> tags = trainTree.getPreTerminalYield();
        for (int position = 0; position < words.size(); ++position) {
            int nSubStates = tags.get(position).numSubStates();
            short tag = tags.get(position).getState();
            String word = words.get(position).getWord();
            int globalWordIndex = this.wordIndexer.indexOf(word);
            double[] oldLexiconScores = null;
            if (randomness == -1.0) {
                oldLexiconScores = oldLexicon.score(word, tag, position, noSmoothing, false);
            }
            StateSet currentState = tags.get(position);
            double scale = ScalingTools.calcScaleFactor(currentState.getOScale() - sentenceScale) / sentenceScore;
            for (int substate = 0; substate < nSubStates; substate = (int)((short)(substate + 1))) {
                double weight = 1.0;
                weight = randomness == -1.0 ? (!Double.isInfinite(scale) ? currentState.getOScore(substate) * oldLexiconScores[substate] * scale : Math.exp(Math.log(ScalingTools.SCALE) * (double)(currentState.getOScale() - sentenceScale) - Math.log(sentenceScore) + Math.log(currentState.getOScore(substate)) + Math.log(oldLexiconScores[substate]))) : (randomness == 0.0 ? 1.0 : GrammarTrainer.RANDOM.nextDouble() * randomness / 100.0 + 1.0);
                if (weight == 0.0) continue;
                double[] dArray = this.expectedCounts[tag][substate];
                int n = globalWordIndex;
                dArray[n] = dArray[n] + weight;
            }
        }
    }

    @Override
    public void setSmoother(Smoother smoother) {
        this.smoother = smoother;
    }

    @Override
    public FeaturizedLexicon splitAllStates(int[] counts, boolean moreSubstatesThanCounts, int mode) {
        FeaturizedLexicon splitLex = this.copyLexicon();
        short[] newNumSubStates = new short[this.numSubStates.length];
        newNumSubStates[0] = 1;
        for (int i = 1; i < this.numSubStates.length; i = (int)((short)(i + 1))) {
            newNumSubStates[i] = (short)(this.numSubStates[i] * 2);
        }
        newNumSubStates[0] = 1;
        Random random = GrammarTrainer.RANDOM;
        splitLex.numSubStates = newNumSubStates;
        double[][][] newScores = new double[this.scores.length][][];
        double[][][] newExpCounts = new double[this.scores.length][][];
        for (int tag = 1; tag < this.expectedCounts.length; ++tag) {
            newScores[tag] = new double[newNumSubStates[tag]][this.wordIndexer.size()];
            newExpCounts[tag] = new double[newNumSubStates[tag]][this.wordIndexer.size()];
            for (int substate = 0; substate < this.numSubStates[tag]; ++substate) {
                for (int word = 0; word < this.scores[tag][substate].length; ++word) {
                    double d = this.scores[tag][substate][word];
                    newScores[tag][2 * substate + 1][word] = d;
                    newScores[tag][2 * substate][word] = d;
                    if (mode != 2) continue;
                    double d2 = 1.0 + random.nextDouble() / 100.0;
                    newScores[tag][2 * substate + 1][word] = d2;
                    newScores[tag][2 * substate][word] = d2;
                }
            }
        }
        splitLex.scores = newScores;
        splitLex.expectedCounts = newExpCounts;
        return splitLex;
    }

    @Override
    public void mergeStates(boolean[][][] mergeThesePairs, double[][] mergeWeights) {
        int tag;
        short[] newNumSubStates = new short[this.numSubStates.length];
        short[][] mapping = new short[this.numSubStates.length][];
        short[][][] partners = new short[this.numSubStates.length][][];
        Grammar.calculateMergeArrays(mergeThesePairs, newNumSubStates, mapping, partners, this.numSubStates);
        double[][][] newScores = new double[this.scores.length][][];
        for (tag = 1; tag < this.expectedCounts.length; ++tag) {
            newScores[tag] = new double[newNumSubStates[tag]][this.wordIndexer.size()];
            if (this.numSubStates[tag] == 1) continue;
            for (int word = 0; word < this.expectedCounts[tag][0].length; ++word) {
                for (int i = 0; i < this.numSubStates[tag]; i += 2) {
                    int nSplit = partners[tag][i].length;
                    if (nSplit == 2) {
                        double mergeWeightSum = mergeWeights[tag][partners[tag][i][0]] + mergeWeights[tag][partners[tag][i][1]];
                        if (mergeWeightSum == 0.0) {
                            mergeWeightSum = 1.0;
                        }
                        newScores[tag][mapping[tag][i]][word] = (mergeWeights[tag][partners[tag][i][0]] * this.scores[tag][partners[tag][i][0]][word] + mergeWeights[tag][partners[tag][i][1]] * this.scores[tag][partners[tag][i][1]][word]) / mergeWeightSum;
                        continue;
                    }
                    newScores[tag][mapping[tag][i]][word] = this.scores[tag][i][word];
                    newScores[tag][mapping[tag][i + 1]][word] = this.scores[tag][i + 1][word];
                }
            }
        }
        this.numSubStates = newNumSubStates;
        this.scores = newScores;
        for (tag = 0; tag < this.numStates; ++tag) {
            this.expectedCounts[tag] = new double[newNumSubStates[tag]][this.wordIndexer.size()];
        }
    }

    @Override
    public Smoother getSmoother() {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    @Override
    public double[] getSmoothingParams() {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    @Override
    public FeaturizedLexicon projectLexicon(double[] condProbs, int[][] mapping, int[][] toSubstateMapping) {
        int tag;
        short[] newNumSubStates = new short[this.numSubStates.length];
        for (int state = 0; state < this.numSubStates.length; ++state) {
            newNumSubStates[state] = (short)toSubstateMapping[state][0];
        }
        FeaturizedLexicon newLexicon = this.copyLexicon();
        double[][][] newScores = new double[this.scores.length][][];
        for (tag = 0; tag < this.scores.length; tag = (int)((short)(tag + 1))) {
            newScores[tag] = new double[newNumSubStates[tag]][this.wordIndexer.size()];
            for (int word = 0; word < this.scores[tag][0].length; ++word) {
                for (int substate = 0; substate < this.numSubStates[tag]; ++substate) {
                    double[] dArray = newScores[tag][toSubstateMapping[tag][substate + 1]];
                    int n = word;
                    dArray[n] = dArray[n] + condProbs[mapping[tag][substate]] * this.scores[tag][substate][word];
                }
            }
        }
        newLexicon.numStates = newScores.length;
        newLexicon.numSubStates = newNumSubStates;
        newLexicon.scores = newScores;
        newLexicon.expectedCounts = new double[this.numStates][][];
        for (tag = 0; tag < this.numStates; tag = (int)((short)(tag + 1))) {
            newLexicon.expectedCounts[tag] = new double[newNumSubStates[tag]][this.wordIndexer.size()];
            for (int substate = 0; substate < newNumSubStates[tag]; ++substate) {
                for (int word = 0; word < this.wordIndexer.size(); ++word) {
                    newLexicon.expectedCounts[tag][substate][word] = this.isLogarithmMode ? Math.exp(newScores[tag][substate][word]) : newScores[tag][substate][word];
                }
            }
        }
        newLexicon.optimize();
        return newLexicon;
    }

    @Override
    public FeaturizedLexicon copyLexicon() {
        FeaturizedLexicon copy = new FeaturizedLexicon(this.numSubStates, this.featurizer);
        copy.expectedCounts = new double[this.numStates][][];
        copy.scores = ArrayUtil.clone(this.scores);
        copy.wordIndexer = this.wordIndexer;
        for (int tag = 0; tag < this.numStates; ++tag) {
            copy.expectedCounts[tag] = new double[this.numSubStates[tag]][this.wordIndexer.size()];
        }
        copy.nWords = this.nWords;
        copy.smoother = this.smoother;
        copy.numStates = this.numStates;
        copy.numSubStates = this.numSubStates;
        copy.wordCounter = (int[])this.wordCounter.clone();
        copy.tagWordCounts = ArrayUtil.clone(this.tagWordCounts);
        copy.tagWordsWithFeatures = ArrayUtil.clone(this.tagWordsWithFeatures);
        copy.featureWeights = ArrayUtil.clone(this.featureWeights);
        copy.normalizers = ArrayUtil.clone(this.normalizers);
        copy.featureIndex = this.featureIndex;
        copy.indexedFeatures = this.indexedFeatures;
        return copy;
    }

    @Override
    public void removeUnlikelyTags(double threshold, double exponent) {
    }

    @Override
    public double getPruningThreshold() {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    @Override
    public void tieRareWordStats(int threshold) {
    }

    @Override
    public Counter<String> getWordCounter() {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    @Override
    public void explicitlyComputeScores(int finalLevel) {
        throw new UnsupportedOperationException("Not supported yet.");
    }
}

