package com.intel.analytics.bigdl.nn;

import com.intel.analytics.bigdl.nn.abstractnn.TensorModule;
import com.intel.analytics.bigdl.tensor.Tensor;
import com.intel.analytics.bigdl.tensor.Tensor$;
import com.intel.analytics.bigdl.tensor.TensorNumericMath;
import scala.reflect.ClassTag;
import scala.reflect.ScalaSignature;

/* compiled from: Transformer.scala */
@ScalaSignature(bytes = "\u0006\u0001-4Q!\u0003\u0006\u0001\u0015QA\u0001b\u000b\u0001\u0003\u0004\u0003\u0006Y\u0001\f\u0005\te\u0001\u0011\t\u0011)A\u0006g!)\u0011\n\u0001C\u0001\u0015\")\u0001\u000b\u0001C!#\")\u0001\f\u0001C!3\"9Q\f\u0001a\u0001\n\u0013q\u0006bB0\u0001\u0001\u0004%I\u0001\u0019\u0005\u0007M\u0002\u0001\u000b\u0015\u0002*\u0003\u001dA{7/\u001b;j_:,enY8eK*\u00111\u0002D\u0001\u0003]:T!!\u0004\b\u0002\u000b\tLw\r\u001a7\u000b\u0005=\u0001\u0012!C1oC2LH/[2t\u0015\t\t\"#A\u0003j]R,GNC\u0001\u0014\u0003\r\u0019w.\\\u000b\u0003+y\u0019\"\u0001\u0001\f\u0011\u0007]QB$D\u0001\u0019\u0015\tI\"\"\u0001\u0006bEN$(/Y2u]:L!a\u0007\r\u0003\u0019Q+gn]8s\u001b>$W\u000f\\3\u0011\u0005uqB\u0002\u0001\u0003\u0006?\u0001\u0011\r!\t\u0002\u0002)\u000e\u0001\u0011C\u0001\u0012)!\t\u0019c%D\u0001%\u0015\u0005)\u0013!B:dC2\f\u0017BA\u0014%\u0005\u001dqu\u000e\u001e5j]\u001e\u0004\"aI\u0015\n\u0005)\"#aA!os\u0006QQM^5eK:\u001cW\rJ\u001b\u0011\u00075\u0002D$D\u0001/\u0015\tyC%A\u0004sK\u001adWm\u0019;\n\u0005Er#\u0001C\"mCN\u001cH+Y4\u0002\u0005\u00154\bc\u0001\u001bG99\u0011Qg\u0011\b\u0003m\u0005s!a\u000e!\u000f\u0005azdBA\u001d?\u001d\tQT(D\u0001<\u0015\ta\u0004%\u0001\u0004=e>|GOP\u0005\u0002'%\u0011\u0011CE\u0005\u0003\u001fAI!!\u0004\b\n\u0005\tc\u0011A\u0002;f]N|'/\u0003\u0002E\u000b\u0006\tB+\u001a8t_JtU/\\3sS\u000el\u0015\r\u001e5\u000b\u0005\tc\u0011BA$I\u00055!VM\\:pe:+X.\u001a:jG*\u0011A)R\u0001\u0007y%t\u0017\u000e\u001e \u0015\u0003-#2\u0001\u0014(P!\ri\u0005\u0001H\u0007\u0002\u0015!)1f\u0001a\u0002Y!)!g\u0001a\u0002g\u0005aQ\u000f\u001d3bi\u0016|U\u000f\u001e9viR\u0011!K\u0016\t\u0004'RcR\"A#\n\u0005U+%A\u0002+f]N|'\u000fC\u0003X\t\u0001\u0007!+A\u0003j]B,H/A\bva\u0012\fG/Z$sC\u0012Le\u000e];u)\r\u0011&l\u0017\u0005\u0006/\u0016\u0001\rA\u0015\u0005\u00069\u0016\u0001\rAU\u0001\u000bOJ\fGmT;uaV$\u0018a\u0003:b]\u001e,')\u001e4gKJ,\u0012AU\u0001\u0010e\u0006tw-\u001a\"vM\u001a,'o\u0018\u0013fcR\u0011\u0011\r\u001a\t\u0003G\tL!a\u0019\u0013\u0003\tUs\u0017\u000e\u001e\u0005\bK\u001e\t\t\u00111\u0001S\u0003\rAH%M\u0001\re\u0006tw-\u001a\"vM\u001a,'\u000f\t\u0015\u0003\u0011!\u0004\"aI5\n\u0005)$#!\u0003;sC:\u001c\u0018.\u001a8u\u0001")
/* loaded from: input_file:com/intel/analytics/bigdl/nn/PositionEncode.class */
public class PositionEncode<T> extends TensorModule<T> {
    private final ClassTag<T> evidence$5;
    private final TensorNumericMath.TensorNumeric<T> ev;
    private transient Tensor<T> rangeBuffer;

    private Tensor<T> rangeBuffer() {
        return this.rangeBuffer;
    }

    private void rangeBuffer_$eq(Tensor<T> tensor) {
        this.rangeBuffer = tensor;
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    public Tensor<T> updateOutput(Tensor<T> tensor) {
        int size = tensor.size(2);
        int size2 = tensor.size(3);
        if (!output().isEmpty() && output().nElement() == size * size2) {
            return output();
        }
        if (rangeBuffer() == null) {
            rangeBuffer_$eq(Tensor$.MODULE$.apply(this.evidence$5, this.ev));
        }
        TransformerOperation$.MODULE$.initRangeTensor(size, rangeBuffer(), this.evidence$5, this.ev);
        output().resize(size, size2);
        Tensor<T> rangeBuffer = rangeBuffer();
        Tensor<T> output = output();
        TransformerOperation$.MODULE$.getPositionEncode(size, size2, TransformerOperation$.MODULE$.getPositionEncode$default$3(), TransformerOperation$.MODULE$.getPositionEncode$default$4(), rangeBuffer, output, this.evidence$5, this.ev);
        return output();
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    /* renamed from: updateGradInput, reason: merged with bridge method [inline-methods] */
    public Tensor<T> updateGradInput2(Tensor<T> tensor, Tensor<T> tensor2) {
        if (!gradInput().isEmpty() && gradInput().nElement() == tensor.nElement()) {
            return gradInput();
        }
        gradInput().resizeAs(tensor).zero();
        return gradInput();
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public PositionEncode(ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        super(classTag, tensorNumeric);
        this.evidence$5 = classTag;
        this.ev = tensorNumeric;
        this.rangeBuffer = null;
    }
}
