/*
 * Decompiled with CFR 0.152.
 */
package edu.berkeley.nlp.discPCFG;

import edu.berkeley.nlp.PCFGLA.BinaryRule;
import edu.berkeley.nlp.PCFGLA.Grammar;
import edu.berkeley.nlp.PCFGLA.HierarchicalAdaptiveBinaryRule;
import edu.berkeley.nlp.PCFGLA.HierarchicalAdaptiveGrammar;
import edu.berkeley.nlp.PCFGLA.HierarchicalAdaptiveLexicalRule;
import edu.berkeley.nlp.PCFGLA.HierarchicalAdaptiveUnaryRule;
import edu.berkeley.nlp.PCFGLA.HierarchicalFullyConnectedAdaptiveLexicon;
import edu.berkeley.nlp.PCFGLA.Rule;
import edu.berkeley.nlp.PCFGLA.SimpleLexicon;
import edu.berkeley.nlp.PCFGLA.SpanPredictor;
import edu.berkeley.nlp.PCFGLA.UnaryRule;
import edu.berkeley.nlp.discPCFG.HierarchicalLinearizer;
import edu.berkeley.nlp.syntax.StateSet;
import edu.berkeley.nlp.syntax.StateSetWithFeatures;
import edu.berkeley.nlp.util.ArrayUtil;
import java.util.List;

public class HiearchicalAdaptiveLinearizer
extends HierarchicalLinearizer {
    private static final long serialVersionUID = 1L;
    HierarchicalAdaptiveGrammar grammar;
    HierarchicalFullyConnectedAdaptiveLexicon lexicon;

    public HiearchicalAdaptiveLinearizer(Grammar grammar, SimpleLexicon lexicon, SpanPredictor sp, int fLevel) {
        this.grammar = (HierarchicalAdaptiveGrammar)grammar;
        lexicon.explicitlyComputeScores(fLevel);
        grammar.closedViterbiRulesWithParent = grammar.unaryRulesWithParent;
        grammar.closedSumRulesWithParent = grammar.unaryRulesWithParent;
        grammar.closedViterbiRulesWithChild = grammar.unaryRulesWithC;
        grammar.closedSumRulesWithChild = grammar.unaryRulesWithC;
        grammar.clearUnaryIntermediates();
        grammar.makeCRArrays();
        this.lexicon = (HierarchicalFullyConnectedAdaptiveLexicon)lexicon;
        this.spanPredictor = sp;
        this.finalLevel = fLevel;
        this.nSubstates = (int)ArrayUtil.max(grammar.numSubStates);
        this.init();
        this.computeMappings();
    }

    public SimpleLexicon getLexicon() {
        return this.lexicon;
    }

    public Grammar getGrammar() {
        return this.grammar;
    }

    public double[] getLinearizedLexicon(boolean update) {
        if (update) {
            this.nLexiconWeights = 0;
            for (int tag = 0; tag < this.lexicon.rules.length; tag = (int)((short)(tag + 1))) {
                for (int word = 0; word < this.lexicon.rules[tag].length; ++word) {
                    this.lexicon.rules[tag][word].identifier = this.nLexiconWeights + this.nGrammarWeights;
                    this.nLexiconWeights += this.lexicon.rules[tag][word].getFinalLevel().size();
                }
            }
        }
        double[] logProbs = new double[this.nLexiconWeights];
        int index = 0;
        for (int tag = 0; tag < this.lexicon.rules.length; tag = (int)((short)(tag + 1))) {
            for (int word = 0; word < this.lexicon.rules[tag].length; ++word) {
                List<Double> vals = this.lexicon.rules[tag][word].getFinalLevel();
                for (Double val : vals) {
                    logProbs[index++] = val;
                }
            }
        }
        if (index != logProbs.length) {
            System.out.println("unequal length in lexicon");
        }
        return logProbs;
    }

    public void delinearizeLexicon(double[] logProbs, boolean usingOnlyLastLevel) {
        for (int tag = 0; tag < this.lexicon.rules.length; tag = (int)((short)(tag + 1))) {
            for (int word = 0; word < this.lexicon.rules[tag].length; ++word) {
                this.lexicon.rules[tag][word].updateScores(logProbs);
                this.lexicon.rules[tag][word].explicitlyComputeScores(this.finalLevel, usingOnlyLastLevel);
            }
        }
    }

    public void delinearizeLexicon(double[] logProbs) {
        for (int tag = 0; tag < this.lexicon.rules.length; tag = (int)((short)(tag + 1))) {
            for (int word = 0; word < this.lexicon.rules[tag].length; ++word) {
                this.lexicon.rules[tag][word].updateScores(logProbs);
                this.lexicon.rules[tag][word].explicitlyComputeScores(this.finalLevel, false);
            }
        }
    }

    public void increment(double[] counts, StateSet stateSet, int tag, double[] weights, boolean isGold) {
        if (!(stateSet instanceof StateSetWithFeatures)) {
            int globalWordIndex;
            int tagSpecificWordIndex;
            int tagSpecificWordIndex2;
            int globalSigIndex = stateSet.sigIndex;
            if (globalSigIndex != -1 && (tagSpecificWordIndex2 = this.lexicon.tagWordIndexer[tag].indexOf(globalSigIndex)) >= 0) {
                HierarchicalAdaptiveLexicalRule rule = this.lexicon.rules[tag][tagSpecificWordIndex2];
                int startIndexWord = rule.identifier;
                short[] mapping = rule.mapping;
                for (int i = 0; i < this.nSubstates; ++i) {
                    if (isGold) {
                        int n = startIndexWord + mapping[i];
                        counts[n] = counts[n] + weights[i];
                        continue;
                    }
                    int n = startIndexWord + mapping[i];
                    counts[n] = counts[n] - weights[i];
                }
            }
            if ((tagSpecificWordIndex = this.lexicon.tagWordIndexer[tag].indexOf(globalWordIndex = stateSet.wordIndex)) < 0) {
                for (int i = 0; i < this.nSubstates; ++i) {
                    weights[i] = 0.0;
                }
            } else {
                HierarchicalAdaptiveLexicalRule rule = this.lexicon.rules[tag][tagSpecificWordIndex];
                int startIndexWord = rule.identifier;
                short[] mapping = rule.mapping;
                for (int i = 0; i < this.nSubstates; ++i) {
                    if (isGold) {
                        int n = startIndexWord + mapping[i];
                        counts[n] = counts[n] + weights[i];
                    } else {
                        int n = startIndexWord + mapping[i];
                        counts[n] = counts[n] - weights[i];
                    }
                    weights[i] = 0.0;
                }
            }
        } else {
            StateSetWithFeatures stateSetF = (StateSetWithFeatures)stateSet;
            for (int f : stateSetF.features) {
                int tagF;
                if (f < 0 || (tagF = this.lexicon.tagWordIndexer[tag].indexOf(f)) < 0) continue;
                HierarchicalAdaptiveLexicalRule rule = this.lexicon.rules[tag][tagF];
                int startIndexWord = rule.identifier;
                short[] mapping = rule.mapping;
                for (int i = 0; i < this.nSubstates; ++i) {
                    if (isGold) {
                        int n = startIndexWord + mapping[i];
                        counts[n] = counts[n] + weights[i];
                        continue;
                    }
                    int n = startIndexWord + mapping[i];
                    counts[n] = counts[n] - weights[i];
                }
            }
            for (int i = 0; i < this.nSubstates; ++i) {
                weights[i] = 0.0;
            }
        }
    }

    public void increment(double[] counts, BinaryRule rule, double[] weights, boolean isGold) {
        HierarchicalAdaptiveBinaryRule hr = (HierarchicalAdaptiveBinaryRule)rule;
        int thisStartIndex = hr.identifier;
        for (int curInd = 0; curInd < hr.nParam; ++curInd) {
            double val = weights[curInd];
            if (!(val > 0.0)) continue;
            weights[curInd] = 0.0;
            if (isGold) {
                int n = thisStartIndex + curInd;
                counts[n] = counts[n] + val;
                continue;
            }
            int n = thisStartIndex + curInd;
            counts[n] = counts[n] - val;
        }
    }

    public void increment(double[] counts, UnaryRule rule, double[] weights, boolean isGold) {
        HierarchicalAdaptiveUnaryRule hr = (HierarchicalAdaptiveUnaryRule)rule;
        int thisStartIndex = hr.identifier;
        for (int curInd = 0; curInd < hr.nParam; ++curInd) {
            double val = weights[curInd];
            if (!(val > 0.0)) continue;
            weights[curInd] = 0.0;
            if (isGold) {
                int n = thisStartIndex + curInd;
                counts[n] = counts[n] + val;
                continue;
            }
            int n = thisStartIndex + curInd;
            counts[n] = counts[n] - val;
        }
    }

    public void delinearizeGrammar(double[] probs) {
        Rule hRule;
        int nDangerous = 0;
        for (BinaryRule bRule : this.grammar.binaryRuleMap.keySet()) {
            hRule = (HierarchicalAdaptiveBinaryRule)bRule;
            ((HierarchicalAdaptiveBinaryRule)hRule).updateScores(probs);
        }
        if (nDangerous > 0) {
            System.out.println("Left " + nDangerous + " binary rule weights unchanged since the proposed weight was dangerous.");
        }
        nDangerous = 0;
        for (UnaryRule uRule : this.grammar.unaryRuleMap.keySet()) {
            hRule = (HierarchicalAdaptiveUnaryRule)uRule;
            ((HierarchicalAdaptiveUnaryRule)hRule).updateScores(probs);
        }
        if (nDangerous > 0) {
            System.out.println("Left " + nDangerous + " unary rule weights unchanged since the proposed weight was dangerous.");
        }
        this.grammar.explicitlyComputeScores(this.finalLevel);
        this.grammar.closedViterbiRulesWithParent = this.grammar.unaryRulesWithParent;
        this.grammar.closedSumRulesWithParent = this.grammar.unaryRulesWithParent;
        this.grammar.closedViterbiRulesWithChild = this.grammar.unaryRulesWithC;
        this.grammar.closedSumRulesWithChild = this.grammar.unaryRulesWithC;
        this.grammar.clearUnaryIntermediates();
        this.grammar.makeCRArrays();
    }

    public double[] getLinearizedGrammar(boolean update) {
        List<Double> vals;
        int ind;
        Rule hRule;
        if (update) {
            Rule hRule2;
            this.nGrammarWeights = 0;
            for (BinaryRule bRule : this.grammar.binaryRuleMap.keySet()) {
                hRule2 = (HierarchicalAdaptiveBinaryRule)bRule;
                if (!this.grammar.isGrammarTag[bRule.parentState]) {
                    System.out.println("Incorrect grammar tag");
                }
                bRule.identifier = this.nGrammarWeights;
                this.nGrammarWeights += hRule2.nParam;
            }
            for (UnaryRule uRule : this.grammar.unaryRuleMap.keySet()) {
                hRule2 = (HierarchicalAdaptiveUnaryRule)uRule;
                uRule.identifier = this.nGrammarWeights;
                this.nGrammarWeights += ((HierarchicalAdaptiveUnaryRule)hRule2).nParam;
            }
        }
        double[] logProbs = new double[this.nGrammarWeights];
        for (BinaryRule bRule : this.grammar.binaryRuleMap.keySet()) {
            hRule = (HierarchicalAdaptiveBinaryRule)bRule;
            ind = ((HierarchicalAdaptiveBinaryRule)hRule).identifier;
            vals = ((HierarchicalAdaptiveBinaryRule)hRule).getFinalLevel();
            for (Double val : vals) {
                logProbs[ind++] = val;
            }
        }
        for (UnaryRule uRule : this.grammar.unaryRuleMap.keySet()) {
            hRule = (HierarchicalAdaptiveUnaryRule)uRule;
            ind = ((HierarchicalAdaptiveUnaryRule)hRule).identifier;
            if (uRule.childState == uRule.parentState) continue;
            vals = ((HierarchicalAdaptiveUnaryRule)hRule).getFinalLevel();
            for (Double val : vals) {
                logProbs[ind++] = val;
            }
        }
        return logProbs;
    }

    public void delinearizeLexiconWeights(double[] logWeights) {
        double val;
        int i;
        int nGrZ = 0;
        int nLexZ = 0;
        boolean nSpZ = false;
        int tmpI = 0;
        for (i = 0; i < this.nGrammarWeights; ++i) {
            if ((val = logWeights[tmpI++]) != 0.0) continue;
            ++nGrZ;
        }
        for (i = 0; i < this.nLexiconWeights; ++i) {
            if ((val = logWeights[tmpI++]) != 0.0) continue;
            ++nLexZ;
        }
        this.delinearizeLexicon(logWeights, true);
    }
}

