package com.intel.analytics.bigdl.nn;

import com.intel.analytics.bigdl.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.tensor.Tensor;
import com.intel.analytics.bigdl.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.utils.T$;
import com.intel.analytics.bigdl.utils.Table;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;

/* compiled from: Transformer.scala */
@ScalaSignature(bytes = "\u0006\u0001E4QAC\u0006\u0001\u0017UA\u0001\u0002\u000f\u0001\u0003\u0002\u0003\u0006I!\u000f\u0005\ty\u0001\u0011\t\u0011)A\u0005s!AQ\b\u0001B\u0002B\u0003-a\b\u0003\u0005E\u0001\t\u0005\t\u0015a\u0003F\u0011\u0015I\u0006\u0001\"\u0001[\u0011\u001d\u0011\u0007A1A\u0005\n\rDaa\u001a\u0001!\u0002\u0013!\u0007\"\u00025\u0001\t\u0003J\u0007\"\u00027\u0001\t\u0003j'aC*qY&$H+\u001a8t_JT!\u0001D\u0007\u0002\u00059t'B\u0001\b\u0010\u0003\u0015\u0011\u0017n\u001a3m\u0015\t\u0001\u0012#A\u0005b]\u0006d\u0017\u0010^5dg*\u0011!cE\u0001\u0006S:$X\r\u001c\u0006\u0002)\u0005\u00191m\\7\u0016\u0005Y)3C\u0001\u0001\u0018!\u0015A2$\b\u001a$\u001b\u0005I\"B\u0001\u000e\f\u0003)\t'm\u001d;sC\u000e$hN\\\u0005\u00039e\u0011a\"\u00112tiJ\f7\r^'pIVdW\rE\u0002\u001fC\rj\u0011a\b\u0006\u0003A5\ta\u0001^3og>\u0014\u0018B\u0001\u0012 \u0005\u0019!VM\\:peB\u0011A%\n\u0007\u0001\t\u00151\u0003A1\u0001)\u0005\u0005!6\u0001A\t\u0003S=\u0002\"AK\u0017\u000e\u0003-R\u0011\u0001L\u0001\u0006g\u000e\fG.Y\u0005\u0003]-\u0012qAT8uQ&tw\r\u0005\u0002+a%\u0011\u0011g\u000b\u0002\u0004\u0003:L\bCA\u001a7\u001b\u0005!$BA\u001b\u000e\u0003\u0015)H/\u001b7t\u0013\t9DGA\u0003UC\ndW-A\u0005eS6,gn]5p]B\u0011!FO\u0005\u0003w-\u00121!\u00138u\u0003\rqW/\\\u0001\fKZLG-\u001a8dK\u0012\n\u0004\u0007E\u0002@\u0005\u000ej\u0011\u0001\u0011\u0006\u0003\u0003.\nqA]3gY\u0016\u001cG/\u0003\u0002D\u0001\nA1\t\\1tgR\u000bw-\u0001\u0002fmB\u0019aIV\u0012\u000f\u0005\u001d#fB\u0001%T\u001d\tI%K\u0004\u0002K#:\u00111\n\u0015\b\u0003\u0019>k\u0011!\u0014\u0006\u0003\u001d\u001e\na\u0001\u0010:p_Rt\u0014\"\u0001\u000b\n\u0005I\u0019\u0012B\u0001\t\u0012\u0013\tqq\"\u0003\u0002!\u001b%\u0011QkH\u0001\u0012)\u0016t7o\u001c:Ok6,'/[2NCRD\u0017BA,Y\u00055!VM\\:pe:+X.\u001a:jG*\u0011QkH\u0001\u0007y%t\u0017\u000e\u001e \u0015\u0007m\u0003\u0017\rF\u0002]=~\u00032!\u0018\u0001$\u001b\u0005Y\u0001\"B\u001f\u0006\u0001\bq\u0004\"\u0002#\u0006\u0001\b)\u0005\"\u0002\u001d\u0006\u0001\u0004I\u0004\"\u0002\u001f\u0006\u0001\u0004I\u0014AC5o]\u0016\u0014H*Y=feV\tA\rE\u0002^K\u000eJ!AZ\u0006\u0003\u0013){\u0017N\u001c+bE2,\u0017aC5o]\u0016\u0014H*Y=fe\u0002\nA\"\u001e9eCR,w*\u001e;qkR$\"A\r6\t\u000b-D\u0001\u0019A\u000f\u0002\u000b%t\u0007/\u001e;\u0002\u001fU\u0004H-\u0019;f\u000fJ\fG-\u00138qkR$2!\b8p\u0011\u0015Y\u0017\u00021\u0001\u001e\u0011\u0015\u0001\u0018\u00021\u00013\u0003)9'/\u00193PkR\u0004X\u000f\u001e")
/* loaded from: input_file:com/intel/analytics/bigdl/nn/SplitTensor.class */
public class SplitTensor<T> extends AbstractModule<Tensor<T>, Table, T> {
    private final int dimension;
    private final int num;
    private final TensorNumericMath.TensorNumeric<T> ev;
    private final JoinTable<T> innerLayer;

    private JoinTable<T> innerLayer() {
        return this.innerLayer;
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    public Table updateOutput(Tensor<T> tensor) {
        output_$eq(T$.MODULE$.array(tensor.split(tensor.size(this.dimension) / this.num, this.dimension)));
        return output();
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    /* renamed from: updateGradInput, reason: merged with bridge method [inline-methods] */
    public Tensor<T> updateGradInput2(Tensor<T> tensor, Table table) {
        gradInput_$eq(innerLayer().forward(table).toTensor(this.ev));
        return gradInput();
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public SplitTensor(int i, int i2, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        super(ClassTag$.MODULE$.apply(Tensor.class), ClassTag$.MODULE$.apply(Table.class), classTag, tensorNumeric);
        this.dimension = i;
        this.num = i2;
        this.ev = tensorNumeric;
        this.innerLayer = new JoinTable<>(i, -1, classTag, tensorNumeric);
    }
}
