/*
 * Decompiled with CFR 0.152.
 */
package edu.berkeley.nlp.PCFGLA;

import edu.berkeley.nlp.PCFGLA.HierarchicalLexicon;
import edu.berkeley.nlp.PCFGLA.SimpleLexicon;
import edu.berkeley.nlp.PCFGLA.StateSetTreeList;
import edu.berkeley.nlp.PCFGLA.smoothing.Smoother;
import edu.berkeley.nlp.syntax.StateSet;
import edu.berkeley.nlp.syntax.Tree;
import java.util.Arrays;
import java.util.List;

public class HierarchicalFullyConnectedLexicon
extends HierarchicalLexicon {
    private static final long serialVersionUID = 1L;
    protected int knownWordCount;

    public HierarchicalFullyConnectedLexicon(short[] numSubStates, int knownWordCount) {
        super(numSubStates, 0.0);
        this.knownWordCount = knownWordCount;
    }

    public HierarchicalFullyConnectedLexicon(short[] numSubStates, int smoothingCutoff, double[] smoothParam, Smoother smoother, StateSetTreeList trainTrees, int knownWordCount) {
        this(numSubStates, knownWordCount);
        this.init(trainTrees);
    }

    public HierarchicalFullyConnectedLexicon(SimpleLexicon previousLexicon, int knownWordCount) {
        super(previousLexicon);
        this.knownWordCount = knownWordCount;
    }

    public HierarchicalFullyConnectedLexicon newInstance() {
        return new HierarchicalFullyConnectedLexicon(this.numSubStates, this.knownWordCount);
    }

    public void init(StateSetTreeList trainTrees) {
        List<StateSet> words;
        for (Tree<StateSet> tree : trainTrees) {
            words = tree.getYield();
            for (StateSet word : words) {
                String sig = word.getWord();
                this.wordIndexer.add(sig);
            }
        }
        this.wordCounter = new int[this.wordIndexer.size()];
        for (Tree<StateSet> tree : trainTrees) {
            words = tree.getYield();
            int ind = 0;
            for (StateSet word : words) {
                String wordString = word.getWord();
                int n = this.wordIndexer.indexOf(wordString);
                this.wordCounter[n] = this.wordCounter[n] + 1;
                String sig = this.getSignature(word.getWord(), ind++);
                this.wordIndexer.add(sig);
            }
        }
        this.tagWordIndexer = new SimpleLexicon.IntegerIndexer[this.numStates];
        for (int tag = 0; tag < this.numStates; ++tag) {
            this.tagWordIndexer[tag] = new SimpleLexicon.IntegerIndexer(this.wordIndexer.size());
        }
        this.labelTrees(trainTrees);
        boolean[] lexTag = new boolean[this.numStates];
        for (Tree<StateSet> tree : trainTrees) {
            List<StateSet> words2 = tree.getYield();
            List<StateSet> tags = tree.getPreTerminalYield();
            int ind = 0;
            for (StateSet word : words2) {
                short tag = tags.get(ind).getState();
                this.tagWordIndexer[tag].add(new Integer(word.wordIndex));
                this.tagWordIndexer[tag].add(new Integer(word.sigIndex));
                lexTag[tag] = true;
                ++ind;
            }
        }
        this.expectedCounts = new double[this.numStates][][];
        this.scores = new double[this.numStates][][];
        for (int tag = 0; tag < this.numStates; ++tag) {
            if (!lexTag[tag]) {
                this.tagWordIndexer[tag] = null;
                continue;
            }
            this.scores[tag] = new double[this.numSubStates[tag]][this.tagWordIndexer[tag].size()];
        }
        this.nWords = this.wordIndexer.size();
    }

    public double[] score(int globalWordIndex, int globalSigIndex, short tag, int loc, boolean noSmoothing, boolean isSignature) {
        int i;
        int tagSpecificWordIndex;
        double[] res = new double[this.numSubStates[tag]];
        if (globalWordIndex != -1) {
            tagSpecificWordIndex = this.tagWordIndexer[tag].indexOf(globalWordIndex);
            if (tagSpecificWordIndex != -1) {
                for (i = 0; i < this.numSubStates[tag]; ++i) {
                    res[i] = this.scores[tag][i][tagSpecificWordIndex];
                }
            } else {
                Arrays.fill(res, 1.0);
            }
        } else {
            Arrays.fill(res, 1.0);
        }
        if (globalWordIndex >= 0 && this.wordCounter[globalWordIndex] > this.knownWordCount) {
            return res;
        }
        if (globalSigIndex != -1 && (tagSpecificWordIndex = this.tagWordIndexer[tag].indexOf(globalSigIndex)) != -1) {
            for (i = 0; i < this.numSubStates[tag]; ++i) {
                int n = i;
                res[n] = res[n] * this.scores[tag][i][tagSpecificWordIndex];
            }
        }
        return res;
    }

    public double[] score(StateSet stateSet, short tag, boolean noSmoothing, boolean isSignature) {
        if (stateSet.wordIndex == -2) {
            String word = stateSet.getWord();
            if (isSignature) {
                stateSet.wordIndex = -1;
                stateSet.sigIndex = this.wordIndexer.indexOf(word);
            } else {
                stateSet.wordIndex = this.wordIndexer.indexOf(word);
                if (stateSet.wordIndex >= 0 && this.wordCounter[stateSet.wordIndex] > this.knownWordCount || noSmoothing) {
                    stateSet.sigIndex = -1;
                } else if (this.knownWordCount > 0) {
                    stateSet.sigIndex = this.wordIndexer.indexOf(this.getSignature(word, stateSet.from));
                } else {
                    stateSet.wordIndex = this.wordIndexer.indexOf(this.getSignature(word, stateSet.from));
                }
            }
        }
        return this.score(stateSet.wordIndex, stateSet.sigIndex, tag, stateSet.from, noSmoothing, isSignature);
    }

    public void labelTrees(StateSetTreeList trainTrees) {
        for (Tree<StateSet> tree : trainTrees) {
            List<StateSet> words = tree.getYield();
            List<StateSet> tags = tree.getPreTerminalYield();
            int ind = 0;
            for (StateSet word : words) {
                word.wordIndex = this.wordIndexer.indexOf(word.getWord());
                if (word.wordIndex < 0 || word.wordIndex >= this.wordCounter.length) {
                    System.out.println("Have never seen this word before: " + word.getWord() + " " + word.wordIndex);
                    System.out.println(tree);
                } else if (this.wordCounter[word.wordIndex] <= this.knownWordCount) {
                    short tag = tags.get(ind).getState();
                    String sig = this.getSignature(word.getWord(), ind);
                    this.wordIndexer.add(sig);
                    word.sigIndex = this.wordIndexer.indexOf(sig);
                    this.tagWordIndexer[tag].add(this.wordIndexer.indexOf(sig));
                } else {
                    word.sigIndex = -1;
                }
                ++ind;
            }
        }
    }
}

