package cmu.arktweetnlp.impl;

import cmu.arktweetnlp.util.BasicFileIO;
import edu.berkeley.nlp.util.ArrayUtil;
import edu.berkeley.nlp.util.Triple;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.util.Pair;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;

/* loaded from: input_file:cmu/arktweetnlp/impl/Model.class */
public class Model {
    public Vocabulary labelVocab = new Vocabulary();
    public Vocabulary featureVocab = new Vocabulary();
    public double[] biasCoefs;
    public double[][] edgeCoefs;
    public double[][] observationFeatureCoefs;
    public int numLabels;
    static final /* synthetic */ boolean $assertionsDisabled;

    public int startMarker() {
        if ($assertionsDisabled || this.labelVocab.isLocked()) {
            return (this.labelVocab.size() - 1) + 1;
        }
        throw new AssertionError();
    }

    public void lockdownAfterFeatureExtraction() {
        this.labelVocab.lock();
        this.featureVocab.lock();
        allocateCoefs(this.labelVocab.size(), this.featureVocab.size());
    }

    public void allocateCoefs(int i, int i2) {
        this.observationFeatureCoefs = new double[i2][i];
        this.edgeCoefs = new double[i + 1][i];
        this.biasCoefs = new double[i];
    }

    public double[][] inferPosteriorGivenLabels(ModelSentence modelSentence) {
        double[][] dArr = new double[modelSentence.T][this.labelVocab.size()];
        double[] dArr2 = new double[this.numLabels];
        for (int i = 0; i < modelSentence.T; i++) {
            computeLabelScores(i, modelSentence, dArr2);
            ArrayUtil.expInPlace(dArr2);
            double sum = ArrayUtil.sum(dArr2);
            for (int i2 = 0; i2 < this.numLabels; i2++) {
                dArr[i][i2] = dArr2[i2] / sum;
            }
        }
        return dArr;
    }

    public void greedyDecode(ModelSentence modelSentence, boolean z) {
        int i = modelSentence.T;
        modelSentence.labels = new int[i];
        modelSentence.edgeFeatures[0] = startMarker();
        if (z) {
            modelSentence.confidences = new double[i];
        }
        double[] dArr = new double[this.numLabels];
        for (int i2 = 0; i2 < i; i2++) {
            computeLabelScores(i2, modelSentence, dArr);
            modelSentence.labels[i2] = ArrayMath.argmax(dArr);
            if (i2 < i - 1) {
                modelSentence.edgeFeatures[i2 + 1] = modelSentence.labels[i2];
            }
            if (z) {
                ArrayMath.expInPlace(dArr);
                ArrayMath.multiplyInPlace(dArr, 1.0d / ArrayMath.sum(dArr));
                modelSentence.confidences[i2] = dArr[modelSentence.labels[i2]];
            }
        }
    }

    public double[][] inferPosteriorForUnknownLabels(ModelSentence modelSentence) {
        if ($assertionsDisabled) {
            return (double[][]) null;
        }
        throw new AssertionError("Unimplemented");
    }

    public void viterbiDecode(ModelSentence modelSentence) {
        int i = modelSentence.T;
        modelSentence.labels = new int[i];
        int[][] iArr = new int[i][this.numLabels];
        double[][] dArr = new double[i][this.numLabels];
        double[] dArr2 = new double[this.numLabels];
        computeVitLabelScores(0, startMarker(), modelSentence, dArr2);
        ArrayUtil.logNormalize(dArr2);
        dArr[0] = dArr2;
        for (int i2 = 0; i2 < this.numLabels; i2++) {
            iArr[0][i2] = startMarker();
        }
        for (int i3 = 1; i3 < i; i3++) {
            double[][] dArr3 = new double[this.numLabels][this.numLabels];
            for (int i4 = 0; i4 < this.numLabels; i4++) {
                computeVitLabelScores(i3, i4, modelSentence, dArr3[i4]);
                ArrayUtil.logNormalize(dArr3[i4]);
                dArr3[i4] = ArrayUtil.add(dArr3[i4], dArr2[i4]);
            }
            for (int i5 = 0; i5 < this.numLabels; i5++) {
                double[] column = getColumn(dArr3, i5);
                iArr[i3][i5] = ArrayUtil.argmax(column);
                dArr[i3][i5] = column[iArr[i3][i5]];
            }
            dArr2 = dArr[i3];
        }
        modelSentence.labels[i - 1] = ArrayUtil.argmax(dArr[i - 1]);
        int i6 = iArr[i - 1][modelSentence.labels[i - 1]];
        for (int i7 = i - 2; i7 >= 0 && i6 != startMarker(); i7--) {
            modelSentence.labels[i7] = i6;
            i6 = iArr[i7][i6];
        }
        if (!$assertionsDisabled && i6 != startMarker()) {
            throw new AssertionError();
        }
    }

    private double[] getColumn(double[][] dArr, int i) {
        double[] dArr2 = new double[dArr.length];
        for (int i2 = 0; i2 < dArr[0].length; i2++) {
            dArr2[i2] = dArr[i2][i];
        }
        return dArr2;
    }

    public void mbrDecode(ModelSentence modelSentence) {
        double[][] inferPosteriorForUnknownLabels = inferPosteriorForUnknownLabels(modelSentence);
        for (int i = 0; i < modelSentence.T; i++) {
            modelSentence.labels[i] = ArrayMath.argmax(inferPosteriorForUnknownLabels[i]);
        }
    }

    public void computeLabelScores(int i, ModelSentence modelSentence, double[] dArr) {
        Arrays.fill(dArr, 0.0d);
        computeBiasScores(dArr);
        computeEdgeScores(i, modelSentence, dArr);
        computeObservedFeatureScores(i, modelSentence, dArr);
    }

    public void computeVitLabelScores(int i, int i2, ModelSentence modelSentence, double[] dArr) {
        Arrays.fill(dArr, 0.0d);
        computeBiasScores(dArr);
        viterbiEdgeScores(i2, modelSentence, dArr);
        computeObservedFeatureScores(i, modelSentence, dArr);
    }

    public void computeBiasScores(double[] dArr) {
        for (int i = 0; i < this.numLabels; i++) {
            int i2 = i;
            dArr[i2] = dArr[i2] + this.biasCoefs[i];
        }
    }

    public void computeEdgeScores(int i, ModelSentence modelSentence, double[] dArr) {
        int i2 = modelSentence.edgeFeatures[i];
        for (int i3 = 0; i3 < this.numLabels; i3++) {
            int i4 = i3;
            dArr[i4] = dArr[i4] + this.edgeCoefs[i2][i3];
        }
    }

    public void viterbiEdgeScores(int i, ModelSentence modelSentence, double[] dArr) {
        for (int i2 = 0; i2 < this.numLabels; i2++) {
            int i3 = i2;
            dArr[i3] = dArr[i3] + this.edgeCoefs[i][i2];
        }
    }

    public void computeObservedFeatureScores(int i, ModelSentence modelSentence, double[] dArr) {
        for (int i2 = 0; i2 < this.numLabels; i2++) {
            Iterator<Pair<Integer, Double>> it = modelSentence.observationFeatures.get(i).iterator();
            while (it.hasNext()) {
                Pair<Integer, Double> next = it.next();
                int i3 = i2;
                dArr[i3] = dArr[i3] + (this.observationFeatureCoefs[((Integer) next.first).intValue()][i2] * ((Double) next.second).doubleValue());
            }
        }
    }

    public double[] ThreewiseMultiply(double[] dArr, double[] dArr2, double[] dArr3) {
        if (dArr.length != dArr2.length || dArr2.length != dArr3.length) {
            throw new RuntimeException();
        }
        double[] dArr4 = new double[dArr.length];
        for (int i = 0; i < dArr4.length; i++) {
            dArr4[i] = dArr[i] * dArr2[i] * dArr3[i];
        }
        return dArr4;
    }

    public void computeGradient(ModelSentence modelSentence, double[] dArr) {
        if (!$assertionsDisabled && dArr.length != flatIDsize()) {
            throw new AssertionError();
        }
        int i = modelSentence.T;
        double[][] inferPosteriorGivenLabels = inferPosteriorGivenLabels(modelSentence);
        for (int i2 = 0; i2 < i; i2++) {
            int i3 = modelSentence.edgeFeatures[i2];
            int i4 = modelSentence.labels[i2];
            int i5 = 0;
            while (i5 < this.numLabels) {
                double d = inferPosteriorGivenLabels[i2][i5];
                int i6 = i4 == i5 ? 1 : 0;
                int biasFeature_to_flatID = biasFeature_to_flatID(i5);
                dArr[biasFeature_to_flatID] = dArr[biasFeature_to_flatID] + (i6 - d);
                int edgeFeature_to_flatID = edgeFeature_to_flatID(i3, i5);
                dArr[edgeFeature_to_flatID] = dArr[edgeFeature_to_flatID] + (i6 - d);
                Iterator<Pair<Integer, Double>> it = modelSentence.observationFeatures.get(i2).iterator();
                while (it.hasNext()) {
                    Pair<Integer, Double> next = it.next();
                    int observationFeature_to_flatID = observationFeature_to_flatID(((Integer) next.first).intValue(), i5);
                    dArr[observationFeature_to_flatID] = dArr[observationFeature_to_flatID] + ((i6 - d) * ((Double) next.second).doubleValue());
                }
                i5++;
            }
        }
    }

    public double computeLogLik(ModelSentence modelSentence) {
        double[][] inferPosteriorGivenLabels = inferPosteriorGivenLabels(modelSentence);
        double d = 0.0d;
        for (int i = 0; i < modelSentence.T; i++) {
            d += Math.log(inferPosteriorGivenLabels[i][modelSentence.labels[i]]);
        }
        return d;
    }

    public void setCoefsFromFlat(double[] dArr) {
        for (int i = 0; i < this.numLabels; i++) {
            this.biasCoefs[i] = dArr[biasFeature_to_flatID(i)];
        }
        for (int i2 = 0; i2 < this.numLabels + 1; i2++) {
            for (int i3 = 0; i3 < this.numLabels; i3++) {
                this.edgeCoefs[i2][i3] = dArr[edgeFeature_to_flatID(i2, i3)];
            }
        }
        for (int i4 = 0; i4 < this.featureVocab.size(); i4++) {
            for (int i5 = 0; i5 < this.numLabels; i5++) {
                this.observationFeatureCoefs[i4][i5] = dArr[observationFeature_to_flatID(i4, i5)];
            }
        }
    }

    public double[] convertCoefsToFlat() {
        double[] dArr = new double[flatIDsize()];
        for (int i = 0; i < this.numLabels; i++) {
            dArr[biasFeature_to_flatID(i)] = this.biasCoefs[i];
        }
        for (int i2 = 0; i2 < this.numLabels + 1; i2++) {
            for (int i3 = 0; i3 < this.numLabels; i3++) {
                dArr[edgeFeature_to_flatID(i2, i3)] = this.edgeCoefs[i2][i3];
            }
        }
        for (int i4 = 0; i4 < this.featureVocab.size(); i4++) {
            for (int i5 = 0; i5 < this.numLabels; i5++) {
                dArr[observationFeature_to_flatID(i4, i5)] = this.observationFeatureCoefs[i4][i5];
            }
        }
        return dArr;
    }

    public int flatIDsize() {
        int size = this.labelVocab.size();
        return size + ((size + 1) * size) + (this.featureVocab.size() * size);
    }

    private int biasFeature_to_flatID(int i) {
        return i;
    }

    private int edgeFeature_to_flatID(int i, int i2) {
        int size = this.labelVocab.size();
        return size + (i * size) + i2;
    }

    private int observationFeature_to_flatID(int i, int i2) {
        int size = this.labelVocab.size();
        return size + ((size + 1) * size) + (i * size) + i2;
    }

    public void saveModelAsText(String str) throws IOException {
        BufferedWriter openFileToWriteUTF8 = BasicFileIO.openFileToWriteUTF8(str);
        PrintWriter printWriter = new PrintWriter(openFileToWriteUTF8);
        for (int i = 0; i < this.numLabels; i++) {
            printWriter.printf("***BIAS***\t%s\t%g\n", this.labelVocab.name(i), Double.valueOf(this.biasCoefs[i]));
        }
        for (int i2 = 0; i2 < this.numLabels + 1; i2++) {
            for (int i3 = 0; i3 < this.numLabels; i3++) {
                printWriter.printf("***EDGE***\t%s %s\t%s\n", Integer.valueOf(i2), Integer.valueOf(i3), Double.valueOf(this.edgeCoefs[i2][i3]));
            }
        }
        if (!$assertionsDisabled && this.featureVocab.size() != this.observationFeatureCoefs.length) {
            throw new AssertionError();
        }
        for (int i4 = 0; i4 < this.featureVocab.size(); i4++) {
            for (int i5 = 0; i5 < this.numLabels; i5++) {
                if (this.observationFeatureCoefs[i4][i5] != 0.0d) {
                    printWriter.printf("%s\t%s\t%g\n", this.featureVocab.name(i4), this.labelVocab.name(i5), Double.valueOf(this.observationFeatureCoefs[i4][i5]));
                }
            }
        }
        printWriter.close();
        openFileToWriteUTF8.close();
    }

    public static Model loadModelFromText(String str) throws IOException {
        String str2;
        String readLine;
        String readLine2;
        Model model = new Model();
        BufferedReader openFileOrResource = BasicFileIO.openFileOrResource(str);
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        while (true) {
            String readLine3 = openFileOrResource.readLine();
            str2 = readLine3;
            if (readLine3 == null) {
                break;
            }
            String[] split = str2.split("\t");
            if (!split[0].equals("***BIAS***")) {
                break;
            }
            model.labelVocab.num(split[1]);
            arrayList.add(Double.valueOf(Double.parseDouble(split[2])));
        }
        model.labelVocab.lock();
        model.numLabels = model.labelVocab.size();
        do {
            String[] split2 = str2.split("\t");
            if (!split2[0].equals("***EDGE***")) {
                break;
            }
            String[] split3 = split2[1].split(" ");
            arrayList2.add(new Triple(Integer.valueOf(Integer.parseInt(split3[0])), Integer.valueOf(Integer.parseInt(split3[1])), Double.valueOf(Double.parseDouble(split2[2]))));
            readLine2 = openFileOrResource.readLine();
            str2 = readLine2;
        } while (readLine2 != null);
        do {
            String[] split4 = str2.split("\t");
            arrayList3.add(new Triple(Integer.valueOf(model.featureVocab.num(split4[0])), Integer.valueOf(model.labelVocab.num(split4[1])), Double.valueOf(Double.parseDouble(split4[2]))));
            readLine = openFileOrResource.readLine();
            str2 = readLine;
        } while (readLine != null);
        model.featureVocab.lock();
        model.allocateCoefs(model.labelVocab.size(), model.featureVocab.size());
        for (int i = 0; i < model.numLabels; i++) {
            model.biasCoefs[i] = ((Double) arrayList.get(i)).doubleValue();
        }
        Iterator it = arrayList2.iterator();
        while (it.hasNext()) {
            Triple triple = (Triple) it.next();
            model.edgeCoefs[((Integer) triple.getFirst()).intValue()][((Integer) triple.getSecond()).intValue()] = ((Double) triple.getThird()).doubleValue();
        }
        Iterator it2 = arrayList3.iterator();
        while (it2.hasNext()) {
            Triple triple2 = (Triple) it2.next();
            model.observationFeatureCoefs[((Integer) triple2.getFirst()).intValue()][((Integer) triple2.getSecond()).intValue()] = ((Double) triple2.getThird()).doubleValue();
        }
        openFileOrResource.close();
        return model;
    }

    public static void copyCoefsForIntersectingFeatures(Model model, Model model2) {
        int i = model.numLabels;
        if (i != model2.numLabels) {
            throw new RuntimeException("label vocabs must be same size for warm-start");
        }
        for (int i2 = 0; i2 < i; i2++) {
            if (!model2.labelVocab.name(i2).equals(model.labelVocab.name(i2))) {
                throw new RuntimeException("label vocabs must agree for warm-start");
            }
        }
        model2.biasCoefs = ArrayUtil.copy(model.biasCoefs);
        model2.edgeCoefs = ArrayUtil.copy(model.edgeCoefs);
        for (int i3 = 0; i3 < model.featureVocab.size(); i3++) {
            String name = model.featureVocab.name(i3);
            if (model2.featureVocab.contains(name)) {
                model2.observationFeatureCoefs[model2.featureVocab.num(name)] = ArrayUtil.copy(model.observationFeatureCoefs[i3]);
            }
        }
    }

    static {
        $assertionsDisabled = !Model.class.desiredAssertionStatus();
    }
}
