package com.intel.analytics.bigdl.nn;

import com.intel.analytics.bigdl.nn.abstractnn.TensorModule;
import com.intel.analytics.bigdl.tensor.ConvertableFrom$ConvertableFromDouble$;
import com.intel.analytics.bigdl.tensor.Tensor;
import com.intel.analytics.bigdl.tensor.TensorNumericMath;
import scala.Predef$;
import scala.reflect.ClassTag;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichInt$;
import scala.runtime.ScalaRunTime$;

/* compiled from: Transformer.scala */
@ScalaSignature(bytes = "\u0006\u0001Y4Q!\u0003\u0006\u0001\u0015QA\u0001b\u000b\u0001\u0003\u0004\u0003\u0006Y\u0001\f\u0005\te\u0001\u0011\t\u0011)A\u0006g!)\u0011\n\u0001C\u0001\u0015\"9\u0001\u000b\u0001b\u0001\n\u0013\t\u0006BB+\u0001A\u0003%!\u000bC\u0003W\u0001\u0011%q\u000bC\u0003m\u0001\u0011\u0005S\u000eC\u0003r\u0001\u0011\u0005#OA\tTK24\u0017\t\u001e;f]RLwN\\'bg.T!a\u0003\u0007\u0002\u00059t'BA\u0007\u000f\u0003\u0015\u0011\u0017n\u001a3m\u0015\ty\u0001#A\u0005b]\u0006d\u0017\u0010^5dg*\u0011\u0011CE\u0001\u0006S:$X\r\u001c\u0006\u0002'\u0005\u00191m\\7\u0016\u0005Uq2C\u0001\u0001\u0017!\r9\"\u0004H\u0007\u00021)\u0011\u0011DC\u0001\u000bC\n\u001cHO]1di:t\u0017BA\u000e\u0019\u00051!VM\\:pe6{G-\u001e7f!\tib\u0004\u0004\u0001\u0005\u000b}\u0001!\u0019A\u0011\u0003\u0003Q\u001b\u0001!\u0005\u0002#QA\u00111EJ\u0007\u0002I)\tQ%A\u0003tG\u0006d\u0017-\u0003\u0002(I\t9aj\u001c;iS:<\u0007CA\u0012*\u0013\tQCEA\u0002B]f\f!\"\u001a<jI\u0016t7-\u001a\u00139!\ri\u0003\u0007H\u0007\u0002])\u0011q\u0006J\u0001\be\u00164G.Z2u\u0013\t\tdF\u0001\u0005DY\u0006\u001c8\u000fV1h\u0003\t)g\u000fE\u00025\rrq!!N\"\u000f\u0005Y\neBA\u001cA\u001d\tAtH\u0004\u0002:}9\u0011!(P\u0007\u0002w)\u0011A\bI\u0001\u0007yI|w\u000e\u001e \n\u0003MI!!\u0005\n\n\u0005=\u0001\u0012BA\u0007\u000f\u0013\t\u0011E\"\u0001\u0004uK:\u001cxN]\u0005\u0003\t\u0016\u000b\u0011\u0003V3og>\u0014h*^7fe&\u001cW*\u0019;i\u0015\t\u0011E\"\u0003\u0002H\u0011\niA+\u001a8t_JtU/\\3sS\u000eT!\u0001R#\u0002\rqJg.\u001b;?)\u0005YEc\u0001'O\u001fB\u0019Q\n\u0001\u000f\u000e\u0003)AQaK\u0002A\u00041BQAM\u0002A\u0004M\n\u0011\"\\1tWZ\u000bG.^3\u0016\u0003I\u0003\"aI*\n\u0005Q##A\u0002#pk\ndW-\u0001\u0006nCN\\g+\u00197vK\u0002\n!$\u0019;uK:$\u0018n\u001c8CS\u0006\u001cHj\\<feR\u0013\u0018.\u00198hY\u0016,\"\u0001W0\u0015\u0007e+'\u000eF\u0002[A\u000e\u00042a\u0017/_\u001b\u0005)\u0015BA/F\u0005\u0019!VM\\:peB\u0011Qd\u0018\u0003\u0006?\u0019\u0011\r!\t\u0005\bC\u001a\t\t\u0011q\u0001c\u0003))g/\u001b3f]\u000e,G%\u000f\t\u0004[Ar\u0006\"\u0002\u001a\u0007\u0001\b!\u0007c\u0001\u001bG=\")aM\u0002a\u0001O\u00061A.\u001a8hi\"\u0004\"a\t5\n\u0005%$#aA%oi\")1N\u0002a\u00015\u00061q.\u001e;qkR\fA\"\u001e9eCR,w*\u001e;qkR$\"A\\8\u0011\u0007mcF\u0004C\u0003q\u000f\u0001\u0007a.A\u0003j]B,H/A\bva\u0012\fG/Z$sC\u0012Le\u000e];u)\rq7\u000f\u001e\u0005\u0006a\"\u0001\rA\u001c\u0005\u0006k\"\u0001\rA\\\u0001\u000bOJ\fGmT;uaV$\b")
/* loaded from: input_file:com/intel/analytics/bigdl/nn/SelfAttentionMask.class */
public class SelfAttentionMask<T> extends TensorModule<T> {
    private final ClassTag<T> evidence$8;
    private final TensorNumericMath.TensorNumeric<T> ev;
    private final double maskValue;

    private double maskValue() {
        return this.maskValue;
    }

    private <T> Tensor<T> attentionBiasLowerTriangle(int i, Tensor<T> tensor, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Object array = tensor.storage().array();
        RichInt$.MODULE$.to$extension0(Predef$.MODULE$.intWrapper(0), i - 1).foreach$mVc$sp(i2 -> {
            int i2 = i;
            while (true) {
                int i3 = i2 - 1;
                if (i3 <= i2) {
                    return;
                }
                ScalaRunTime$.MODULE$.array_update(array, (i2 * i) + i3, tensorNumeric.mo1182fromType(BoxesRunTime.boxToDouble(this.maskValue()), ConvertableFrom$ConvertableFromDouble$.MODULE$));
                i2 = i3;
            }
        });
        return tensor.resize(new int[]{1, 1, i, i}, tensor.resize$default$2());
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    public Tensor<T> updateOutput(Tensor<T> tensor) {
        if (!output().isEmpty() && output().nElement() == tensor.nElement()) {
            return output();
        }
        output().resize(tensor.size(2), tensor.size(2)).zero();
        attentionBiasLowerTriangle(tensor.size(2), output(), this.evidence$8, this.ev);
        return output();
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    /* renamed from: updateGradInput, reason: merged with bridge method [inline-methods] */
    public Tensor<T> updateGradInput2(Tensor<T> tensor, Tensor<T> tensor2) {
        if (!gradInput().isEmpty() && gradInput().nElement() == tensor.nElement()) {
            return gradInput();
        }
        gradInput().resizeAs(tensor).zero();
        return gradInput();
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public SelfAttentionMask(ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        super(classTag, tensorNumeric);
        this.evidence$8 = classTag;
        this.ev = tensorNumeric;
        this.maskValue = -1.0E9d;
    }
}
