/*
 * Decompiled with CFR 0.152.
 */
package edu.berkeley.nlp.PCFGLA.smoothing;

import edu.berkeley.nlp.PCFGLA.BinaryCounterTable;
import edu.berkeley.nlp.PCFGLA.BinaryRule;
import edu.berkeley.nlp.PCFGLA.UnaryCounterTable;
import edu.berkeley.nlp.PCFGLA.UnaryRule;
import edu.berkeley.nlp.PCFGLA.smoothing.Smoother;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.util.Numberer;
import java.io.Serializable;
import java.util.List;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class SmoothAcrossParentBits
implements Smoother,
Serializable {
    private static final long serialVersionUID = 1L;
    double same;
    double[][][] diffWeights;
    double weightBasis = 0.5;
    double totalWeight;

    @Override
    public SmoothAcrossParentBits copy() {
        return new SmoothAcrossParentBits(this.same, this.diffWeights, this.weightBasis, this.totalWeight);
    }

    public SmoothAcrossParentBits(double smooth, Tree<Short>[] splitTrees) {
        this.same = 1.0 - smooth;
        int nStates = splitTrees.length;
        this.diffWeights = new double[nStates][][];
        for (int state = 0; state < nStates; state = (int)((short)(state + 1))) {
            Tree<Short> splitTree = splitTrees[state];
            List<Short> allSubstates = splitTree.getYield();
            int nSubstates = 1;
            for (int i = 0; i < allSubstates.size(); ++i) {
                if (allSubstates.get(i) < nSubstates) continue;
                nSubstates = allSubstates.get(i) + 1;
            }
            this.diffWeights[state] = new double[nSubstates][nSubstates];
            if (nSubstates == 1) {
                this.diffWeights[state][0][0] = 1.0;
                continue;
            }
            while (splitTree.getChildren().size() == 1) {
                splitTree = splitTree.getChildren().get(0);
            }
            for (int branch = 0; branch < 2; ++branch) {
                List<Short> substatesInBranch = splitTree.getChildren().get(branch).getYield();
                int total = substatesInBranch.size();
                double normalizedSmooth = smooth / (double)(total - 1);
                for (short i : substatesInBranch) {
                    for (short j : substatesInBranch) {
                        if (i == j) {
                            this.diffWeights[state][i][j] = this.same;
                            continue;
                        }
                        this.diffWeights[state][i][j] = normalizedSmooth;
                    }
                }
            }
        }
    }

    public SmoothAcrossParentBits(double same2, double[][][] diffWeights2, double weightBasis2, double totalWeight2) {
        this.same = same2;
        this.diffWeights = diffWeights2;
        this.weightBasis = weightBasis2;
        this.totalWeight = totalWeight2;
    }

    @Override
    public void smooth(UnaryCounterTable unaryCounter, BinaryCounterTable binaryCounter) {
        int j;
        short pState;
        Object scopy;
        Object scores;
        for (UnaryRule unaryRule : unaryCounter.keySet()) {
            scores = unaryCounter.getCount(unaryRule);
            scopy = new double[((double[][])scores).length][];
            pState = unaryRule.parentState;
            for (j = 0; j < ((double[][])scores).length; ++j) {
                if (scores[j] == null) continue;
                scopy[j] = new double[scores[j].length];
                for (int i = 0; i < scores[j].length; ++i) {
                    for (int k = 0; k < scores[j].length; ++k) {
                        double[] dArray = scopy[j];
                        int n = i;
                        dArray[n] = dArray[n] + this.diffWeights[pState][i][k] * scores[j][k];
                    }
                }
            }
            unaryCounter.setCount(unaryRule, (double[][])scopy);
        }
        for (BinaryRule binaryRule : binaryCounter.keySet()) {
            scores = binaryCounter.getCount(binaryRule);
            scopy = new double[((double[][])scores).length][scores[0].length][];
            pState = binaryRule.parentState;
            for (j = 0; j < ((double[][])scores).length; ++j) {
                for (int l = 0; l < scores[j].length; ++l) {
                    if (scores[j][l] == null) continue;
                    scopy[j][l] = (double)new double[((double)scores[j][l]).length];
                    for (int i = 0; i < ((double)scores[j][l]).length; ++i) {
                        for (int k = 0; k < ((double)scores[j][l]).length; ++k) {
                            double d = scopy[j][l];
                            int n = i;
                            d[n] = d[n] + this.diffWeights[pState][i][k] * scores[j][l][k];
                        }
                    }
                }
            }
            binaryCounter.setCount(binaryRule, (double[][][])scopy);
        }
    }

    private void fillWeightsArray(short state, short substate, double weight, Tree<Short> subTree) {
        if (subTree.isLeaf()) {
            if (subTree.getLabel() == substate) {
                this.diffWeights[state][substate][substate] = this.same;
            } else {
                this.diffWeights[state][substate][subTree.getLabel().shortValue()] = weight;
                this.totalWeight += weight;
            }
            return;
        }
        if (subTree.getChildren().size() == 1) {
            this.fillWeightsArray(state, substate, weight, subTree.getChildren().get(0));
            return;
        }
        for (int branch = 0; branch < 2; ++branch) {
            Tree<Short> branchTree = subTree.getChildren().get(branch);
            List<Short> substatesInBranch = branchTree.getYield();
            if (substatesInBranch.contains(substate)) {
                this.fillWeightsArray(state, substate, weight, branchTree);
                continue;
            }
            this.fillWeightsArray(state, substate, weight * this.weightBasis / 2.0, branchTree);
        }
    }

    @Override
    public void smooth(short tag, double[] scores) {
        int i;
        double[] scopy = new double[scores.length];
        for (i = 0; i < scores.length; ++i) {
            for (int k = 0; k < scores.length; ++k) {
                int n = i;
                scopy[n] = scopy[n] + this.diffWeights[tag][i][k] * scores[k];
            }
        }
        for (i = 0; i < scores.length; ++i) {
            scores[i] = scopy[i];
        }
    }

    @Override
    public void updateWeights(int[][] toSubstateMapping) {
        double[][][] newWeights = new double[toSubstateMapping.length][][];
        for (int state = 0; state < toSubstateMapping.length; ++state) {
            int substate2;
            int substate1;
            int nSub = toSubstateMapping[state][0];
            newWeights[state] = new double[nSub][nSub];
            if (nSub == 1) {
                newWeights[state][0][0] = 1.0;
                continue;
            }
            double[] total = new double[nSub];
            for (substate1 = 0; substate1 < this.diffWeights[state].length; ++substate1) {
                for (substate2 = 0; substate2 < this.diffWeights[state].length; ++substate2) {
                    double[] dArray = newWeights[state][toSubstateMapping[state][substate1 + 1]];
                    int n = toSubstateMapping[state][substate2 + 1];
                    dArray[n] = dArray[n] + this.diffWeights[state][substate1][substate2];
                    int n2 = toSubstateMapping[state][substate1 + 1];
                    total[n2] = total[n2] + this.diffWeights[state][substate1][substate2];
                }
            }
            for (substate1 = 0; substate1 < nSub; ++substate1) {
                substate2 = 0;
                while (substate2 < nSub) {
                    double[] dArray = newWeights[state][substate1];
                    int n = substate2++;
                    dArray[n] = dArray[n] / total[substate1];
                }
            }
        }
        this.diffWeights = newWeights;
    }

    @Override
    public Smoother remapStates(Numberer thisNumberer, Numberer newNumberer) {
        SmoothAcrossParentBits remappedSmoother = this.copy();
        remappedSmoother.diffWeights = new double[newNumberer.size()][][];
        for (int s = 0; s < newNumberer.size(); ++s) {
            short translatedState = this.translateState(s, newNumberer, thisNumberer);
            remappedSmoother.diffWeights[s] = translatedState >= 0 ? this.diffWeights[translatedState] : new double[1][1];
        }
        return remappedSmoother;
    }

    private short translateState(int state, Numberer baseNumberer, Numberer translationNumberer) {
        Object object = baseNumberer.object(state);
        if (translationNumberer.hasSeen(object)) {
            return (short)translationNumberer.number(object);
        }
        return -1;
    }
}

