package edu.berkeley.nlp.PCFGLA;

import edu.berkeley.nlp.util.ArrayUtil;
import edu.berkeley.nlp.util.Numberer;
import edu.berkeley.nlp.util.ScalingTools;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.ObjectInputStream;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.io.Writer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.zip.GZIPInputStream;

/* loaded from: input_file:edu/berkeley/nlp/PCFGLA/PosteriorMerger.class */
public class PosteriorMerger {
    static double[][][] maxcScore;
    static int[][][] maxcSplit;
    static int[][][] maxcChild;
    static int[][][] maxcLeftChild;
    static int[][][] maxcRightChild;

    /* loaded from: input_file:edu/berkeley/nlp/PCFGLA/PosteriorMerger$Options.class */
    public static class Options {

        @Option(name = "-grammarFiles", required = true, usage = "Input Files for Grammars.")
        public String grammarFiles;

        @Option(name = "-inputFile", usage = "Read input from this file instead of reading it from STDIN.")
        public String inputFile;

        @Option(name = "-outputFile", usage = "Store output in this file instead of printing it to STDOUT.")
        public String outputFile;

        @Option(name = "-nGrammars", usage = "Number of Grammars")
        public int nGrammars;

        @Option(name = "-maxLength", usage = "Maximum sentence length (Default = 200).")
        public int maxLength = 200;
    }

    /* JADX WARN: Type inference failed for: r0v10, types: [short[], short[][]] */
    public static void main(String[] strArr) {
        OptionParser optionParser = new OptionParser(Options.class);
        Options options = (Options) optionParser.parse(strArr, true);
        System.err.println("Calling with " + optionParser.getPassedInOptions());
        String str = options.grammarFiles;
        if (str == null) {
            throw new Error("Did not provide a grammar.");
        }
        ?? r0 = new short[options.nGrammars];
        Grammar[] grammarArr = new Grammar[options.nGrammars];
        Lexicon[] lexiconArr = new Lexicon[options.nGrammars];
        for (int i = 0; i < options.nGrammars; i++) {
            System.err.println("Loading grammar from " + str + "." + (i + 1));
            ParserData Load = ParserData.Load(str + "." + (i + 1));
            if (Load == null) {
                System.out.println("Failed to load grammar from file" + str + ".");
                System.exit(1);
            }
            r0[i] = Load.getGrammar().numSubStates;
            Numberer.setNumberers(Load.getNumbs());
            grammarArr[i] = Load.getGrammar();
            lexiconArr[i] = Load.getLexicon();
        }
        int length = r0.length;
        CoarseToFineMaxRuleParser coarseToFineMaxRuleParser = new CoarseToFineMaxRuleParser(grammarArr[0], lexiconArr[0], 1.0d, -1, false, false, false, true, false, false, false);
        try {
            BufferedReader bufferedReader = options.inputFile == null ? new BufferedReader(new InputStreamReader(System.in)) : new BufferedReader(new InputStreamReader(new FileInputStream(options.inputFile), "UTF-8"));
            PrintWriter printWriter = options.outputFile == null ? new PrintWriter(new OutputStreamWriter(System.out)) : new PrintWriter((Writer) new OutputStreamWriter(new FileOutputStream(options.outputFile), "UTF-8"), true);
            int i2 = 0;
            int i3 = 0;
            List[] listArr = null;
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    break;
                }
                List<String> asList = Arrays.asList(readLine.split(" "));
                if (listArr == null || i3 == listArr[0].size()) {
                    listArr = new ArrayList[length];
                    for (int i4 = 0; i4 < length; i4++) {
                        listArr[i4] = loadPosteriors(options.grammarFiles + "." + (i4 + 1) + ".posteriors." + i2);
                    }
                    i3 = 0;
                    i2++;
                }
                int size = asList.size();
                if (size > options.maxLength) {
                    printWriter.write("(())\n");
                } else {
                    ArrayList arrayList = new ArrayList(length);
                    ArrayList arrayList2 = new ArrayList(length);
                    ArrayList arrayList3 = new ArrayList(length);
                    ArrayList arrayList4 = new ArrayList(length);
                    boolean[][][] zArr = (boolean[][][]) null;
                    boolean z = false;
                    for (int i5 = 0; i5 < length; i5++) {
                        Posterior posterior = (Posterior) listArr[i5].get(i3);
                        arrayList.add(posterior.iScore);
                        arrayList2.add(posterior.oScore);
                        arrayList3.add(posterior.iScale);
                        arrayList4.add(posterior.oScale);
                        zArr = mergeAllowedStates(zArr, posterior.allowedStates);
                        countAllowedStates(zArr);
                        if (posterior.iScale != null) {
                            z = true;
                            System.err.println("Scaling will be used.");
                            if (size != posterior.iScale.length) {
                                System.err.println("G: " + i5 + " sentence " + i3 + " Length mismatch. Expected: " + size + " Got: " + posterior.iScale.length);
                            }
                        }
                    }
                    i3++;
                    if (z) {
                        printWriter.write("(()) \n");
                    } else {
                        doCombinedMaxCScores(asList, arrayList, arrayList2, arrayList3, arrayList4, zArr, grammarArr, lexiconArr, r0, arrayList3.get(0) != null);
                        System.err.println("Done with scores");
                        if (maxcScore[0][asList.size()][0] == Double.NEGATIVE_INFINITY) {
                            System.err.println("MaxCscore for ROOT is -Inf.");
                            printWriter.write("(()) \n");
                        } else {
                            coarseToFineMaxRuleParser.maxcScore = maxcScore;
                            coarseToFineMaxRuleParser.maxcChild = maxcChild;
                            coarseToFineMaxRuleParser.maxcLeftChild = maxcLeftChild;
                            coarseToFineMaxRuleParser.maxcRightChild = maxcRightChild;
                            coarseToFineMaxRuleParser.maxcSplit = maxcSplit;
                            coarseToFineMaxRuleParser.allowedStates = zArr;
                            printWriter.write(TreeAnnotations.unAnnotateTree(coarseToFineMaxRuleParser.extractBestMaxRuleParse(0, asList.size(), asList), false) + "\n");
                            printWriter.flush();
                        }
                    }
                }
            }
            printWriter.flush();
            printWriter.close();
        } catch (Exception e) {
            e.printStackTrace();
        }
        System.exit(0);
    }

    private static boolean[][][] mergeAllowedStates(boolean[][][] zArr, boolean[][][] zArr2) {
        if (zArr == null) {
            return zArr2;
        }
        for (int i = 0; i < zArr.length; i++) {
            for (int i2 = i + 1; i2 < zArr[i].length; i2++) {
                for (int i3 = 0; i3 < zArr[i][i2].length; i3++) {
                    if (!zArr2[i][i2][i3] && zArr[i][i2][i3]) {
                        zArr[i][i2][i3] = false;
                    }
                }
            }
        }
        return zArr;
    }

    private static void countAllowedStates(boolean[][][] zArr) {
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < zArr.length; i3++) {
            for (int i4 = i3 + 1; i4 < zArr[i3].length; i4++) {
                for (int i5 = 0; i5 < zArr[i3][i4].length; i5++) {
                    if (zArr[i3][i4][i5]) {
                        i2++;
                    }
                    i++;
                }
            }
        }
        System.err.println(i2 + "/" + i + " allowed for sentence of length " + zArr.length);
    }

    static void doCombinedMaxCScores(List<String> list, List<double[][][][]> list2, List<double[][][][]> list3, List<int[][][]> list4, List<int[][][]> list5, boolean[][][] zArr, Grammar[] grammarArr, Lexicon[] lexiconArr, short[][] sArr, boolean z) {
        int size = list.size();
        int length = sArr.length;
        int length2 = sArr[0].length;
        boolean[] zArr2 = grammarArr[0].isGrammarTag;
        Numberer globalNumberer = Numberer.getGlobalNumberer("tags");
        maxcScore = new double[size][size + 1][length2];
        maxcSplit = new int[size][size + 1][length2];
        maxcChild = new int[size][size + 1][length2];
        maxcLeftChild = new int[size][size + 1][length2];
        maxcRightChild = new int[size][size + 1][length2];
        ArrayUtil.fill(maxcScore, Double.NEGATIVE_INFINITY);
        double[] dArr = new double[length];
        for (int i = 0; i < length; i++) {
            dArr[i] = list2.get(i)[0][size][0][0];
        }
        for (int i2 = 1; i2 <= size; i2++) {
            for (int i3 = 0; i3 < (size - i2) + 1; i3++) {
                int i4 = i3 + i2;
                Arrays.fill(maxcSplit[i3][i4], -1);
                Arrays.fill(maxcChild[i3][i4], -1);
                Arrays.fill(maxcLeftChild[i3][i4], -1);
                Arrays.fill(maxcRightChild[i3][i4], -1);
                if (i2 > 1) {
                    short s = 0;
                    while (true) {
                        short s2 = s;
                        if (s2 >= length2) {
                            break;
                        }
                        if (zArr[i3][i4][s2]) {
                            for (BinaryRule binaryRule : grammarArr[0].splitRulesWithP(s2)) {
                                short s3 = binaryRule.leftChildState;
                                short s4 = binaryRule.rightChildState;
                                double d = maxcScore[i3][i4][s2];
                                for (int i5 = i3 + 1; i5 <= i4 - 1; i5++) {
                                    if (zArr[i3][i5][s3] && zArr[i5][i4][s4]) {
                                        double d2 = maxcScore[i3][i5][s3];
                                        double d3 = maxcScore[i5][i4][s4];
                                        if (d2 != Double.NEGATIVE_INFINITY && d3 != Double.NEGATIVE_INFINITY) {
                                            double d4 = 0.0d;
                                            if (z) {
                                                for (int i6 = 0; i6 < length; i6++) {
                                                    d4 += ((list5.get(i6)[i3][i4][s2] + list4.get(i6)[i3][i5][s3]) + list4.get(i6)[i5][i4][s4]) - list4.get(i6)[0][size][0];
                                                }
                                                d4 = Math.log(ScalingTools.calcScaleFactor(d4));
                                            }
                                            double d5 = d2 + d4 + d3;
                                            if (d5 >= d) {
                                                for (int i7 = 0; i7 < length; i7++) {
                                                    double d6 = 0.0d;
                                                    BinaryRule binaryRule2 = grammarArr[i7].getBinaryRule(s2, s3, s4);
                                                    if (binaryRule2 == null) {
                                                        System.err.println("Dont have rule " + ((String) globalNumberer.object(s2)) + " -> " + ((String) globalNumberer.object(s3)) + " " + ((String) globalNumberer.object(s4)) + " in grammar " + i7);
                                                    } else {
                                                        double[][][] scores2 = binaryRule2.getScores2();
                                                        short s5 = sArr[i7][s2];
                                                        short s6 = sArr[i7][s3];
                                                        short s7 = sArr[i7][s4];
                                                        for (int i8 = 0; i8 < s6; i8++) {
                                                            double d7 = list2.get(i7)[i3][i5][s3][i8];
                                                            if (d7 != 0.0d) {
                                                                for (int i9 = 0; i9 < s7; i9++) {
                                                                    if (scores2[i8][i9] != null) {
                                                                        double d8 = list2.get(i7)[i5][i4][s4][i9];
                                                                        if (d8 != 0.0d) {
                                                                            for (int i10 = 0; i10 < s5; i10++) {
                                                                                double d9 = list3.get(i7)[i3][i4][s2][i10];
                                                                                if (d9 != 0.0d) {
                                                                                    double d10 = scores2[i8][i9][i10];
                                                                                    if (d10 != 0.0d) {
                                                                                        d6 += (((d9 * d10) * d7) * d8) / dArr[i7];
                                                                                    }
                                                                                }
                                                                            }
                                                                        }
                                                                    }
                                                                }
                                                            }
                                                        }
                                                        d5 += Math.log(d6);
                                                    }
                                                }
                                                if (d5 > d) {
                                                    d = d5;
                                                    maxcScore[i3][i4][s2] = d5;
                                                    maxcSplit[i3][i4][s2] = i5;
                                                    maxcLeftChild[i3][i4][s2] = s3;
                                                    maxcRightChild[i3][i4][s2] = s4;
                                                }
                                            }
                                        }
                                    }
                                }
                            }
                        }
                        s = (short) (s2 + 1);
                    }
                } else {
                    for (int i11 = 0; i11 < length2; i11++) {
                        if (zArr[i3][i4][i11]) {
                            String str = list.get(i3);
                            if (!zArr2[i11]) {
                                double d11 = 0.0d;
                                for (int i12 = 0; i12 < length; i12++) {
                                    double d12 = 0.0d;
                                    double[] score = lexiconArr[i12].score(str, (short) i11, i3, false, false);
                                    for (int i13 = 0; i13 < sArr[i12][i11]; i13++) {
                                        d12 += (list3.get(i12)[i3][i4][i11][i13] * score[i13]) / dArr[i12];
                                    }
                                    d11 += Math.log(d12);
                                }
                                if (size != list2.get(0).length) {
                                    System.err.println("Length mismatch. Expected: " + size + " Got: " + list2.get(0).length);
                                    System.err.println(list);
                                }
                                double d13 = 0.0d;
                                if (z) {
                                    for (int i14 = 0; i14 < length; i14++) {
                                        try {
                                            d13 += list5.get(i14)[i3][i4][i11] - list4.get(i14)[0][size][0];
                                        } catch (ArrayIndexOutOfBoundsException e) {
                                            System.err.println("Start " + i3);
                                            System.err.println("End " + i4);
                                            System.err.println("Length " + size);
                                            System.err.println("Tag " + i11);
                                            System.err.println("Grammar " + i14);
                                            int[][][] iArr = list5.get(i14);
                                            System.err.println("oS.l " + iArr.length);
                                            System.err.println("oS[].l " + iArr[i3].length);
                                            System.err.println("oS[][].l " + iArr[i3][i4].length);
                                            int[][][] iArr2 = list4.get(i14);
                                            System.err.println("iS.l " + iArr2.length);
                                            System.err.println("iS[].l " + iArr2[i3].length);
                                            System.err.println("iS[][].l " + iArr2[i3][i4].length);
                                            double[][][][] dArr2 = list2.get(i14);
                                            System.err.println("iS.l " + dArr2.length);
                                            System.err.println("iS[].l " + dArr2[i3].length);
                                            System.err.println("iS[][].l " + dArr2[i3][i4].length);
                                            System.err.println("Length mismatch. Expected: " + size + " Got: " + list4.get(i14).length);
                                            System.err.println(list);
                                        }
                                    }
                                    d13 = Math.log(ScalingTools.calcScaleFactor(d13));
                                }
                                maxcScore[i3][i4][i11] = d11 + d13;
                            }
                        }
                    }
                }
                double[] dArr3 = new double[length2];
                for (int i15 = 0; i15 < length2; i15++) {
                    dArr3[i15] = maxcScore[i3][i4][i15];
                }
                short s8 = 0;
                while (true) {
                    short s9 = s8;
                    if (s9 < length2) {
                        if (zArr[i3][i4][s9]) {
                            for (UnaryRule unaryRule : grammarArr[0].getClosedSumUnaryRulesByParent(s9)) {
                                short s10 = unaryRule.childState;
                                if (s9 != s10 && zArr[i3][i4][s10]) {
                                    double d14 = maxcScore[i3][i4][s10];
                                    if (d14 != Double.NEGATIVE_INFINITY) {
                                        double d15 = 0.0d;
                                        if (z) {
                                            for (int i16 = 0; i16 < length; i16++) {
                                                d15 += (list5.get(i16)[i3][i4][s9] + list4.get(i16)[i3][i4][s10]) - list4.get(i16)[0][size][0];
                                            }
                                            d15 = Math.log(ScalingTools.calcScaleFactor(d15));
                                        }
                                        double d16 = d15 + d14;
                                        if (d16 >= dArr3[s9]) {
                                            for (int i17 = 0; i17 < length; i17++) {
                                                double d17 = 0.0d;
                                                UnaryRule unaryRule2 = grammarArr[i17].getUnaryRule(s9, s10);
                                                if (unaryRule2 == null) {
                                                    System.err.println("Dont have rule " + ((String) globalNumberer.object(s9)) + " -> " + ((String) globalNumberer.object(s10)) + " in grammar " + i17);
                                                } else {
                                                    double[][] scores22 = unaryRule2.getScores2();
                                                    short s11 = sArr[i17][s10];
                                                    short s12 = sArr[i17][s9];
                                                    for (int i18 = 0; i18 < s11; i18++) {
                                                        double d18 = list2.get(i17)[i3][i4][s10][i18];
                                                        if (d18 != 0.0d && scores22[i18] != null) {
                                                            for (int i19 = 0; i19 < s12; i19++) {
                                                                double d19 = list3.get(i17)[i3][i4][s9][i19];
                                                                if (d19 >= 0.0d) {
                                                                    double d20 = scores22[i18][i19];
                                                                    if (d20 != 0.0d) {
                                                                        d17 += ((d19 * d20) * d18) / dArr[i17];
                                                                    }
                                                                }
                                                            }
                                                        }
                                                    }
                                                    d16 += Math.log(d17);
                                                }
                                            }
                                            if (d16 > dArr3[s9]) {
                                                dArr3[s9] = d16;
                                                maxcChild[i3][i4][s9] = s10;
                                            }
                                        }
                                    }
                                }
                            }
                        }
                        s8 = (short) (s9 + 1);
                    }
                }
                maxcScore[i3][i4] = dArr3;
            }
        }
    }

    public static List<Posterior> loadPosteriors(String str) {
        try {
            FileInputStream fileInputStream = new FileInputStream(str);
            GZIPInputStream gZIPInputStream = new GZIPInputStream(fileInputStream);
            ObjectInputStream objectInputStream = new ObjectInputStream(gZIPInputStream);
            List<Posterior> list = (List) objectInputStream.readObject();
            objectInputStream.close();
            gZIPInputStream.close();
            fileInputStream.close();
            return list;
        } catch (IOException e) {
            System.out.println("IOException\n" + e);
            return null;
        } catch (ClassNotFoundException e2) {
            System.out.println("Class not found!");
            return null;
        }
    }
}
