package edu.berkeley.nlp.scripts;

import edu.berkeley.nlp.PCFGLA.Grammar;
import edu.berkeley.nlp.PCFGLA.StateSetTreeList;
import edu.berkeley.nlp.PCFGLA.smoothing.NoSmoothing;
import edu.berkeley.nlp.discPCFG.ParsingObjectiveFunction;
import edu.berkeley.nlp.math.DoubleArrays;
import edu.berkeley.nlp.parser.EnglishPennTreebankParseEvaluator;
import edu.berkeley.nlp.syntax.StateSet;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.syntax.Trees;
import edu.berkeley.nlp.util.Filter;
import edu.berkeley.nlp.util.Numberer;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.InputStreamReader;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:edu/berkeley/nlp/scripts/GermanSharedTask.class */
public class GermanSharedTask {
    Numberer tagNumberer;
    List<Numberer> substateNumberers;

    public Grammar extractGrammar(List<Tree<String>> list) {
        this.tagNumberer = Numberer.getGlobalNumberer("tags");
        this.substateNumberers = new ArrayList();
        short[] countSymbols = countSymbols(list);
        return createGrammar(new StateSetTreeList(stripOffGF(list), countSymbols, false, this.tagNumberer), list, countSymbols);
    }

    private void checkGrammar(Grammar grammar, List<Tree<String>> list, List<Tree<String>> list2) {
        Tree<String> tree;
        EnglishPennTreebankParseEvaluator.LabeledConstituentEval labeledConstituentEval = new EnglishPennTreebankParseEvaluator.LabeledConstituentEval(new HashSet(Arrays.asList("ROOT", "PSEUDO")), new HashSet(Arrays.asList("''", "``", ".", ":", ",")));
        int i = 0;
        Iterator<Tree<StateSet>> it = new StateSetTreeList(stripOffGF(list), grammar.numSubStates, false, this.tagNumberer).iterator();
        while (it.hasNext()) {
            Tree<StateSet> next = it.next();
            int i2 = i;
            i++;
            Tree<String> tree2 = list2.get(i2);
            while (true) {
                tree = tree2;
                if (tree.getYield().size() != next.getYield().size() && i <= list2.size()) {
                    int i3 = i;
                    i++;
                    tree2 = list2.get(i3);
                }
            }
            labeledConstituentEval.evaluate(Trees.spliceNodes(tree, new Filter<String>() { // from class: edu.berkeley.nlp.scripts.GermanSharedTask.2
                @Override // edu.berkeley.nlp.util.Filter
                public boolean accept(String str) {
                    return str.startsWith("@");
                }
            }), Trees.spliceNodes(guessGF(next, grammar, tree.getPreTerminalYield()), new Filter<String>() { // from class: edu.berkeley.nlp.scripts.GermanSharedTask.1
                @Override // edu.berkeley.nlp.util.Filter
                public boolean accept(String str) {
                    return str.startsWith("@");
                }
            }));
            int i4 = 1 + 1;
        }
        labeledConstituentEval.display(true);
    }

    private void labelTrees(Grammar grammar, List<Tree<String>> list, List<List<String>> list2) {
        int i = 0;
        Iterator<Tree<StateSet>> it = new StateSetTreeList(stripOffGF(list), grammar.numSubStates, false, this.tagNumberer).iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            System.out.println(Trees.spliceNodes(guessGF(it.next(), grammar, list2.get(i2)), new Filter<String>() { // from class: edu.berkeley.nlp.scripts.GermanSharedTask.3
                @Override // edu.berkeley.nlp.util.Filter
                public boolean accept(String str) {
                    return str.startsWith("@");
                }
            }) + "\n");
        }
    }

    private Tree<String> guessGF(Tree<StateSet> tree, Grammar grammar, List<String> list) {
        doInsideScores(tree, grammar, list);
        return extractBestViterbiDerivation(grammar, tree, 0);
    }

    private List<Tree<String>> stripOffGF(List<Tree<String>> list) {
        ArrayList<Tree> arrayList = new ArrayList(list.size());
        Iterator<Tree<String>> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().shallowClone());
        }
        for (Tree tree : arrayList) {
            for (Tree tree2 : tree.getPostOrderTraversal()) {
                if (!tree.isLeaf()) {
                    String str = (String) tree2.getLabel();
                    int indexOf = str.indexOf(45);
                    if (indexOf != -1) {
                        str = str.substring(0, indexOf);
                    }
                    tree2.setLabel(str);
                }
            }
        }
        return arrayList;
    }

    private Grammar createGrammar(StateSetTreeList stateSetTreeList, List<Tree<String>> list, short[] sArr) {
        Grammar grammar = new Grammar(sArr, false, new NoSmoothing(), null, -1.0d);
        int i = 0;
        Iterator<Tree<StateSet>> it = stateSetTreeList.iterator();
        while (it.hasNext()) {
            Tree<StateSet> next = it.next();
            int i2 = i;
            i++;
            setScores(next, list.get(i2));
            grammar.tallyStateSetTree(next, grammar);
        }
        grammar.optimize(0.0d);
        return grammar;
    }

    private void setScores(Tree<StateSet> tree, Tree<String> tree2) {
        if (tree2.isLeaf()) {
            return;
        }
        String[] splitLabel = splitLabel(tree2.getLabel());
        StateSet label = tree.getLabel();
        int number = this.substateNumberers.get(label.getState()).number(splitLabel[1]);
        label.setIScore(number, 1.0d);
        label.setIScale(0);
        label.setOScore(number, 1.0d);
        label.setOScale(0);
        int size = tree2.getChildren().size();
        if (size != tree.getChildren().size()) {
            System.err.println("Mismatch!");
        }
        for (int i = 0; i < size; i++) {
            setScores(tree.getChildren().get(i), tree2.getChildren().get(i));
        }
    }

    private short[] countSymbols(List<Tree<String>> list) {
        Iterator<Tree<String>> it = list.iterator();
        while (it.hasNext()) {
            processTree(it.next());
        }
        short[] sArr = new short[this.tagNumberer.total()];
        for (int i = 0; i < sArr.length; i++) {
            sArr[i] = (short) this.substateNumberers.get(i).total();
        }
        return sArr;
    }

    private void processTree(Tree<String> tree) {
        String[] splitLabel = splitLabel(tree.getLabel());
        int number = this.tagNumberer.number(splitLabel[0]);
        if (number >= this.substateNumberers.size()) {
            this.substateNumberers.add(new Numberer());
        }
        this.substateNumberers.get(number).number(splitLabel[1]);
        for (Tree<String> tree2 : tree.getChildren()) {
            if (!tree2.isLeaf()) {
                processTree(tree2);
            }
        }
    }

    private String[] splitLabel(String str) {
        String[] split = str.split("-");
        if (split.length == 1) {
            split = new String[]{split[0], ""};
        }
        return split;
    }

    Tree<String> extractBestViterbiDerivation(Grammar grammar, Tree<StateSet> tree, int i) {
        if (tree.isLeaf()) {
            return new Tree<>(tree.getLabel().getWord());
        }
        if (i == -1) {
            i = 0;
        }
        if (tree.isPreTerminal()) {
            ArrayList arrayList = new ArrayList();
            arrayList.add(extractBestViterbiDerivation(grammar, tree.getChildren().get(0), -1));
            short state = tree.getLabel().getState();
            String str = (String) this.tagNumberer.object(state);
            String str2 = (String) this.substateNumberers.get(state).object(i);
            if (!str2.equals("")) {
                str = str + "-" + str2;
            }
            return new Tree<>(str, arrayList);
        }
        StateSet label = tree.getLabel();
        short state2 = label.getState();
        ArrayList arrayList2 = new ArrayList();
        List<Tree<StateSet>> children = tree.getChildren();
        double iScore = label.getIScore(i);
        if (iScore == Double.NEGATIVE_INFINITY) {
            iScore = DoubleArrays.max(label.getIScores());
            i = DoubleArrays.argMax(label.getIScores());
        }
        switch (children.size()) {
            case 1:
                StateSet label2 = children.get(0).getLabel();
                short state3 = label2.getState();
                int numSubStates = label2.numSubStates();
                double[][] unaryScore = grammar.getUnaryScore(state2, state3);
                int i2 = -1;
                for (int i3 = 0; i3 < numSubStates && i2 == -1; i3++) {
                    if (unaryScore[i3] != null) {
                        double iScore2 = label2.getIScore(i3);
                        if (iScore2 != 0.0d) {
                            double d = unaryScore[i3][i];
                            if (d != 0.0d && matches(d * iScore2, iScore)) {
                                i2 = i3;
                            }
                        }
                    }
                }
                arrayList2.add(extractBestViterbiDerivation(grammar, children.get(0), i2));
                break;
            case ParsingObjectiveFunction.L2_REGULARIZATION /* 2 */:
                StateSet label3 = children.get(0).getLabel();
                StateSet label4 = children.get(1).getLabel();
                int numSubStates2 = label3.numSubStates();
                int numSubStates3 = label4.numSubStates();
                double[][][] binaryScore = grammar.getBinaryScore(state2, label3.getState(), label4.getState());
                int i4 = -1;
                int i5 = -1;
                for (int i6 = 0; i6 < numSubStates2 && (i4 == -1 || i5 == -1); i6++) {
                    double iScore3 = label3.getIScore(i6);
                    if (iScore3 != 0.0d) {
                        for (int i7 = 0; i7 < numSubStates3 && (i4 == -1 || i5 == -1); i7++) {
                            double iScore4 = label4.getIScore(i7);
                            if (iScore4 != 0.0d && binaryScore[i6][i7] != null) {
                                double d2 = binaryScore[i6][i7][i];
                                if (d2 != 0.0d) {
                                    if (matches(iScore, d2 * iScore3 * iScore4)) {
                                        i4 = i6;
                                        i5 = i7;
                                    }
                                }
                            }
                        }
                    }
                }
                arrayList2.add(extractBestViterbiDerivation(grammar, children.get(0), i4));
                arrayList2.add(extractBestViterbiDerivation(grammar, children.get(1), i5));
                break;
            default:
                throw new Error("Malformed tree: more than two children");
        }
        short state4 = label.getState();
        String str3 = (String) this.tagNumberer.object(state4);
        if (str3.endsWith("^g")) {
            str3 = str3.substring(0, str3.length() - 2);
        }
        String str4 = (String) this.substateNumberers.get(state4).object(i);
        if (!str4.equals("")) {
            str3 = str3 + "-" + str4;
        }
        return new Tree<>(str3, arrayList2);
    }

    protected boolean matches(double d, double d2) {
        return Math.abs(d - d2) / ((Math.abs(d) + Math.abs(d2)) + 1.0E-10d) < 1.0E-4d;
    }

    void doInsideScores(Tree<StateSet> tree, Grammar grammar, List<String> list) {
        if (tree.isLeaf()) {
            return;
        }
        List<Tree<StateSet>> children = tree.getChildren();
        for (Tree<StateSet> tree2 : children) {
            if (!tree2.isLeaf()) {
                doInsideScores(tree2, grammar, list);
            }
        }
        StateSet label = tree.getLabel();
        short state = label.getState();
        int numSubStates = label.numSubStates();
        if (tree.isPreTerminal()) {
            String str = list.get(label.from);
            String[] splitLabel = splitLabel(str);
            int i = 0;
            if (state < grammar.numStates) {
                i = this.substateNumberers.get(state).number(splitLabel[1]);
                if (i >= grammar.numSubStates[state]) {
                    System.err.println("Have never seen this POS: " + str);
                    i = 0;
                }
            } else {
                label = new StateSet((short) (grammar.numStates - 1), (short) 1);
                tree.setLabel(label);
            }
            label.setIScore(i, 1.0d);
            label.scaleIScores(0);
            return;
        }
        switch (children.size()) {
            case 0:
                return;
            case 1:
                StateSet label2 = children.get(0).getLabel();
                short state2 = label2.getState();
                int numSubStates2 = label2.numSubStates();
                double[][] unaryScore = grammar.getUnaryScore(state, state2);
                double[] dArr = new double[numSubStates];
                for (int i2 = 0; i2 < numSubStates2; i2++) {
                    if (unaryScore[i2] != null) {
                        double iScore = label2.getIScore(i2);
                        if (iScore != 0.0d) {
                            for (int i3 = 0; i3 < numSubStates; i3++) {
                                double d = unaryScore[i2][i3];
                                if (d != 0.0d) {
                                    int i4 = i3;
                                    dArr[i4] = dArr[i4] + (d * iScore);
                                }
                            }
                        }
                    }
                }
                label.setIScores(dArr);
                label.scaleIScores(label2.getIScale());
                return;
            case ParsingObjectiveFunction.L2_REGULARIZATION /* 2 */:
                StateSet label3 = children.get(0).getLabel();
                StateSet label4 = children.get(1).getLabel();
                int numSubStates3 = label3.numSubStates();
                int numSubStates4 = label4.numSubStates();
                double[][][] binaryScore = grammar.getBinaryScore(state, label3.getState(), label4.getState());
                double[] dArr2 = new double[numSubStates];
                for (int i5 = 0; i5 < numSubStates3; i5++) {
                    double iScore2 = label3.getIScore(i5);
                    if (iScore2 != 0.0d) {
                        for (int i6 = 0; i6 < numSubStates4; i6++) {
                            double iScore3 = label4.getIScore(i6);
                            if (iScore3 != 0.0d && binaryScore[i5][i6] != null) {
                                for (int i7 = 0; i7 < numSubStates; i7++) {
                                    double d2 = binaryScore[i5][i6][i7];
                                    if (d2 != 0.0d) {
                                        int i8 = i7;
                                        dArr2[i8] = dArr2[i8] + (d2 * iScore2 * iScore3);
                                    }
                                }
                            }
                        }
                    }
                }
                label.setIScores(dArr2);
                label.scaleIScores(label3.getIScale() + label4.getIScale());
                return;
            default:
                throw new Error("Malformed tree: more than two children");
        }
    }

    private static List<Tree<String>> loadTrees(String str) {
        InputStreamReader inputStreamReader = null;
        try {
            inputStreamReader = new InputStreamReader(new FileInputStream(str), "UTF-8");
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (UnsupportedEncodingException e2) {
            e2.printStackTrace();
        }
        Trees.PennTreeReader pennTreeReader = new Trees.PennTreeReader(inputStreamReader);
        ArrayList arrayList = new ArrayList();
        while (pennTreeReader.hasNext()) {
            arrayList.add(pennTreeReader.next());
        }
        return arrayList;
    }

    public static void main(String[] strArr) {
        List<Tree<String>> loadTrees = loadTrees(strArr[0]);
        GermanSharedTask germanSharedTask = new GermanSharedTask();
        Grammar extractGrammar = germanSharedTask.extractGrammar(loadTrees);
        List<Tree<String>> loadTrees2 = loadTrees("/Users/petrov/Data/german_st/tueba/tueba_tmp");
        List<Tree<String>> loadTrees3 = loadTrees("/Users/petrov/Data/german_st/tueba/data02.mrg");
        ArrayList arrayList = new ArrayList(loadTrees3.size());
        Iterator<Tree<String>> it = loadTrees3.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().getPreTerminalYield());
        }
        germanSharedTask.checkGrammar(extractGrammar, loadTrees2, loadTrees3);
    }
}
