package edu.stanford.nlp.neural;

import java.io.Serializable;
import java.util.Arrays;
import java.util.Iterator;
import java.util.NoSuchElementException;
import java.util.Random;
import org.ejml.simple.SimpleMatrix;

/* loaded from: input_file:edu/stanford/nlp/neural/SimpleTensor.class */
public class SimpleTensor implements Serializable {
    private final SimpleMatrix[] slices;
    final int numRows;
    final int numCols;
    final int numSlices;
    private static final long serialVersionUID = 1;

    /* loaded from: input_file:edu/stanford/nlp/neural/SimpleTensor$SimpleMatrixIteratorWrapper.class */
    private static class SimpleMatrixIteratorWrapper implements Iterator<SimpleMatrix> {
        Iterator<SimpleTensor> tensors;
        Iterator<SimpleMatrix> currentIterator;

        public SimpleMatrixIteratorWrapper(Iterator<SimpleTensor> it) {
            this.tensors = it;
            advanceIterator();
        }

        @Override // java.util.Iterator
        public boolean hasNext() {
            if (this.currentIterator == null) {
                return false;
            }
            if (this.currentIterator.hasNext()) {
                return true;
            }
            advanceIterator();
            return this.currentIterator != null;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.Iterator
        public SimpleMatrix next() {
            if (this.currentIterator != null && this.currentIterator.hasNext()) {
                return this.currentIterator.next();
            }
            advanceIterator();
            if (this.currentIterator != null) {
                return this.currentIterator.next();
            }
            throw new NoSuchElementException();
        }

        private void advanceIterator() {
            if (this.currentIterator == null || !this.currentIterator.hasNext()) {
                while (this.tensors.hasNext()) {
                    this.currentIterator = this.tensors.next().iteratorSimpleMatrix();
                    if (this.currentIterator.hasNext()) {
                        return;
                    }
                }
                this.currentIterator = null;
            }
        }

        @Override // java.util.Iterator
        public void remove() {
            throw new UnsupportedOperationException();
        }
    }

    public SimpleTensor(int i, int i2, int i3) {
        this.slices = new SimpleMatrix[i3];
        for (int i4 = 0; i4 < i3; i4++) {
            this.slices[i4] = new SimpleMatrix(i, i2);
        }
        this.numRows = i;
        this.numCols = i2;
        this.numSlices = i3;
    }

    public SimpleTensor(SimpleMatrix[] simpleMatrixArr) {
        this.numRows = simpleMatrixArr[0].numRows();
        this.numCols = simpleMatrixArr[0].numCols();
        this.numSlices = simpleMatrixArr.length;
        this.slices = new SimpleMatrix[simpleMatrixArr.length];
        for (int i = 0; i < this.numSlices; i++) {
            if (simpleMatrixArr[i].numRows() != this.numRows || simpleMatrixArr[i].numCols() != this.numCols) {
                throw new IllegalArgumentException("Slice " + i + " has matrix dimensions " + simpleMatrixArr[i].numRows() + "," + simpleMatrixArr[i].numCols() + ", expected " + this.numRows + "," + this.numCols);
            }
            this.slices[i] = new SimpleMatrix(simpleMatrixArr[i]);
        }
    }

    public static SimpleTensor random(int i, int i2, int i3, double d, double d2, Random random) {
        SimpleTensor simpleTensor = new SimpleTensor(i, i2, i3);
        for (int i4 = 0; i4 < i3; i4++) {
            simpleTensor.slices[i4] = SimpleMatrix.random(i, i2, d, d2, random);
        }
        return simpleTensor;
    }

    public int numRows() {
        return this.numRows;
    }

    public int numCols() {
        return this.numCols;
    }

    public int numSlices() {
        return this.numSlices;
    }

    public int getNumElements() {
        return this.numRows * this.numCols * this.numSlices;
    }

    public void set(double d) {
        for (int i = 0; i < this.numSlices; i++) {
            this.slices[i].set(d);
        }
    }

    public SimpleTensor scale(double d) {
        SimpleTensor simpleTensor = new SimpleTensor(this.numRows, this.numCols, this.numSlices);
        for (int i = 0; i < this.numSlices; i++) {
            simpleTensor.slices[i] = (SimpleMatrix) this.slices[i].scale(d);
        }
        return simpleTensor;
    }

    public SimpleTensor plus(SimpleTensor simpleTensor) {
        if (simpleTensor.numRows != this.numRows || simpleTensor.numCols != this.numCols || simpleTensor.numSlices != this.numSlices) {
            throw new IllegalArgumentException("Sizes of tensors do not match.  Our size: " + this.numRows + "," + this.numCols + "," + this.numSlices + "; other size " + simpleTensor.numRows + "," + simpleTensor.numCols + "," + simpleTensor.numSlices);
        }
        SimpleTensor simpleTensor2 = new SimpleTensor(this.numRows, this.numCols, this.numSlices);
        for (int i = 0; i < this.numSlices; i++) {
            simpleTensor2.slices[i] = (SimpleMatrix) this.slices[i].plus(simpleTensor.slices[i]);
        }
        return simpleTensor2;
    }

    public SimpleTensor elementMult(SimpleTensor simpleTensor) {
        if (simpleTensor.numRows != this.numRows || simpleTensor.numCols != this.numCols || simpleTensor.numSlices != this.numSlices) {
            throw new IllegalArgumentException("Sizes of tensors do not match.  Our size: " + this.numRows + "," + this.numCols + "," + this.numSlices + "; other size " + simpleTensor.numRows + "," + simpleTensor.numCols + "," + simpleTensor.numSlices);
        }
        SimpleTensor simpleTensor2 = new SimpleTensor(this.numRows, this.numCols, this.numSlices);
        for (int i = 0; i < this.numSlices; i++) {
            simpleTensor2.slices[i] = (SimpleMatrix) this.slices[i].elementMult(simpleTensor.slices[i]);
        }
        return simpleTensor2;
    }

    public double elementSum() {
        double d = 0.0d;
        for (SimpleMatrix simpleMatrix : this.slices) {
            d += simpleMatrix.elementSum();
        }
        return d;
    }

    public void setSlice(int i, SimpleMatrix simpleMatrix) {
        if (i < 0 || i >= this.numSlices) {
            throw new IllegalArgumentException("Unexpected slice number " + i + " for tensor with " + this.numSlices + " slices");
        }
        if (simpleMatrix.numCols() != this.numCols) {
            throw new IllegalArgumentException("Incompatible matrix size.  Has " + simpleMatrix.numCols() + " columns, tensor has " + this.numCols);
        }
        if (simpleMatrix.numRows() != this.numRows) {
            throw new IllegalArgumentException("Incompatible matrix size.  Has " + simpleMatrix.numRows() + " columns, tensor has " + this.numRows);
        }
        this.slices[i] = simpleMatrix;
    }

    public SimpleMatrix getSlice(int i) {
        if (i < 0 || i >= this.numSlices) {
            throw new IllegalArgumentException("Unexpected slice number " + i + " for tensor with " + this.numSlices + " slices");
        }
        return this.slices[i];
    }

    public SimpleMatrix bilinearProducts(SimpleMatrix simpleMatrix) {
        if (simpleMatrix.numCols() != 1) {
            throw new AssertionError("Expected a column vector");
        }
        if (simpleMatrix.numRows() != this.numCols) {
            throw new AssertionError("Number of rows in the input does not match number of columns in tensor");
        }
        if (this.numRows != this.numCols) {
            throw new AssertionError("Can only perform this operation on a SimpleTensor with square slices");
        }
        SimpleMatrix transpose = simpleMatrix.transpose();
        SimpleMatrix simpleMatrix2 = new SimpleMatrix(this.numSlices, 1);
        for (int i = 0; i < this.numSlices; i++) {
            simpleMatrix2.set(i, transpose.mult(this.slices[i]).mult(simpleMatrix).get(0));
        }
        return simpleMatrix2;
    }

    public boolean isZero() {
        for (int i = 0; i < this.numSlices; i++) {
            if (!NeuralUtils.isZero(this.slices[i])) {
                return false;
            }
        }
        return true;
    }

    public Iterator<SimpleMatrix> iteratorSimpleMatrix() {
        return Arrays.asList(this.slices).iterator();
    }

    public static Iterator<SimpleMatrix> iteratorSimpleMatrix(Iterator<SimpleTensor> it) {
        return new SimpleMatrixIteratorWrapper(it);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < this.numSlices; i++) {
            sb.append("Slice " + i + "\n");
            sb.append(this.slices[i]);
        }
        return sb.toString();
    }
}
