package com.intel.analytics.bigdl.nn;

import com.intel.analytics.bigdl.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.tensor.Tensor;
import com.intel.analytics.bigdl.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.utils.Node;
import com.intel.analytics.bigdl.utils.T$;
import com.intel.analytics.bigdl.utils.Table;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.Seq;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: Attention.scala */
@ScalaSignature(bytes = "\u0006\u0001\t\u001dh\u0001\u0002#F\u0001AC\u0001\"\u001b\u0001\u0003\u0006\u0004%\tA\u001b\u0005\t]\u0002\u0011\t\u0011)A\u0005W\"Aq\u000e\u0001BC\u0002\u0013\u0005!\u000e\u0003\u0005q\u0001\t\u0005\t\u0015!\u0003l\u0011!\t\bA!b\u0001\n\u0003\u0011\b\u0002\u0003<\u0001\u0005\u0003\u0005\u000b\u0011B:\t\u0011]\u0004!1!Q\u0001\faD\u0001B \u0001\u0003\u0002\u0003\u0006Ya \u0005\b\u0003W\u0001A\u0011AA\u0017\u0011%\ty\u0004\u0001b\u0001\n\u0013\t\t\u0005\u0003\u0005\u0002J\u0001\u0001\u000b\u0011BA\"\u0011%\tY\u0005\u0001b\u0001\n\u0013\t\t\u0005\u0003\u0005\u0002N\u0001\u0001\u000b\u0011BA\"\u0011%\ty\u0005\u0001b\u0001\n\u0013\t\t\u0006\u0003\u0005\u0002b\u0001\u0001\u000b\u0011BA*\u0011%\t\u0019\u0007\u0001b\u0001\n\u0013\t\t\u0006\u0003\u0005\u0002f\u0001\u0001\u000b\u0011BA*\u0011%\t9\u0007\u0001b\u0001\n\u0013\t\t\u0006\u0003\u0005\u0002j\u0001\u0001\u000b\u0011BA*\u0011%\tY\u0007\u0001b\u0001\n\u0013\ti\u0007\u0003\u0005\u0002v\u0001\u0001\u000b\u0011BA8\u0011%\t9\b\u0001b\u0001\n\u0013\ti\u0007\u0003\u0005\u0002z\u0001\u0001\u000b\u0011BA8\u0011%\tY\b\u0001b\u0001\n\u0013\ti\u0007\u0003\u0005\u0002~\u0001\u0001\u000b\u0011BA8\u0011%\ty\b\u0001b\u0001\n\u0013\t\t\t\u0003\u0005\u0002\n\u0002\u0001\u000b\u0011BAB\u0011%\tY\t\u0001b\u0001\n\u0013\t\t\t\u0003\u0005\u0002\u000e\u0002\u0001\u000b\u0011BAB\u0011%\ty\t\u0001b\u0001\n\u0013\t\t\t\u0003\u0005\u0002\u0012\u0002\u0001\u000b\u0011BAB\u0011%\t\u0019\n\u0001b\u0001\n\u0013\t)\n\u0003\u0005\u0002\u001e\u0002\u0001\u000b\u0011BAL\u0011%\ty\n\u0001b\u0001\n\u0013\t\t\u000b\u0003\u0005\u0002*\u0002\u0001\u000b\u0011BAR\u0011%\tY\u000b\u0001b\u0001\n\u0013\t\t\u0006\u0003\u0005\u0002.\u0002\u0001\u000b\u0011BA*\u0011%\ty\u000b\u0001b\u0001\n\u0013\t\t\f\u0003\u0005\u0002:\u0002\u0001\u000b\u0011BAZ\u0011%\tY\f\u0001b\u0001\n\u0013\t)\n\u0003\u0005\u0002>\u0002\u0001\u000b\u0011BAL\u0011%\ty\f\u0001b\u0001\n\u0013\t\t\r\u0003\u0005\u0002J\u0002\u0001\u000b\u0011BAb\u0011%\tY\r\u0001b\u0001\n\u0013\t\t\u0006\u0003\u0005\u0002N\u0002\u0001\u000b\u0011BA*\u0011)\ty\r\u0001b\u0001\n\u00039\u0015\u0011\u000b\u0005\t\u0003#\u0004\u0001\u0015!\u0003\u0002T!I\u00111\u001b\u0001C\u0002\u0013%\u0011\u0011\u000b\u0005\t\u0003+\u0004\u0001\u0015!\u0003\u0002T!9\u0011q\u001b\u0001\u0005\n\u0005e\u0007bBA\u007f\u0001\u0011%\u0011q \u0005\b\u0005\u000b\u0001A\u0011\tB\u0004\u0011\u001d\u0011Y\u0001\u0001C!\u0005\u001bAqA!\u0006\u0001\t\u0003\u00129\u0002C\u0004\u0003$\u0001!\tE!\n\t\u000f\t%\u0002\u0001\"\u0011\u0003&!9!1\u0006\u0001\u0005B\t5\u0002b\u0002B\u001f\u0001\u0011\u0005#q\b\u0005\b\u0005C\u0002A\u0011\tB2\u0011\u001d\u0011)\u0007\u0001C!\u0005OBqAa\u001c\u0001\t\u0003\u0012\t\bC\u0004\u0003��\u0001!\tE!\n\b\u000f\t\u0005U\t#\u0001\u0003\u0004\u001a1A)\u0012E\u0001\u0005\u000bCq!a\u000bA\t\u0003\u0011\u0019\nC\u0004\u0003\u0016\u0002#\tAa&\t\u0013\tM\u0007)!A\u0005\n\tU'!C!ui\u0016tG/[8o\u0015\t1u)\u0001\u0002o]*\u0011\u0001*S\u0001\u0006E&<G\r\u001c\u0006\u0003\u0015.\u000b\u0011\"\u00198bYf$\u0018nY:\u000b\u00051k\u0015!B5oi\u0016d'\"\u0001(\u0002\u0007\r|Wn\u0001\u0001\u0016\u0005Ek6C\u0001\u0001S!\u0015\u0019f\u000b\u0017-\\\u001b\u0005!&BA+F\u0003)\t'm\u001d;sC\u000e$hN\\\u0005\u0003/R\u0013a\"\u00112tiJ\f7\r^'pIVdW\r\u0005\u0002T3&\u0011!\f\u0016\u0002\t\u0003\u000e$\u0018N^5usB\u0011A,\u0018\u0007\u0001\t\u0015q\u0006A1\u0001`\u0005\u0005!\u0016C\u00011g!\t\tG-D\u0001c\u0015\u0005\u0019\u0017!B:dC2\f\u0017BA3c\u0005\u001dqu\u000e\u001e5j]\u001e\u0004\"!Y4\n\u0005!\u0014'aA!os\u0006Q\u0001.\u001b3eK:\u001c\u0016N_3\u0016\u0003-\u0004\"!\u00197\n\u00055\u0014'aA%oi\u0006Y\u0001.\u001b3eK:\u001c\u0016N_3!\u0003!qW/\u001c%fC\u0012\u001c\u0018!\u00038v[\"+\u0017\rZ:!\u0003A\tG\u000f^3oi&|g\u000e\u0012:pa>,H/F\u0001t!\t\tG/\u0003\u0002vE\n)a\t\\8bi\u0006\t\u0012\r\u001e;f]RLwN\u001c#s_B|W\u000f\u001e\u0011\u0002\u0015\u00154\u0018\u000eZ3oG\u0016$\u0013\u0007E\u0002zynk\u0011A\u001f\u0006\u0003w\n\fqA]3gY\u0016\u001cG/\u0003\u0002~u\nA1\t\\1tgR\u000bw-\u0001\u0002fmB)\u0011\u0011AA\u00137:!\u00111AA\u0010\u001d\u0011\t)!a\u0007\u000f\t\u0005\u001d\u0011\u0011\u0004\b\u0005\u0003\u0013\t9B\u0004\u0003\u0002\f\u0005Ua\u0002BA\u0007\u0003'i!!a\u0004\u000b\u0007\u0005Eq*\u0001\u0004=e>|GOP\u0005\u0002\u001d&\u0011A*T\u0005\u0003\u0015.K!\u0001S%\n\u0007\u0005uq)\u0001\u0004uK:\u001cxN]\u0005\u0005\u0003C\t\u0019#A\tUK:\u001cxN\u001d(v[\u0016\u0014\u0018nY'bi\"T1!!\bH\u0013\u0011\t9#!\u000b\u0003\u001bQ+gn]8s\u001dVlWM]5d\u0015\u0011\t\t#a\t\u0002\rqJg.\u001b;?)!\ty#!\u000f\u0002<\u0005uBCBA\u0019\u0003k\t9\u0004\u0005\u0003\u00024\u0001YV\"A#\t\u000b]L\u00019\u0001=\t\u000byL\u00019A@\t\u000b%L\u0001\u0019A6\t\u000b=L\u0001\u0019A6\t\u000bEL\u0001\u0019A:\u0002\u000b)|\u0017N\\&\u0016\u0005\u0005\r\u0003#BA\u001a\u0003\u000bZ\u0016bAA$\u000b\nI!j\\5o)\u0006\u0014G.Z\u0001\u0007U>Lgn\u0013\u0011\u0002\u000b)|\u0017N\u001c,\u0002\r)|\u0017N\u001c,!\u0003)\tX/\u001a:z\u0019\u0006LXM]\u000b\u0003\u0003'\u0002R!!\u0016\u0002\\msA!!\u0002\u0002X%\u0019\u0011\u0011L$\u0002\u000fA\f7m[1hK&!\u0011QLA0\u0005\u0019iu\u000eZ;mK*\u0019\u0011\u0011L$\u0002\u0017E,XM]=MCf,'\u000fI\u0001\tW\u0016LH*Y=fe\u0006I1.Z=MCf,'\u000fI\u0001\u000bm\u0006dW/\u001a'bs\u0016\u0014\u0018a\u0003<bYV,G*Y=fe\u0002\nq\"];fef\u001c\u0006\u000f\\5u\u0019\u0006LXM]\u000b\u0003\u0003_\u0002R!a\r\u0002rmK1!a\u001dF\u0005)\u0019\u0006\u000f\\5u\u0011\u0016\fGm]\u0001\u0011cV,'/_*qY&$H*Y=fe\u0002\nQb[3z'Bd\u0017\u000e\u001e'bs\u0016\u0014\u0018AD6fsN\u0003H.\u001b;MCf,'\u000fI\u0001\u0010m\u0006dW/Z*qY&$H*Y=fe\u0006\u0001b/\u00197vKN\u0003H.\u001b;MCf,'\u000fI\u0001\u0011G>tG/[4v_V\u001c\u0018\u000bT1zKJ,\"!a!\u0011\u000b\u0005M\u0012QQ.\n\u0007\u0005\u001dUI\u0001\u0006D_:$\u0018nZ;pkN\f\u0011cY8oi&<Wo\\;t#2\u000b\u00170\u001a:!\u0003A\u0019wN\u001c;jOV|Wo]&MCf,'/A\td_:$\u0018nZ;pkN\\E*Y=fe\u0002\n\u0001cY8oi&<Wo\\;t-2\u000b\u00170\u001a:\u0002#\r|g\u000e^5hk>,8O\u0016'bs\u0016\u0014\b%A\u0006nCRlW\u000f\u001c'bs\u0016\u0014XCAAL!\u0015\t\u0019$!'\\\u0013\r\tY*\u0012\u0002\u0003\u001b6\u000bA\"\\1u[VdG*Y=fe\u0002\n\u0011bY1eI2\u000b\u00170\u001a:\u0016\u0005\u0005\r\u0006CBA\u001a\u0003K[6,C\u0002\u0002(\u0016\u0013\u0011bQ!eIR\u000b'\r\\3\u0002\u0015\r\fG\r\u001a'bs\u0016\u0014\b%\u0001\u0007t_\u001a$X*\u0019=MCf,'/A\u0007t_\u001a$X*\u0019=MCf,'\u000fI\u0001\nIJ|\u0007\u000fT1zKJ,\"!a-\u0011\u000b\u0005M\u0012QW.\n\u0007\u0005]VIA\u0004Ee>\u0004x.\u001e;\u0002\u0015\u0011\u0014x\u000e\u001d'bs\u0016\u0014\b%\u0001\nnCRlW\u000f\u001c(p)J\fgn\u001d'bs\u0016\u0014\u0018aE7bi6,HNT8Ue\u0006t7\u000fT1zKJ\u0004\u0013!E2p[\nLg.\u001a%fC\u0012\u001cH*Y=feV\u0011\u00111\u0019\t\u0006\u0003g\t)mW\u0005\u0004\u0003\u000f,%\u0001D\"p[\nLg.\u001a%fC\u0012\u001c\u0018AE2p[\nLg.\u001a%fC\u0012\u001cH*Y=fe\u0002\n1b\\;uaV$H*Y=fe\u0006aq.\u001e;qkRd\u0015-_3sA\u0005)Qn\u001c3fY\u00061Qn\u001c3fY\u0002\nQa\u001a:ba\"\faa\u001a:ba\"\u0004\u0013\u0001D2sK\u0006$X-T8ek2,GCCAn\u0003[\f\t0!>\u0002zB)\u0011Q\\At7:!\u0011q\\Ar\u001d\u0011\t)!!9\n\u0005\u0019;\u0015bAAs\u000b\u0006)qI]1qQ&!\u0011\u0011^Av\u0005)iu\u000eZ;mK:{G-\u001a\u0006\u0004\u0003K,\u0005bBAxe\u0001\u0007\u00111\\\u0001\u000bS:\u0004X\u000f^)vKJL\bbBAze\u0001\u0007\u00111\\\u0001\tS:\u0004X\u000f^&fs\"9\u0011q\u001f\u001aA\u0002\u0005m\u0017AC5oaV$h+\u00197vK\"9\u00111 \u001aA\u0002\u0005m\u0017!C5oaV$()[1t\u0003E)\b\u000fZ1uK>+H\u000f];u\u0007\u0006\u001c\u0007.\u001a\u000b\u00041\n\u0005\u0001B\u0002B\u0002g\u0001\u0007\u0001,A\u0003j]B,H/\u0001\u0007va\u0012\fG/Z(viB,H\u000fF\u0002Y\u0005\u0013AaAa\u00015\u0001\u0004A\u0016aD;qI\u0006$Xm\u0012:bI&s\u0007/\u001e;\u0015\u000ba\u0013yA!\u0005\t\r\t\rQ\u00071\u0001Y\u0011\u0019\u0011\u0019\"\u000ea\u00011\u0006QqM]1e\u001fV$\b/\u001e;\u0002#\u0005\u001c7m\u0012:bIB\u000b'/Y7fi\u0016\u00148\u000f\u0006\u0004\u0003\u001a\t}!\u0011\u0005\t\u0004C\nm\u0011b\u0001B\u000fE\n!QK\\5u\u0011\u0019\u0011\u0019A\u000ea\u00011\"1!1\u0003\u001cA\u0002a\u000b\u0001\u0002\u001e:bS:Lgn\u001a\u000b\u0003\u0005Oi\u0011\u0001A\u0001\tKZ\fG.^1uK\u0006\tr-\u001a;FqR\u0014\u0018\rU1sC6,G/\u001a:\u0015\u0005\t=\u0002#B1\u00032\tU\u0012b\u0001B\u001aE\n)\u0011I\u001d:bsB)!q\u0007B\u001d76\u0011\u00111E\u0005\u0005\u0005w\t\u0019C\u0001\u0004UK:\u001cxN]\u0001\tO\u0016$H+[7fgR\u0011!\u0011\t\t\u0006C\nE\"1\t\t\nC\n\u0015#\u0011\nB.\u00057J1Aa\u0012c\u0005\u0019!V\u000f\u001d7fgA2!1\nB(\u0005/\u0002ra\u0015,\u0003N\tU3\fE\u0002]\u0005\u001f\"1B!\u0015;\u0003\u0003\u0005\tQ!\u0001\u0003T\t\u0019q\fJ\u0019\u0012\u0005\u0001D\u0006c\u0001/\u0003X\u0011Y!\u0011\f\u001e\u0002\u0002\u0003\u0005)\u0011\u0001B*\u0005\ryFE\r\t\u0004C\nu\u0013b\u0001B0E\n!Aj\u001c8h\u0003)\u0011Xm]3u)&lWm\u001d\u000b\u0003\u00053\t!\u0002]1sC6,G/\u001a:t)\t\u0011I\u0007E\u0004b\u0005W\u0012yCa\f\n\u0007\t5$M\u0001\u0004UkBdWMM\u0001\u0013O\u0016$\b+\u0019:b[\u0016$XM]:UC\ndW\r\u0006\u0002\u0003tA!!Q\u000fB>\u001b\t\u00119HC\u0002\u0003z\u001d\u000bQ!\u001e;jYNLAA! \u0003x\t)A+\u00192mK\u0006Q1\r\\3beN#\u0018\r^3\u0002\u0013\u0005#H/\u001a8uS>t\u0007cAA\u001a\u0001N)\u0001Ia\"\u0003\u000eB\u0019\u0011M!#\n\u0007\t-%M\u0001\u0004B]f\u0014VM\u001a\t\u0004C\n=\u0015b\u0001BIE\na1+\u001a:jC2L'0\u00192mKR\u0011!1Q\u0001\u0006CB\u0004H._\u000b\u0005\u00053\u0013\t\u000b\u0006\u0005\u0003\u001c\n5'q\u001aBi)\u0019\u0011iJa1\u0003JB)\u00111\u0007\u0001\u0003 B\u0019AL!)\u0005\u0013y\u0013\u0005\u0015!A\u0001\u0006\u0004y\u0006\u0006\u0003BQ\u0005K\u0013YK!/\u0011\u0007\u0005\u00149+C\u0002\u0003*\n\u00141b\u001d9fG&\fG.\u001b>fIFJ1E!,\u00030\nM&\u0011\u0017\b\u0004C\n=\u0016b\u0001BYE\u0006)a\t\\8biF2AE!.\u00038\u000etA!!\u0004\u00038&\t1-M\u0005$\u0005w\u0013iL!1\u0003@:\u0019\u0011M!0\n\u0007\t}&-\u0001\u0004E_V\u0014G.Z\u0019\u0007I\tU&qW2\t\u0013\t\u0015')!AA\u0004\t\u001d\u0017AC3wS\u0012,gnY3%iA!\u0011\u0010 BP\u0011\u0019q(\tq\u0001\u0003LB1\u0011\u0011AA\u0013\u0005?CQ!\u001b\"A\u0002-DQa\u001c\"A\u0002-DQ!\u001d\"A\u0002M\f1B]3bIJ+7o\u001c7wKR\u0011!q\u001b\t\u0005\u00053\u0014\u0019/\u0004\u0002\u0003\\*!!Q\u001cBp\u0003\u0011a\u0017M\\4\u000b\u0005\t\u0005\u0018\u0001\u00026bm\u0006LAA!:\u0003\\\n1qJ\u00196fGR\u0004")
/* loaded from: input_file:com/intel/analytics/bigdl/nn/Attention.class */
public class Attention<T> extends AbstractModule<Activity, Activity, T> {
    private final int hiddenSize;
    private final int numHeads;
    private final float attentionDropout;
    private final TensorNumericMath.TensorNumeric<T> ev;
    private final JoinTable<T> joinK;
    private final JoinTable<T> joinV;
    private final AbstractModule<Activity, Activity, T> queryLayer;
    private final AbstractModule<Activity, Activity, T> keyLayer;
    private final AbstractModule<Activity, Activity, T> valueLayer;
    private final SplitHeads<T> querySplitLayer;
    private final SplitHeads<T> keySplitLayer;
    private final SplitHeads<T> valueSplitLayer;
    private final Contiguous<T> contiguousQLayer;
    private final Contiguous<T> contiguousKLayer;
    private final Contiguous<T> contiguousVLayer;
    private final MM<T> matmulLayer;
    private final CAddTable<T, T> caddLayer;
    private final AbstractModule<Activity, Activity, T> softMaxLayer;
    private final Dropout<T> dropLayer;
    private final MM<T> matmulNoTransLayer;
    private final CombineHeads<T> combineHeadsLayer;
    private final AbstractModule<Activity, Activity, T> outputLayer;
    private final AbstractModule<Activity, Activity, T> model;
    private final AbstractModule<Activity, Activity, T> graph;

    public int hiddenSize() {
        return this.hiddenSize;
    }

    public int numHeads() {
        return this.numHeads;
    }

    public float attentionDropout() {
        return this.attentionDropout;
    }

    private JoinTable<T> joinK() {
        return this.joinK;
    }

    private JoinTable<T> joinV() {
        return this.joinV;
    }

    private AbstractModule<Activity, Activity, T> queryLayer() {
        return this.queryLayer;
    }

    private AbstractModule<Activity, Activity, T> keyLayer() {
        return this.keyLayer;
    }

    private AbstractModule<Activity, Activity, T> valueLayer() {
        return this.valueLayer;
    }

    private SplitHeads<T> querySplitLayer() {
        return this.querySplitLayer;
    }

    private SplitHeads<T> keySplitLayer() {
        return this.keySplitLayer;
    }

    private SplitHeads<T> valueSplitLayer() {
        return this.valueSplitLayer;
    }

    private Contiguous<T> contiguousQLayer() {
        return this.contiguousQLayer;
    }

    private Contiguous<T> contiguousKLayer() {
        return this.contiguousKLayer;
    }

    private Contiguous<T> contiguousVLayer() {
        return this.contiguousVLayer;
    }

    private MM<T> matmulLayer() {
        return this.matmulLayer;
    }

    private CAddTable<T, T> caddLayer() {
        return this.caddLayer;
    }

    private AbstractModule<Activity, Activity, T> softMaxLayer() {
        return this.softMaxLayer;
    }

    private Dropout<T> dropLayer() {
        return this.dropLayer;
    }

    private MM<T> matmulNoTransLayer() {
        return this.matmulNoTransLayer;
    }

    private CombineHeads<T> combineHeadsLayer() {
        return this.combineHeadsLayer;
    }

    private AbstractModule<Activity, Activity, T> outputLayer() {
        return this.outputLayer;
    }

    public AbstractModule<Activity, Activity, T> model() {
        return this.model;
    }

    private AbstractModule<Activity, Activity, T> graph() {
        return this.graph;
    }

    private Node<AbstractModule<Activity, Activity, T>> createModule(Node<AbstractModule<Activity, Activity, T>> node, Node<AbstractModule<Activity, Activity, T>> node2, Node<AbstractModule<Activity, Activity, T>> node3, Node<AbstractModule<Activity, Activity, T>> node4) {
        Node<AbstractModule<Activity, Activity, T>> inputs = querySplitLayer().inputs((Seq) Predef$.MODULE$.wrapRefArray(new Node[]{node}));
        Node<AbstractModule<Activity, Activity, T>> inputs2 = keySplitLayer().inputs((Seq) Predef$.MODULE$.wrapRefArray(new Node[]{node2}));
        Node<AbstractModule<Activity, Activity, T>> inputs3 = valueSplitLayer().inputs((Seq) Predef$.MODULE$.wrapRefArray(new Node[]{node3}));
        Node<AbstractModule<Activity, Activity, T>> inputs4 = contiguousQLayer().inputs((Seq) Predef$.MODULE$.wrapRefArray(new Node[]{inputs}));
        Node<AbstractModule<Activity, Activity, T>> inputs5 = contiguousKLayer().inputs((Seq) Predef$.MODULE$.wrapRefArray(new Node[]{inputs2}));
        Node<AbstractModule<Activity, Activity, T>> inputs6 = contiguousVLayer().inputs((Seq) Predef$.MODULE$.wrapRefArray(new Node[]{inputs3}));
        return outputLayer().inputs((Seq) Predef$.MODULE$.wrapRefArray(new Node[]{combineHeadsLayer().inputs((Seq) Predef$.MODULE$.wrapRefArray(new Node[]{matmulNoTransLayer().inputs((Seq) Predef$.MODULE$.wrapRefArray(new Node[]{dropLayer().inputs((Seq) Predef$.MODULE$.wrapRefArray(new Node[]{softMaxLayer().inputs((Seq) Predef$.MODULE$.wrapRefArray(new Node[]{caddLayer().inputs((Seq) Predef$.MODULE$.wrapRefArray(new Node[]{matmulLayer().inputs((Seq) Predef$.MODULE$.wrapRefArray(new Node[]{inputs4, inputs5})), node4}))}))})), inputs6}))}))}));
    }

    private Activity updateOutputCache(Activity activity) {
        Predef$.MODULE$.require(!isTraining(), () -> {
            return "Only support input cache for model inference";
        });
        Table table = activity.toTable();
        Tensor tensor = (Tensor) table.apply(BoxesRunTime.boxToInteger(1));
        Tensor tensor2 = (Tensor) table.apply(BoxesRunTime.boxToInteger(2));
        Tensor tensor3 = (Tensor) ((Table) table.apply(BoxesRunTime.boxToInteger(3))).apply(BoxesRunTime.boxToInteger(1));
        Table table2 = (Table) ((Table) table.apply(BoxesRunTime.boxToInteger(3))).apply(BoxesRunTime.boxToInteger(2));
        Tensor tensor4 = queryLayer().forward(tensor).toTensor(this.ev);
        Tuple2 tuple2 = table2.length() > 0 ? new Tuple2(table2.apply(new StringBuilder(2).append(getName()).append("_k").toString()), table2.apply(new StringBuilder(2).append(getName()).append("_v").toString())) : new Tuple2((Object) null, (Object) null);
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        Tuple2 tuple22 = new Tuple2((Tensor) tuple2._1(), (Tensor) tuple2._2());
        Tensor tensor5 = (Tensor) tuple22._1();
        Tensor tensor6 = (Tensor) tuple22._2();
        Tensor<?> tensor7 = (tensor5 == null || tensor5.isEmpty()) ? keyLayer().forward(tensor2).toTensor(this.ev) : joinK().forward(T$.MODULE$.apply(keyLayer().forward(tensor2).toTensor(this.ev), (Seq<Object>) Predef$.MODULE$.genericWrapArray(new Object[]{tensor5})));
        Tensor<?> tensor8 = (tensor6 == null || tensor6.isEmpty()) ? valueLayer().forward(tensor2).toTensor(this.ev) : joinV().forward(T$.MODULE$.apply(valueLayer().forward(tensor2).toTensor(this.ev), (Seq<Object>) Predef$.MODULE$.genericWrapArray(new Object[]{tensor6})));
        if (table2.length() > 0) {
            table2.update(new StringBuilder(2).append(getName()).append("_k").toString(), tensor7);
            table2.update(new StringBuilder(2).append(getName()).append("_v").toString(), tensor8);
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        output_$eq(graph().updateOutput(T$.MODULE$.apply(tensor4, (Seq<Object>) Predef$.MODULE$.genericWrapArray(new Object[]{tensor7, tensor8, tensor3}))));
        return output();
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    public Activity updateOutput(Activity activity) {
        Predef$.MODULE$.require(activity.toTable().length() == 3, () -> {
            return new StringBuilder(31).append("only support 3 inputs, but get ").append(activity.toTable().length()).toString();
        });
        Activity activity2 = (Activity) activity.toTable().apply(BoxesRunTime.boxToInteger(3));
        if (activity2 instanceof Tensor) {
            output_$eq(model().updateOutput(activity));
        } else if (activity2 instanceof Table) {
            output_$eq(updateOutputCache(activity));
        }
        return output();
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    /* renamed from: updateGradInput */
    public Activity updateGradInput2(Activity activity, Activity activity2) {
        gradInput_$eq(model().updateGradInput2(activity, activity2));
        return gradInput();
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    public void accGradParameters(Activity activity, Activity activity2) {
        model().accGradParameters(activity, activity2);
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    /* renamed from: training */
    public Attention<T> training2() {
        train_$eq(true);
        model().training2();
        return this;
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    /* renamed from: evaluate */
    public Attention<T> evaluate2() {
        train_$eq(false);
        model().evaluate2();
        return this;
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    public Tensor<T>[] getExtraParameter() {
        return model().getExtraParameter();
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    public Tuple3<AbstractModule<? extends Activity, ? extends Activity, T>, Object, Object>[] getTimes() {
        return model().getTimes();
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    public void resetTimes() {
        model().resetTimes();
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    public Tuple2<Tensor<T>[], Tensor<T>[]> parameters() {
        return model().parameters();
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    public Table getParametersTable() {
        return model().getParametersTable();
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    /* renamed from: clearState */
    public Attention<T> clearState2() {
        model().clearState2();
        return this;
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public Attention(int i, int i2, float f, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        super(ClassTag$.MODULE$.apply(Activity.class), ClassTag$.MODULE$.apply(Activity.class), classTag, tensorNumeric);
        this.hiddenSize = i;
        this.numHeads = i2;
        this.attentionDropout = f;
        this.ev = tensorNumeric;
        this.joinK = JoinTable$.MODULE$.apply(2, -1, classTag, tensorNumeric);
        this.joinV = JoinTable$.MODULE$.apply(2, -1, classTag, tensorNumeric);
        String sb = new StringBuilder(2).append(getName()).append("_q").toString();
        TransformerOperation$.MODULE$.dense$default$4();
        TransformerOperation$.MODULE$.dense$default$5();
        TransformerOperation$.MODULE$.dense$default$6();
        this.queryLayer = TransformerOperation$.MODULE$.dense(i, i, false, null, null, null, sb, classTag, tensorNumeric);
        String sb2 = new StringBuilder(2).append(getName()).append("_k").toString();
        TransformerOperation$.MODULE$.dense$default$4();
        TransformerOperation$.MODULE$.dense$default$5();
        TransformerOperation$.MODULE$.dense$default$6();
        this.keyLayer = TransformerOperation$.MODULE$.dense(i, i, false, null, null, null, sb2, classTag, tensorNumeric);
        String sb3 = new StringBuilder(2).append(getName()).append("_v").toString();
        TransformerOperation$.MODULE$.dense$default$4();
        TransformerOperation$.MODULE$.dense$default$5();
        TransformerOperation$.MODULE$.dense$default$6();
        this.valueLayer = TransformerOperation$.MODULE$.dense(i, i, false, null, null, null, sb3, classTag, tensorNumeric);
        this.querySplitLayer = new SplitHeads<>(i, i2, true, classTag, tensorNumeric);
        this.keySplitLayer = new SplitHeads<>(i, i2, SplitHeads$.MODULE$.$lessinit$greater$default$3(), classTag, tensorNumeric);
        this.valueSplitLayer = new SplitHeads<>(i, i2, SplitHeads$.MODULE$.$lessinit$greater$default$3(), classTag, tensorNumeric);
        this.contiguousQLayer = new Contiguous<>(classTag, tensorNumeric);
        this.contiguousKLayer = new Contiguous<>(classTag, tensorNumeric);
        this.contiguousVLayer = new Contiguous<>(classTag, tensorNumeric);
        this.matmulLayer = MM$.MODULE$.apply(MM$.MODULE$.apply$default$1(), true, classTag, tensorNumeric);
        this.caddLayer = CAddTable$.MODULE$.apply(CAddTable$.MODULE$.apply$default$1(), classTag, tensorNumeric);
        this.softMaxLayer = TransformerOperation$.MODULE$.softMax(classTag, tensorNumeric);
        this.dropLayer = Dropout$.MODULE$.apply(1.0d - f, Dropout$.MODULE$.apply$default$2(), Dropout$.MODULE$.apply$default$3(), classTag, tensorNumeric);
        this.matmulNoTransLayer = MM$.MODULE$.apply(MM$.MODULE$.apply$default$1(), MM$.MODULE$.apply$default$2(), classTag, tensorNumeric);
        this.combineHeadsLayer = new CombineHeads<>(classTag, tensorNumeric);
        String sb4 = new StringBuilder(17).append(getName()).append("_output_transform").toString();
        TransformerOperation$.MODULE$.dense$default$4();
        TransformerOperation$.MODULE$.dense$default$5();
        TransformerOperation$.MODULE$.dense$default$6();
        this.outputLayer = TransformerOperation$.MODULE$.dense(i, i, false, null, null, null, sb4, classTag, tensorNumeric);
        Node<AbstractModule<Activity, Activity, T>> apply = Input$.MODULE$.apply(Input$.MODULE$.apply$default$1(), classTag, tensorNumeric);
        Node<AbstractModule<Activity, Activity, T>> apply2 = Input$.MODULE$.apply(Input$.MODULE$.apply$default$1(), classTag, tensorNumeric);
        Node<AbstractModule<Activity, Activity, T>> apply3 = Input$.MODULE$.apply(Input$.MODULE$.apply$default$1(), classTag, tensorNumeric);
        Graph<T> apply4 = Graph$.MODULE$.apply(new Node[]{apply, apply2, apply3}, new Node[]{createModule(queryLayer().inputs((Seq) Predef$.MODULE$.wrapRefArray(new Node[]{apply})), keyLayer().inputs((Seq) Predef$.MODULE$.wrapRefArray(new Node[]{apply2})), valueLayer().inputs((Seq) Predef$.MODULE$.wrapRefArray(new Node[]{apply2})), apply3)}, Graph$.MODULE$.apply$default$3(), classTag, tensorNumeric);
        this.model = train() ? apply4.training2() : apply4.evaluate2();
        Node<AbstractModule<Activity, Activity, T>> apply5 = Input$.MODULE$.apply(Input$.MODULE$.apply$default$1(), classTag, tensorNumeric);
        Node<AbstractModule<Activity, Activity, T>> apply6 = Input$.MODULE$.apply(Input$.MODULE$.apply$default$1(), classTag, tensorNumeric);
        Node<AbstractModule<Activity, Activity, T>> apply7 = Input$.MODULE$.apply(Input$.MODULE$.apply$default$1(), classTag, tensorNumeric);
        Node<AbstractModule<Activity, Activity, T>> apply8 = Input$.MODULE$.apply(Input$.MODULE$.apply$default$1(), classTag, tensorNumeric);
        this.graph = Graph$.MODULE$.apply(new Node[]{apply5, apply6, apply7, apply8}, new Node[]{createModule(apply5, apply6, apply7, apply8)}, Graph$.MODULE$.apply$default$3(), classTag, tensorNumeric);
    }
}
