package com.intel.analytics.bigdl.nn.ops;

import com.intel.analytics.bigdl.tensor.Tensor;
import com.intel.analytics.bigdl.tensor.Tensor$;
import com.intel.analytics.bigdl.tensor.TensorMath;
import com.intel.analytics.bigdl.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.utils.T$;
import com.intel.analytics.bigdl.utils.Table;
import scala.Array$;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.mutable.ArrayOps;
import scala.math.Numeric$IntIsIntegral$;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.reflect.package$;
import scala.runtime.BoxesRunTime;
import scala.runtime.ObjectRef;

/* compiled from: BatchMatMul.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005mg\u0001B\u000b\u0017\u0001\rB\u0001B\u0012\u0001\u0003\u0006\u0004%\ta\u0012\u0005\t\u0017\u0002\u0011\t\u0011)A\u0005\u0011\"AA\n\u0001BC\u0002\u0013\u0005q\t\u0003\u0005N\u0001\t\u0005\t\u0015!\u0003I\u0011!q\u0005AaA!\u0002\u0017y\u0005\u0002C+\u0001\u0005\u0007\u0005\u000b1\u0002,\t\u0011]\u0003!\u0011!Q\u0001\faC\u0001\u0002\u001c\u0001\u0003\u0002\u0003\u0006Y!\u001c\u0005\u0006]\u0002!\ta\u001c\u0005\u0006q\u0002!\t%\u001f\u0005\u0006y\u0002!\t%`\u0004\b\u0003?1\u0002\u0012AA\u0011\r\u0019)b\u0003#\u0001\u0002$!1a.\u0004C\u0001\u0003cAq!a\r\u000e\t\u0003\t)\u0004C\u0005\u0002~5\t\n\u0011\"\u0001\u0002��!I\u0011QU\u0007\u0012\u0002\u0013\u0005\u0011q\u0015\u0005\n\u0003ok\u0011\u0013!C\u0001\u0003sC\u0011\"a0\u000e#\u0003%\t!!1\t\u0013\u0005\u001dW\"!A\u0005\n\u0005%'a\u0003\"bi\u000eDW*\u0019;Nk2T!a\u0006\r\u0002\u0007=\u00048O\u0003\u0002\u001a5\u0005\u0011aN\u001c\u0006\u00037q\tQAY5hI2T!!\b\u0010\u0002\u0013\u0005t\u0017\r\\=uS\u000e\u001c(BA\u0010!\u0003\u0015Ig\u000e^3m\u0015\u0005\t\u0013aA2p[\u000e\u0001Qc\u0001\u0013EoM\u0011\u0001!\n\t\u0006M\u001dJsfQ\u0007\u0002-%\u0011\u0001F\u0006\u0002\n\u001fB,'/\u0019;j_:\u0004\"AK\u0017\u000e\u0003-R!\u0001\f\u000e\u0002\u000bU$\u0018\u000e\\:\n\u00059Z#!\u0002+bE2,\u0007c\u0001\u00194k5\t\u0011G\u0003\u000235\u00051A/\u001a8t_JL!\u0001N\u0019\u0003\rQ+gn]8s!\t1t\u0007\u0004\u0001\u0005\u000ba\u0002!\u0019A\u001d\u0003\u0003\u0011\u000b\"A\u000f!\u0011\u0005mrT\"\u0001\u001f\u000b\u0003u\nQa]2bY\u0006L!a\u0010\u001f\u0003\u000f9{G\u000f[5oOB\u00111(Q\u0005\u0003\u0005r\u00121!\u00118z!\t1D\tB\u0003F\u0001\t\u0007\u0011HA\u0001U\u0003\u0011\tGM\u001b-\u0016\u0003!\u0003\"aO%\n\u0005)c$a\u0002\"p_2,\u0017M\\\u0001\u0006C\u0012T\u0007\fI\u0001\u0005C\u0012T\u0017,A\u0003bI*L\u0006%\u0001\u0006fm&$WM\\2fIE\u00022\u0001U*D\u001b\u0005\t&B\u0001*=\u0003\u001d\u0011XM\u001a7fGRL!\u0001V)\u0003\u0011\rc\u0017m]:UC\u001e\f!\"\u001a<jI\u0016t7-\u001a\u00133!\r\u00016+N\u0001\u0003KZ\u00042!W5D\u001d\tQvM\u0004\u0002\\M:\u0011A,\u001a\b\u0003;\u0012t!AX2\u000f\u0005}\u0013W\"\u00011\u000b\u0005\u0005\u0014\u0013A\u0002\u001fs_>$h(C\u0001\"\u0013\ty\u0002%\u0003\u0002\u001e=%\u00111\u0004H\u0005\u0003eiI!\u0001[\u0019\u0002#Q+gn]8s\u001dVlWM]5d\u001b\u0006$\b.\u0003\u0002kW\niA+\u001a8t_JtU/\\3sS\u000eT!\u0001[\u0019\u0002\u0007\u00154(\u0007E\u0002ZSV\na\u0001P5oSRtDc\u00019woR)\u0011O]:ukB!a\u0005A\"6\u0011\u0015q\u0015\u0002q\u0001P\u0011\u0015)\u0016\u0002q\u0001W\u0011\u00159\u0016\u0002q\u0001Y\u0011\u0015a\u0017\u0002q\u0001n\u0011\u001d1\u0015\u0002%AA\u0002!Cq\u0001T\u0005\u0011\u0002\u0003\u0007\u0001*\u0001\u0007va\u0012\fG/Z(viB,H\u000f\u0006\u00020u\")1P\u0003a\u0001S\u0005)\u0011N\u001c9vi\u0006\u0019r-\u001a;DY\u0006\u001c8\u000fV1h\u001dVlWM]5dgR\ta\u0010\u0005\u0004<\u007f\u0006\r\u00111C\u0005\u0004\u0003\u0003a$A\u0002+va2,'\u0007E\u0003<\u0003\u000b\tI!C\u0002\u0002\bq\u0012Q!\u0011:sCf\u0004D!a\u0003\u0002\u0010A!\u0001kUA\u0007!\r1\u0014q\u0002\u0003\u000b\u0003#Y\u0011\u0011!A\u0001\u0006\u0003I$aA0%cA)1(!\u0002\u0002\u0016A\"\u0011qCA\u000e!\u0011I\u0016.!\u0007\u0011\u0007Y\nY\u0002\u0002\u0006\u0002\u001e-\t\t\u0011!A\u0003\u0002e\u00121a\u0018\u00133\u0003-\u0011\u0015\r^2i\u001b\u0006$X*\u001e7\u0011\u0005\u0019j1#B\u0007\u0002&\u0005-\u0002cA\u001e\u0002(%\u0019\u0011\u0011\u0006\u001f\u0003\r\u0005s\u0017PU3g!\rY\u0014QF\u0005\u0004\u0003_a$\u0001D*fe&\fG.\u001b>bE2,GCAA\u0011\u0003\u0015\t\u0007\u000f\u001d7z+\u0019\t9$a\u0010\u0002dQ1\u0011\u0011HA=\u0003w\"\"\"a\u000f\u0002f\u0005-\u0014\u0011OA;!\u00191\u0003!!\u0010\u0002bA\u0019a'a\u0010\u0005\u0013\u0015{\u0001\u0015!A\u0001\u0006\u0004I\u0004\u0006CA \u0003\u0007\nI%a\u0016\u0011\u0007m\n)%C\u0002\u0002Hq\u00121b\u001d9fG&\fG.\u001b>fIFJ1%a\u0013\u0002N\u0005E\u0013q\n\b\u0004w\u00055\u0013bAA(y\u0005)a\t\\8biF2A%a\u0015\u0002Vur1aXA+\u0013\u0005i\u0014'C\u0012\u0002Z\u0005m\u0013qLA/\u001d\rY\u00141L\u0005\u0004\u0003;b\u0014A\u0002#pk\ndW-\r\u0004%\u0003'\n)&\u0010\t\u0004m\u0005\rD!\u0002\u001d\u0010\u0005\u0004I\u0004\"CA4\u001f\u0005\u0005\t9AA5\u0003))g/\u001b3f]\u000e,Ge\r\t\u0005!N\u000bi\u0004C\u0005\u0002n=\t\t\u0011q\u0001\u0002p\u0005QQM^5eK:\u001cW\r\n\u001b\u0011\tA\u001b\u0016\u0011\r\u0005\u0007/>\u0001\u001d!a\u001d\u0011\teK\u0017Q\b\u0005\u0007Y>\u0001\u001d!a\u001e\u0011\teK\u0017\u0011\r\u0005\b\r>\u0001\n\u00111\u0001I\u0011\u001dau\u0002%AA\u0002!\u000bq\"\u00199qYf$C-\u001a4bk2$H%M\u000b\u0007\u0003\u0003\u000b9*a)\u0016\u0005\u0005\r%f\u0001%\u0002\u0006.\u0012\u0011q\u0011\t\u0005\u0003\u0013\u000b\u0019*\u0004\u0002\u0002\f*!\u0011QRAH\u0003%)hn\u00195fG.,GMC\u0002\u0002\u0012r\n!\"\u00198o_R\fG/[8o\u0013\u0011\t)*a#\u0003#Ut7\r[3dW\u0016$g+\u0019:jC:\u001cW\rB\u0005F!\u0001\u0006\t\u0011!b\u0001s!B\u0011qSA\"\u00037\u000by*M\u0005$\u0003\u0017\ni%!(\u0002PE2A%a\u0015\u0002Vu\n\u0014bIA-\u00037\n\t+!\u00182\r\u0011\n\u0019&!\u0016>\t\u0015A\u0004C1\u0001:\u0003=\t\u0007\u000f\u001d7zI\u0011,g-Y;mi\u0012\u0012TCBAA\u0003S\u000b)\fB\u0005F#\u0001\u0006\t\u0011!b\u0001s!B\u0011\u0011VA\"\u0003[\u000b\t,M\u0005$\u0003\u0017\ni%a,\u0002PE2A%a\u0015\u0002Vu\n\u0014bIA-\u00037\n\u0019,!\u00182\r\u0011\n\u0019&!\u0016>\t\u0015A\u0014C1\u0001:\u0003m!C.Z:tS:LG\u000fJ4sK\u0006$XM\u001d\u0013eK\u001a\fW\u000f\u001c;%cU1\u0011\u0011QA^\u0003{#Q!\u0012\nC\u0002e\"Q\u0001\u000f\nC\u0002e\n1\u0004\n7fgNLg.\u001b;%OJ,\u0017\r^3sI\u0011,g-Y;mi\u0012\u0012TCBAA\u0003\u0007\f)\rB\u0003F'\t\u0007\u0011\bB\u00039'\t\u0007\u0011(A\u0006sK\u0006$'+Z:pYZ,GCAAf!\u0011\ti-a6\u000e\u0005\u0005='\u0002BAi\u0003'\fA\u0001\\1oO*\u0011\u0011Q[\u0001\u0005U\u00064\u0018-\u0003\u0003\u0002Z\u0006='AB(cU\u0016\u001cG\u000f")
/* loaded from: input_file:com/intel/analytics/bigdl/nn/ops/BatchMatMul.class */
public class BatchMatMul<T, D> extends Operation<Table, Tensor<D>, T> {
    private final boolean adjX;
    private final boolean adjY;
    private final ClassTag<T> evidence$1;
    private final ClassTag<D> evidence$2;
    private final TensorNumericMath.TensorNumeric<T> ev;
    private final TensorNumericMath.TensorNumeric<D> ev2;

    public boolean adjX() {
        return this.adjX;
    }

    public boolean adjY() {
        return this.adjY;
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    public Tensor<D> updateOutput(Table table) {
        ObjectRef create = ObjectRef.create((Tensor) table.apply(BoxesRunTime.boxToInteger(1)));
        ObjectRef create2 = ObjectRef.create((Tensor) table.apply(BoxesRunTime.boxToInteger(2)));
        Predef$.MODULE$.require(((Tensor) create.elem).dim() == ((Tensor) create2.elem).dim(), () -> {
            return "tensor x and tensor y must have the same number of dims";
        });
        Predef$.MODULE$.require(((Tensor) create.elem).dim() >= 2, () -> {
            return "tensor dim num must be at least 2";
        });
        if (((Tensor) create.elem).dim() == 2) {
            Predef$.MODULE$.require(((Tensor) create2.elem).dim() == 2, () -> {
                return new StringBuilder(47).append("second input tensor must be 2D").append("second input dim ").append(((Tensor) create2.elem).dim()).toString();
            });
            if (adjX()) {
                create.elem = ((Tensor) create.elem).t();
            }
            if (adjY()) {
                create2.elem = ((Tensor) create2.elem).t();
            }
            Predef$.MODULE$.require(((Tensor) create.elem).size(2) == ((Tensor) create2.elem).size(1), () -> {
                return new StringBuilder(44).append("matrix sizes do not match").append("The sizes are ").append(((Tensor) create.elem).size(2)).append(" and ").append(((Tensor) create2.elem).size(1)).toString();
            });
            ((Tensor) output()).resize(((Tensor) create.elem).size(1), ((Tensor) create2.elem).size(2));
            ((TensorMath) output()).mm((Tensor) create.elem, (Tensor) create2.elem);
        } else {
            Predef$.MODULE$.require(((Tensor) create.elem).size(1) == ((Tensor) create2.elem).size(1), () -> {
                return new StringBuilder(82).append("inputs must contain the same number of minibatches").append("The minibatces of each are ").append(((Tensor) create.elem).size(1)).append(" and ").append(((Tensor) create2.elem).size(1)).toString();
            });
            int dim = ((Tensor) create.elem).dim();
            int unboxToInt = BoxesRunTime.unboxToInt(new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps((int[]) new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps(((Tensor) create.elem).size())).slice(0, dim - 2))).product(Numeric$IntIsIntegral$.MODULE$));
            ObjectRef create3 = ObjectRef.create(((Tensor) create.elem).view(new int[]{unboxToInt, ((Tensor) create.elem).size(dim - 1), ((Tensor) create.elem).size(dim)}));
            ObjectRef create4 = ObjectRef.create(((Tensor) create2.elem).view(new int[]{unboxToInt, ((Tensor) create2.elem).size(dim - 1), ((Tensor) create2.elem).size(dim)}));
            if (adjX()) {
                create3.elem = ((Tensor) create3.elem).transpose(2, 3);
            }
            if (adjY()) {
                create4.elem = ((Tensor) create4.elem).transpose(2, 3);
            }
            Predef$.MODULE$.require(((Tensor) create3.elem).size(3) == ((Tensor) create4.elem).size(2), () -> {
                return new StringBuilder(51).append("matrix sizes do not match").append("the matrix sizes are ").append(((Tensor) create3.elem).size(2)).append(" and ").append(((Tensor) create4.elem).size(3)).toString();
            });
            ((Tensor) output()).resize(unboxToInt, ((Tensor) create3.elem).size(2), ((Tensor) create4.elem).size(3));
            ((TensorMath) output()).bmm((Tensor) create3.elem, (Tensor) create4.elem);
            int[] iArr = (int[]) new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps((int[]) new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps(((Tensor) create.elem).size())).slice(0, dim - 2))).$plus$plus(new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps(new int[]{((Tensor) create3.elem).size(2), ((Tensor) create4.elem).size(3)})), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Int()));
            Tensor tensor = (Tensor) output();
            tensor.resize(iArr, tensor.resize$default$2());
        }
        return (Tensor) output();
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    public Tuple2<ClassTag<?>[], TensorNumericMath.TensorNumeric<?>[]> getClassTagNumerics() {
        return new Tuple2<>(new ClassTag[]{package$.MODULE$.classTag(this.evidence$1), package$.MODULE$.classTag(this.evidence$2)}, new TensorNumericMath.TensorNumeric[]{this.ev, this.ev2});
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public BatchMatMul(boolean z, boolean z2, ClassTag<T> classTag, ClassTag<D> classTag2, TensorNumericMath.TensorNumeric<T> tensorNumeric, TensorNumericMath.TensorNumeric<D> tensorNumeric2) {
        super(ClassTag$.MODULE$.apply(Table.class), ClassTag$.MODULE$.apply(Tensor.class), classTag, tensorNumeric);
        this.adjX = z;
        this.adjY = z2;
        this.evidence$1 = classTag;
        this.evidence$2 = classTag2;
        this.ev = tensorNumeric;
        this.ev2 = tensorNumeric2;
        gradInput_$eq(T$.MODULE$.apply(Tensor$.MODULE$.apply(classTag2, tensorNumeric2), (Seq<Object>) Predef$.MODULE$.genericWrapArray(new Object[]{Tensor$.MODULE$.apply(classTag2, tensorNumeric2)})));
        output_$eq(Tensor$.MODULE$.apply(classTag2, tensorNumeric2));
    }
}
