package edu.stanford.nlp.sentiment;

import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.neural.Embedding;
import edu.stanford.nlp.neural.NeuralUtils;
import edu.stanford.nlp.neural.SimpleTensor;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.TwoDimensionalMap;
import edu.stanford.nlp.util.TwoDimensionalSet;
import java.io.IOException;
import java.io.Serializable;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.TreeMap;
import org.ejml.simple.SimpleMatrix;

/* loaded from: input_file:edu/stanford/nlp/sentiment/SentimentModel.class */
public class SentimentModel implements Serializable {
    public TwoDimensionalMap<String, String, SimpleMatrix> binaryTransform;
    public TwoDimensionalMap<String, String, SimpleTensor> binaryTensors;
    public TwoDimensionalMap<String, String, SimpleMatrix> binaryClassification;
    public Map<String, SimpleMatrix> unaryClassification;
    public Map<String, SimpleMatrix> wordVectors;
    public final int numClasses;
    public final int numHid;
    public final int numBinaryMatrices;
    public final int binaryTransformSize;
    public final int binaryTensorSize;
    public final int binaryClassificationSize;
    public final int numUnaryMatrices;
    public final int unaryClassificationSize;
    transient SimpleMatrix identity;
    final Random rand;
    static final String UNKNOWN_WORD = "*UNK*";
    final RNNOptions op;
    private static final long serialVersionUID = 1;

    static SentimentModel modelFromMatrices(SimpleMatrix simpleMatrix, SimpleMatrix simpleMatrix2, SimpleTensor simpleTensor, Map<String, SimpleMatrix> map, RNNOptions rNNOptions) {
        if (!rNNOptions.combineClassification || !rNNOptions.simplifiedModel) {
            throw new IllegalArgumentException("Can only create a model using this method if combineClassification and simplifiedModel are turned on");
        }
        TwoDimensionalMap treeMap = TwoDimensionalMap.treeMap();
        treeMap.put("", "", simpleMatrix);
        TwoDimensionalMap treeMap2 = TwoDimensionalMap.treeMap();
        treeMap2.put("", "", simpleTensor);
        TwoDimensionalMap treeMap3 = TwoDimensionalMap.treeMap();
        TreeMap newTreeMap = Generics.newTreeMap();
        newTreeMap.put("", simpleMatrix2);
        return new SentimentModel(treeMap, treeMap2, treeMap3, newTreeMap, map, rNNOptions);
    }

    private SentimentModel(TwoDimensionalMap<String, String, SimpleMatrix> twoDimensionalMap, TwoDimensionalMap<String, String, SimpleTensor> twoDimensionalMap2, TwoDimensionalMap<String, String, SimpleMatrix> twoDimensionalMap3, Map<String, SimpleMatrix> map, Map<String, SimpleMatrix> map2, RNNOptions rNNOptions) {
        this.op = rNNOptions;
        this.binaryTransform = twoDimensionalMap;
        this.binaryTensors = twoDimensionalMap2;
        this.binaryClassification = twoDimensionalMap3;
        this.unaryClassification = map;
        this.wordVectors = map2;
        this.numClasses = rNNOptions.numClasses;
        if (rNNOptions.numHid <= 0) {
            int i = 0;
            Iterator<SimpleMatrix> it = map2.values().iterator();
            while (it.hasNext()) {
                i = it.next().getNumElements();
            }
            this.numHid = i;
        } else {
            this.numHid = rNNOptions.numHid;
        }
        this.numBinaryMatrices = twoDimensionalMap.size();
        this.binaryTransformSize = this.numHid * ((2 * this.numHid) + 1);
        if (rNNOptions.useTensors) {
            this.binaryTensorSize = this.numHid * this.numHid * this.numHid * 4;
        } else {
            this.binaryTensorSize = 0;
        }
        this.binaryClassificationSize = rNNOptions.combineClassification ? 0 : this.numClasses * (this.numHid + 1);
        this.numUnaryMatrices = map.size();
        this.unaryClassificationSize = this.numClasses * (this.numHid + 1);
        this.rand = new Random(rNNOptions.randomSeed);
        this.identity = SimpleMatrix.identity(this.numHid);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public SentimentModel(RNNOptions rNNOptions, List<Tree> list) {
        this.op = rNNOptions;
        this.rand = new Random(rNNOptions.randomSeed);
        if (rNNOptions.randomWordVectors) {
            initRandomWordVectors(list);
        } else {
            readWordVectors();
        }
        if (rNNOptions.numHid > 0) {
            this.numHid = rNNOptions.numHid;
        } else {
            int i = 0;
            Iterator<SimpleMatrix> it = this.wordVectors.values().iterator();
            this.numHid = it.hasNext() ? it.next().getNumElements() : i;
        }
        TwoDimensionalSet hashSet = TwoDimensionalSet.hashSet();
        if (!rNNOptions.simplifiedModel) {
            throw new UnsupportedOperationException("Not yet implemented");
        }
        hashSet.add("", "");
        Set newHashSet = Generics.newHashSet();
        if (!rNNOptions.simplifiedModel) {
            throw new UnsupportedOperationException("Not yet implemented");
        }
        newHashSet.add("");
        this.numClasses = rNNOptions.numClasses;
        this.identity = SimpleMatrix.identity(this.numHid);
        this.binaryTransform = TwoDimensionalMap.treeMap();
        this.binaryTensors = TwoDimensionalMap.treeMap();
        this.binaryClassification = TwoDimensionalMap.treeMap();
        Iterator it2 = hashSet.iterator();
        while (it2.hasNext()) {
            Pair pair = (Pair) it2.next();
            String basicCategory = basicCategory((String) pair.first);
            String basicCategory2 = basicCategory((String) pair.second);
            if (!this.binaryTransform.contains(basicCategory, basicCategory2)) {
                this.binaryTransform.put(basicCategory, basicCategory2, randomTransformMatrix());
                if (rNNOptions.useTensors) {
                    this.binaryTensors.put(basicCategory, basicCategory2, randomBinaryTensor());
                }
                if (!rNNOptions.combineClassification) {
                    this.binaryClassification.put(basicCategory, basicCategory2, randomClassificationMatrix());
                }
            }
        }
        this.numBinaryMatrices = this.binaryTransform.size();
        this.binaryTransformSize = this.numHid * ((2 * this.numHid) + 1);
        if (rNNOptions.useTensors) {
            this.binaryTensorSize = this.numHid * this.numHid * this.numHid * 4;
        } else {
            this.binaryTensorSize = 0;
        }
        this.binaryClassificationSize = rNNOptions.combineClassification ? 0 : this.numClasses * (this.numHid + 1);
        this.unaryClassification = Generics.newTreeMap();
        Iterator it3 = newHashSet.iterator();
        while (it3.hasNext()) {
            String basicCategory3 = basicCategory((String) it3.next());
            if (!this.unaryClassification.containsKey(basicCategory3)) {
                this.unaryClassification.put(basicCategory3, randomClassificationMatrix());
            }
        }
        this.numUnaryMatrices = this.unaryClassification.size();
        this.unaryClassificationSize = this.numClasses * (this.numHid + 1);
    }

    SimpleTensor randomBinaryTensor() {
        double d = 1.0d / (4.0d * this.numHid);
        return SimpleTensor.random(this.numHid * 2, this.numHid * 2, this.numHid, -d, d, this.rand).scale(this.op.trainOptions.scalingForInit);
    }

    SimpleMatrix randomTransformMatrix() {
        SimpleMatrix simpleMatrix = new SimpleMatrix(this.numHid, (this.numHid * 2) + 1);
        simpleMatrix.insertIntoThis(0, 0, randomTransformBlock());
        simpleMatrix.insertIntoThis(0, this.numHid, randomTransformBlock());
        return simpleMatrix.scale(this.op.trainOptions.scalingForInit);
    }

    SimpleMatrix randomTransformBlock() {
        double sqrt = 1.0d / (Math.sqrt(this.numHid) * 2.0d);
        return SimpleMatrix.random(this.numHid, this.numHid, -sqrt, sqrt, this.rand).plus(this.identity);
    }

    SimpleMatrix randomClassificationMatrix() {
        SimpleMatrix simpleMatrix = new SimpleMatrix(this.numClasses, this.numHid + 1);
        double sqrt = 1.0d / Math.sqrt(this.numHid);
        simpleMatrix.insertIntoThis(0, 0, SimpleMatrix.random(this.numClasses, this.numHid, -sqrt, sqrt, this.rand));
        return simpleMatrix.scale(this.op.trainOptions.scalingForInit);
    }

    SimpleMatrix randomWordVector() {
        return randomWordVector(this.op.numHid, this.rand);
    }

    static SimpleMatrix randomWordVector(int i, Random random) {
        return NeuralUtils.randomGaussian(i, 1, random);
    }

    void initRandomWordVectors(List<Tree> list) {
        if (this.op.numHid == 0) {
            throw new RuntimeException("Cannot create random word vectors for an unknown numHid");
        }
        Set newHashSet = Generics.newHashSet();
        newHashSet.add("*UNK*");
        Iterator<Tree> it = list.iterator();
        while (it.hasNext()) {
            Iterator it2 = it.next().getLeaves().iterator();
            while (it2.hasNext()) {
                String value = ((Tree) it2.next()).label().value();
                if (this.op.lowercaseWordVectors) {
                    value = value.toLowerCase();
                }
                newHashSet.add(value);
            }
        }
        this.wordVectors = Generics.newTreeMap();
        Iterator it3 = newHashSet.iterator();
        while (it3.hasNext()) {
            this.wordVectors.put((String) it3.next(), randomWordVector());
        }
    }

    void readWordVectors() {
        Embedding embedding = new Embedding(this.op.wordVectors, this.op.numHid);
        this.wordVectors = Generics.newTreeMap();
        for (String str : embedding.keySet()) {
            this.wordVectors.put(str, embedding.get(str));
        }
        SimpleMatrix simpleMatrix = this.wordVectors.get(this.op.unkWord);
        this.wordVectors.put("*UNK*", simpleMatrix);
        if (simpleMatrix == null) {
            throw new RuntimeException("Unknown word vector not specified in the word vector file");
        }
    }

    public int totalParamSize() {
        return (this.numBinaryMatrices * (this.binaryTransformSize + this.binaryClassificationSize + this.binaryTensorSize)) + (this.numUnaryMatrices * this.unaryClassificationSize) + (this.wordVectors.size() * this.numHid);
    }

    public double[] paramsToVector() {
        return NeuralUtils.paramsToVector(totalParamSize(), this.binaryTransform.valueIterator(), this.binaryClassification.valueIterator(), SimpleTensor.iteratorSimpleMatrix(this.binaryTensors.valueIterator()), this.unaryClassification.values().iterator(), this.wordVectors.values().iterator());
    }

    public void vectorToParams(double[] dArr) {
        NeuralUtils.vectorToParams(dArr, this.binaryTransform.valueIterator(), this.binaryClassification.valueIterator(), SimpleTensor.iteratorSimpleMatrix(this.binaryTensors.valueIterator()), this.unaryClassification.values().iterator(), this.wordVectors.values().iterator());
    }

    public SimpleMatrix getWForNode(Tree tree) {
        if (tree.children().length == 2) {
            return this.binaryTransform.get(basicCategory(tree.children()[0].value()), basicCategory(tree.children()[1].value()));
        }
        if (tree.children().length == 1) {
            throw new AssertionError("No unary transform matrices, only unary classification");
        }
        throw new AssertionError("Unexpected tree children size of " + tree.children().length);
    }

    public SimpleTensor getTensorForNode(Tree tree) {
        if (!this.op.useTensors) {
            throw new AssertionError("Not using tensors");
        }
        if (tree.children().length == 2) {
            return this.binaryTensors.get(basicCategory(tree.children()[0].value()), basicCategory(tree.children()[1].value()));
        }
        if (tree.children().length == 1) {
            throw new AssertionError("No unary transform matrices, only unary classification");
        }
        throw new AssertionError("Unexpected tree children size of " + tree.children().length);
    }

    public SimpleMatrix getClassWForNode(Tree tree) {
        if (this.op.combineClassification) {
            return this.unaryClassification.get("");
        }
        if (tree.children().length == 2) {
            return this.binaryClassification.get(basicCategory(tree.children()[0].value()), basicCategory(tree.children()[1].value()));
        }
        if (tree.children().length != 1) {
            throw new AssertionError("Unexpected tree children size of " + tree.children().length);
        }
        return this.unaryClassification.get(basicCategory(tree.children()[0].value()));
    }

    public SimpleMatrix getWordVector(String str) {
        return this.wordVectors.get(getVocabWord(str));
    }

    public String getVocabWord(String str) {
        if (this.op.lowercaseWordVectors) {
            str = str.toLowerCase();
        }
        return this.wordVectors.containsKey(str) ? str : "*UNK*";
    }

    public String basicCategory(String str) {
        if (this.op.simplifiedModel) {
            return "";
        }
        String basicCategory = this.op.langpack.basicCategory(str);
        if (basicCategory.length() > 0 && basicCategory.charAt(0) == '@') {
            basicCategory = basicCategory.substring(1);
        }
        return basicCategory;
    }

    public SimpleMatrix getUnaryClassification(String str) {
        return this.unaryClassification.get(basicCategory(str));
    }

    public SimpleMatrix getBinaryClassification(String str, String str2) {
        if (this.op.combineClassification) {
            return this.unaryClassification.get("");
        }
        return this.binaryClassification.get(basicCategory(str), basicCategory(str2));
    }

    public SimpleMatrix getBinaryTransform(String str, String str2) {
        return this.binaryTransform.get(basicCategory(str), basicCategory(str2));
    }

    public SimpleTensor getBinaryTensor(String str, String str2) {
        return this.binaryTensors.get(basicCategory(str), basicCategory(str2));
    }

    public void saveSerialized(String str) {
        try {
            IOUtils.writeObjectToFile(this, str);
        } catch (IOException e) {
            throw new RuntimeIOException(e);
        }
    }

    public static SentimentModel loadSerialized(String str) {
        try {
            return (SentimentModel) IOUtils.readObjectFromURLOrClasspathOrFileSystem(str);
        } catch (IOException e) {
            throw new RuntimeIOException(e);
        } catch (ClassNotFoundException e2) {
            throw new RuntimeIOException(e2);
        }
    }

    public void printParamInformation(int i) {
        int i2 = 0;
        Iterator<TwoDimensionalMap.Entry<String, String, SimpleMatrix>> it = this.binaryTransform.iterator();
        while (it.hasNext()) {
            TwoDimensionalMap.Entry<String, String, SimpleMatrix> next = it.next();
            if (i2 <= i && i2 + next.getValue().getNumElements() > i) {
                System.err.println("Index " + i + " is element " + (i - i2) + " of binaryTransform \"" + next.getFirstKey() + "," + next.getSecondKey() + "\"");
                return;
            }
            i2 += next.getValue().getNumElements();
        }
        Iterator<TwoDimensionalMap.Entry<String, String, SimpleMatrix>> it2 = this.binaryClassification.iterator();
        while (it2.hasNext()) {
            TwoDimensionalMap.Entry<String, String, SimpleMatrix> next2 = it2.next();
            if (i2 <= i && i2 + next2.getValue().getNumElements() > i) {
                System.err.println("Index " + i + " is element " + (i - i2) + " of binaryClassification \"" + next2.getFirstKey() + "," + next2.getSecondKey() + "\"");
                return;
            }
            i2 += next2.getValue().getNumElements();
        }
        Iterator<TwoDimensionalMap.Entry<String, String, SimpleTensor>> it3 = this.binaryTensors.iterator();
        while (it3.hasNext()) {
            TwoDimensionalMap.Entry<String, String, SimpleTensor> next3 = it3.next();
            if (i2 <= i && i2 + next3.getValue().getNumElements() > i) {
                System.err.println("Index " + i + " is element " + (i - i2) + " of binaryTensor \"" + next3.getFirstKey() + "," + next3.getSecondKey() + "\"");
                return;
            }
            i2 += next3.getValue().getNumElements();
        }
        for (Map.Entry<String, SimpleMatrix> entry : this.unaryClassification.entrySet()) {
            if (i2 <= i && i2 + entry.getValue().getNumElements() > i) {
                System.err.println("Index " + i + " is element " + (i - i2) + " of unaryClassification \"" + entry.getKey() + "\"");
                return;
            }
            i2 += entry.getValue().getNumElements();
        }
        for (Map.Entry<String, SimpleMatrix> entry2 : this.wordVectors.entrySet()) {
            if (i2 <= i && i2 + entry2.getValue().getNumElements() > i) {
                System.err.println("Index " + i + " is element " + (i - i2) + " of wordVector \"" + entry2.getKey() + "\"");
                return;
            }
            i2 += entry2.getValue().getNumElements();
        }
        System.err.println("Index " + i + " is beyond the length of the parameters; total parameter space was " + totalParamSize());
    }
}
