package com.intel.analytics.bigdl.nn;

import com.intel.analytics.bigdl.nn.abstractnn.TensorModule;
import com.intel.analytics.bigdl.tensor.Tensor;
import com.intel.analytics.bigdl.tensor.Tensor$;
import com.intel.analytics.bigdl.tensor.TensorNumericMath;
import scala.reflect.ClassTag;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;

/* compiled from: Transformer.scala */
@ScalaSignature(bytes = "\u0006\u0001Q4Q\u0001D\u0007\u0001\u001b]A\u0001B\f\u0001\u0003\u0004\u0003\u0006Ya\f\u0005\tk\u0001\u0011\t\u0011)A\u0006m!)A\n\u0001C\u0001\u001b\")1\u000b\u0001C!)\")1\f\u0001C!9\"9\u0001\r\u0001a\u0001\n\u0013\t\u0007b\u00022\u0001\u0001\u0004%Ia\u0019\u0005\u0007S\u0002\u0001\u000b\u0015B+\t\u000f9\u0004\u0001\u0019!C\u0005C\"9q\u000e\u0001a\u0001\n\u0013\u0001\bB\u0002:\u0001A\u0003&QKA\fQ_NLG/[8o\u000b:\u001cw\u000eZ3XSRD7\u000b[5gi*\u0011abD\u0001\u0003]:T!\u0001E\t\u0002\u000b\tLw\r\u001a7\u000b\u0005I\u0019\u0012!C1oC2LH/[2t\u0015\t!R#A\u0003j]R,GNC\u0001\u0017\u0003\r\u0019w.\\\u000b\u00031\u0005\u001a\"\u0001A\r\u0011\u0007iir$D\u0001\u001c\u0015\taR\"\u0001\u0006bEN$(/Y2u]:L!AH\u000e\u0003\u0019Q+gn]8s\u001b>$W\u000f\\3\u0011\u0005\u0001\nC\u0002\u0001\u0003\u0006E\u0001\u0011\r\u0001\n\u0002\u0002)\u000e\u0001\u0011CA\u0013,!\t1\u0013&D\u0001(\u0015\u0005A\u0013!B:dC2\f\u0017B\u0001\u0016(\u0005\u001dqu\u000e\u001e5j]\u001e\u0004\"A\n\u0017\n\u00055:#aA!os\u0006QQM^5eK:\u001cW\r\n\u001c\u0011\u0007A\u001at$D\u00012\u0015\t\u0011t%A\u0004sK\u001adWm\u0019;\n\u0005Q\n$\u0001C\"mCN\u001cH+Y4\u0002\u0005\u00154\bcA\u001cJ?9\u0011\u0001H\u0012\b\u0003s\u0011s!AO\"\u000f\u0005m\u0012eB\u0001\u001fB\u001d\ti\u0004)D\u0001?\u0015\ty4%\u0001\u0004=e>|GOP\u0005\u0002-%\u0011A#F\u0005\u0003%MI!\u0001E\t\n\u0005\u0015{\u0011A\u0002;f]N|'/\u0003\u0002H\u0011\u0006\tB+\u001a8t_JtU/\\3sS\u000el\u0015\r\u001e5\u000b\u0005\u0015{\u0011B\u0001&L\u00055!VM\\:pe:+X.\u001a:jG*\u0011q\tS\u0001\u0007y%t\u0017\u000e\u001e \u0015\u00039#2aT)S!\r\u0001\u0006aH\u0007\u0002\u001b!)af\u0001a\u0002_!)Qg\u0001a\u0002m\u0005aQ\u000f\u001d3bi\u0016|U\u000f\u001e9viR\u0011Q+\u0017\t\u0004-^{R\"\u0001%\n\u0005aC%A\u0002+f]N|'\u000fC\u0003[\t\u0001\u0007Q+A\u0003j]B,H/A\bva\u0012\fG/Z$sC\u0012Le\u000e];u)\r)VL\u0018\u0005\u00065\u0016\u0001\r!\u0016\u0005\u0006?\u0016\u0001\r!V\u0001\u000bOJ\fGmT;uaV$\u0018a\u0003:b]\u001e,')\u001e4gKJ,\u0012!V\u0001\u0010e\u0006tw-\u001a\"vM\u001a,'o\u0018\u0013fcR\u0011Am\u001a\t\u0003M\u0015L!AZ\u0014\u0003\tUs\u0017\u000e\u001e\u0005\bQ\u001e\t\t\u00111\u0001V\u0003\rAH%M\u0001\re\u0006tw-\u001a\"vM\u001a,'\u000f\t\u0015\u0003\u0011-\u0004\"A\n7\n\u00055<#!\u0003;sC:\u001c\u0018.\u001a8u\u0003)!\u0018.\\3Ck\u001a4WM]\u0001\u000fi&lWMQ;gM\u0016\u0014x\fJ3r)\t!\u0017\u000fC\u0004i\u0015\u0005\u0005\t\u0019A+\u0002\u0017QLW.\u001a\"vM\u001a,'\u000f\t\u0015\u0003\u0017-\u0004")
/* loaded from: input_file:com/intel/analytics/bigdl/nn/PositionEncodeWithShift.class */
public class PositionEncodeWithShift<T> extends TensorModule<T> {
    private final ClassTag<T> evidence$6;
    private final TensorNumericMath.TensorNumeric<T> ev;
    private transient Tensor<T> rangeBuffer;
    private transient Tensor<T> timeBuffer;

    private Tensor<T> rangeBuffer() {
        return this.rangeBuffer;
    }

    private void rangeBuffer_$eq(Tensor<T> tensor) {
        this.rangeBuffer = tensor;
    }

    private Tensor<T> timeBuffer() {
        return this.timeBuffer;
    }

    private void timeBuffer_$eq(Tensor<T> tensor) {
        this.timeBuffer = tensor;
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    public Tensor<T> updateOutput(Tensor<T> tensor) {
        TransformerOperation$.MODULE$.shiftRight3D(tensor, output(), this.evidence$6, this.ev);
        int size = output().size(2);
        int size2 = output().size(3);
        if (rangeBuffer() == null) {
            rangeBuffer_$eq(Tensor$.MODULE$.apply(this.evidence$6, this.ev));
        }
        if (timeBuffer() == null) {
            timeBuffer_$eq(Tensor$.MODULE$.apply(this.evidence$6, this.ev));
        }
        if (timeBuffer().nElement() != size * size2) {
            TransformerOperation$.MODULE$.initRangeTensor(size, rangeBuffer(), this.evidence$6, this.ev);
            timeBuffer_$eq(Tensor$.MODULE$.apply(this.evidence$6, this.ev).resize(size, size2));
            Tensor<T> rangeBuffer = rangeBuffer();
            Tensor<T> timeBuffer = timeBuffer();
            TransformerOperation$.MODULE$.getPositionEncode(size, size2, TransformerOperation$.MODULE$.getPositionEncode$default$3(), TransformerOperation$.MODULE$.getPositionEncode$default$4(), rangeBuffer, timeBuffer, this.evidence$6, this.ev);
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        int size3 = tensor.size(1);
        int i = 1;
        while (true) {
            int i2 = i;
            if (i2 > size3) {
                return output();
            }
            output().select(1, i2).add((Tensor) timeBuffer());
            i = i2 + 1;
        }
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    /* renamed from: updateGradInput, reason: merged with bridge method [inline-methods] */
    public Tensor<T> updateGradInput2(Tensor<T> tensor, Tensor<T> tensor2) {
        if (gradInput() == null) {
            gradInput_$eq(Tensor$.MODULE$.apply(this.evidence$6, this.ev));
        }
        gradInput().resizeAs(tensor2).zero();
        int size = tensor2.size(2);
        int i = 1;
        while (true) {
            int i2 = i;
            if (i2 >= size) {
                return gradInput();
            }
            gradInput().select(2, i2).copy(tensor2.select(2, i2 + 1));
            i = i2 + 1;
        }
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public PositionEncodeWithShift(ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        super(classTag, tensorNumeric);
        this.evidence$6 = classTag;
        this.ev = tensorNumeric;
        this.rangeBuffer = null;
        this.timeBuffer = null;
    }
}
