package com.intel.analytics.bigdl.nn.keras;

import com.intel.analytics.bigdl.nn.BiRecurrent$;
import com.intel.analytics.bigdl.nn.CAddTable$;
import com.intel.analytics.bigdl.nn.CAveTable;
import com.intel.analytics.bigdl.nn.CAveTable$;
import com.intel.analytics.bigdl.nn.CMulTable$;
import com.intel.analytics.bigdl.nn.Cell;
import com.intel.analytics.bigdl.nn.JoinTable$;
import com.intel.analytics.bigdl.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.tensor.Tensor;
import com.intel.analytics.bigdl.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.utils.Shape;
import com.intel.analytics.bigdl.utils.Shape$;
import scala.MatchError;
import scala.Predef$;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;

/* compiled from: Bidirectional.scala */
@ScalaSignature(bytes = "\u0006\u0001\u00055g\u0001B\f\u0019\u0001\u0015B\u0001b\u0010\u0001\u0003\u0006\u0004%\t\u0001\u0011\u0005\t\t\u0002\u0011\t\u0011)A\u0005\u0003\"AQ\t\u0001BC\u0002\u0013\u0005a\t\u0003\u0005S\u0001\t\u0005\t\u0015!\u0003H\u0011!\u0019\u0006A!b\u0001\n\u0003!\u0006\u0002C.\u0001\u0005\u0003\u0005\u000b\u0011B+\t\u0011q\u0003!1!Q\u0001\fuC\u0001b\u0019\u0001\u0003\u0002\u0003\u0006Y\u0001\u001a\u0005\u0006k\u0002!\tA\u001e\u0005\b}\u0002\u0011\r\u0011\"\u0003��\u0011!\ty\u0001\u0001Q\u0001\n\u0005\u0005\u0001bBA\t\u0001\u0011\u0005\u00131\u0003\u0005\b\u0003/\u0001A\u0011IA\r\u000f\u001d\tI\u0003\u0007E\u0001\u0003W1aa\u0006\r\t\u0002\u00055\u0002BB;\u0010\t\u0003\tY\u0004C\u0004\u0002>=!\t!a\u0010\t\u0013\u0005ut\"%A\u0005\u0002\u0005}\u0004\"CAM\u001fE\u0005I\u0011AAN\u0011%\t\u0019kDI\u0001\n\u0003\t)\u000bC\u0005\u00024>\t\n\u0011\"\u0001\u00026\"I\u00111Y\b\u0002\u0002\u0013%\u0011Q\u0019\u0002\u000e\u0005&$\u0017N]3di&|g.\u00197\u000b\u0005eQ\u0012!B6fe\u0006\u001c(BA\u000e\u001d\u0003\tqgN\u0003\u0002\u001e=\u0005)!-[4eY*\u0011q\u0004I\u0001\nC:\fG.\u001f;jGNT!!\t\u0012\u0002\u000b%tG/\u001a7\u000b\u0003\r\n1aY8n\u0007\u0001)\"AJ\u001a\u0014\u0005\u00019\u0003#\u0002\u0015*W-\nT\"\u0001\r\n\u0005)B\"AC&fe\u0006\u001cH*Y=feB\u0019AfL\u0019\u000e\u00035R!A\f\u000f\u0002\rQ,gn]8s\u0013\t\u0001TF\u0001\u0004UK:\u001cxN\u001d\t\u0003eMb\u0001\u0001B\u00035\u0001\t\u0007QGA\u0001U#\t1D\b\u0005\u00028u5\t\u0001HC\u0001:\u0003\u0015\u00198-\u00197b\u0013\tY\u0004HA\u0004O_RD\u0017N\\4\u0011\u0005]j\u0014B\u0001 9\u0005\r\te._\u0001\u0006Y\u0006LXM]\u000b\u0002\u0003B\u0019\u0001FQ\u0019\n\u0005\rC\"!\u0003*fGV\u0014(/\u001a8u\u0003\u0019a\u0017-_3sA\u0005IQ.\u001a:hK6{G-Z\u000b\u0002\u000fB\u0011\u0001j\u0014\b\u0003\u00136\u0003\"A\u0013\u001d\u000e\u0003-S!\u0001\u0014\u0013\u0002\rq\u0012xn\u001c;?\u0013\tq\u0005(\u0001\u0004Qe\u0016$WMZ\u0005\u0003!F\u0013aa\u0015;sS:<'B\u0001(9\u0003)iWM]4f\u001b>$W\rI\u0001\u000bS:\u0004X\u000f^*iCB,W#A+\u0011\u0005YKV\"A,\u000b\u0005ac\u0012!B;uS2\u001c\u0018B\u0001.X\u0005\u0015\u0019\u0006.\u00199f\u0003-Ig\u000e];u'\"\f\u0007/\u001a\u0011\u0002\u0015\u00154\u0018\u000eZ3oG\u0016$\u0013\u0007E\u0002_CFj\u0011a\u0018\u0006\u0003Ab\nqA]3gY\u0016\u001cG/\u0003\u0002c?\nA1\t\\1tgR\u000bw-\u0001\u0002fmB\u0019QM]\u0019\u000f\u0005\u0019\u0004hBA4p\u001d\tAgN\u0004\u0002j[:\u0011!\u000e\u001c\b\u0003\u0015.L\u0011aI\u0005\u0003C\tJ!a\b\u0011\n\u0005uq\u0012B\u0001\u0018\u001d\u0013\t\tX&A\tUK:\u001cxN\u001d(v[\u0016\u0014\u0018nY'bi\"L!a\u001d;\u0003\u001bQ+gn]8s\u001dVlWM]5d\u0015\t\tX&\u0001\u0004=S:LGO\u0010\u000b\u0005ondX\u0010F\u0002ysj\u00042\u0001\u000b\u00012\u0011\u0015a\u0016\u0002q\u0001^\u0011\u0015\u0019\u0017\u0002q\u0001e\u0011\u0015y\u0014\u00021\u0001B\u0011\u001d)\u0015\u0002%AA\u0002\u001dCqaU\u0005\u0011\u0002\u0003\u0007Q+\u0001\u0003n_\u0012,WCAA\u0001!\u0011\t\u0019!!\u0004\u000e\u0005\u0005\u0015!\u0002BA\u0004\u0003\u0013\tA\u0001\\1oO*\u0011\u00111B\u0001\u0005U\u00064\u0018-C\u0002Q\u0003\u000b\tQ!\\8eK\u0002\n!cY8naV$XmT;uaV$8\u000b[1qKR\u0019Q+!\u0006\t\u000bMc\u0001\u0019A+\u0002\u000f\u0011|')^5mIR!\u00111DA\u0014!\u001d\ti\"a\t,WEj!!a\b\u000b\u0007\u0005\u0005\"$\u0001\u0006bEN$(/Y2u]:LA!!\n\u0002 \tq\u0011IY:ue\u0006\u001cG/T8ek2,\u0007\"B*\u000e\u0001\u0004)\u0016!\u0004\"jI&\u0014Xm\u0019;j_:\fG\u000e\u0005\u0002)\u001fM)q\"a\f\u00026A\u0019q'!\r\n\u0007\u0005M\u0002H\u0001\u0004B]f\u0014VM\u001a\t\u0004o\u0005]\u0012bAA\u001dq\ta1+\u001a:jC2L'0\u00192mKR\u0011\u00111F\u0001\u0006CB\u0004H._\u000b\u0005\u0003\u0003\nI\u0005\u0006\u0005\u0002D\u0005U\u0014\u0011PA>)\u0019\t)%a\u001b\u0002rA!\u0001\u0006AA$!\r\u0011\u0014\u0011\n\u0003\niE\u0001\u000b\u0011!AC\u0002UB\u0003\"!\u0013\u0002N\u0005M\u0013\u0011\r\t\u0004o\u0005=\u0013bAA)q\tY1\u000f]3dS\u0006d\u0017N_3ec%\u0019\u0013QKA,\u00037\nIFD\u00028\u0003/J1!!\u00179\u0003\u00151En\\1uc\u0019!\u0013QLA0s9\u0019!*a\u0018\n\u0003e\n\u0014bIA2\u0003K\nI'a\u001a\u000f\u0007]\n)'C\u0002\u0002ha\na\u0001R8vE2,\u0017G\u0002\u0013\u0002^\u0005}\u0013\bC\u0005\u0002nE\t\t\u0011q\u0001\u0002p\u0005QQM^5eK:\u001cW\r\n\u001a\u0011\ty\u000b\u0017q\t\u0005\u0007GF\u0001\u001d!a\u001d\u0011\t\u0015\u0014\u0018q\t\u0005\u0007\u007fE\u0001\r!a\u001e\u0011\t!\u0012\u0015q\t\u0005\b\u000bF\u0001\n\u00111\u0001H\u0011\u001d\u0019\u0016\u0003%AA\u0002U\u000b1\u0004\n7fgNLg.\u001b;%OJ,\u0017\r^3sI\u0011,g-Y;mi\u0012\u0012T\u0003BAA\u0003/+\"!a!+\u0007\u001d\u000b)i\u000b\u0002\u0002\bB!\u0011\u0011RAJ\u001b\t\tYI\u0003\u0003\u0002\u000e\u0006=\u0015!C;oG\",7m[3e\u0015\r\t\t\nO\u0001\u000bC:tw\u000e^1uS>t\u0017\u0002BAK\u0003\u0017\u0013\u0011#\u001e8dQ\u0016\u001c7.\u001a3WCJL\u0017M\\2f\t\u0015!$C1\u00016\u0003m!C.Z:tS:LG\u000fJ4sK\u0006$XM\u001d\u0013eK\u001a\fW\u000f\u001c;%gU!\u0011QTAQ+\t\tyJK\u0002V\u0003\u000b#Q\u0001N\nC\u0002U\nq\"\u00199qYf$C-\u001a4bk2$HEM\u000b\u0005\u0003\u0003\u000b9\u000bB\u00055)\u0001\u0006\t\u0011!b\u0001k!B\u0011qUA'\u0003W\u000by+M\u0005$\u0003+\n9&!,\u0002ZE2A%!\u0018\u0002`e\n\u0014bIA2\u0003K\n\t,a\u001a2\r\u0011\ni&a\u0018:\u0003=\t\u0007\u000f\u001d7zI\u0011,g-Y;mi\u0012\u001aT\u0003BAO\u0003o#\u0011\u0002N\u000b!\u0002\u0003\u0005)\u0019A\u001b)\u0011\u0005]\u0016QJA^\u0003\u007f\u000b\u0014bIA+\u0003/\ni,!\u00172\r\u0011\ni&a\u0018:c%\u0019\u00131MA3\u0003\u0003\f9'\r\u0004%\u0003;\ny&O\u0001\fe\u0016\fGMU3t_24X\r\u0006\u0002\u0002HB!\u00111AAe\u0013\u0011\tY-!\u0002\u0003\r=\u0013'.Z2u\u0001")
/* loaded from: input_file:com/intel/analytics/bigdl/nn/keras/Bidirectional.class */
public class Bidirectional<T> extends KerasLayer<Tensor<T>, Tensor<T>, T> {
    private final Recurrent<T> layer;
    private final String mergeMode;
    private final Shape inputShape;
    private final ClassTag<T> evidence$1;
    private final TensorNumericMath.TensorNumeric<T> ev;
    private final String mode;

    public Recurrent<T> layer() {
        return this.layer;
    }

    public String mergeMode() {
        return this.mergeMode;
    }

    public Shape inputShape() {
        return this.inputShape;
    }

    private String mode() {
        return this.mode;
    }

    @Override // com.intel.analytics.bigdl.nn.keras.KerasLayer, com.intel.analytics.bigdl.nn.abstractnn.AbstractModule, com.intel.analytics.bigdl.nn.abstractnn.InferShape
    public Shape computeOutputShape(Shape shape) {
        Shape build = layer().build(shape);
        String mode = mode();
        if (mode != null ? !mode.equals("concat") : "concat" != 0) {
            return build;
        }
        int[] iArr = (int[]) build.toSingle().toArray(ClassTag$.MODULE$.Int());
        iArr[iArr.length - 1] = BoxesRunTime.unboxToInt(new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps(iArr)).last()) * 2;
        return Shape$.MODULE$.apply(iArr);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v21, types: [com.intel.analytics.bigdl.nn.CMulTable] */
    /* JADX WARN: Type inference failed for: r0v23, types: [com.intel.analytics.bigdl.nn.CAddTable] */
    /* JADX WARN: Type inference failed for: r0v29, types: [com.intel.analytics.bigdl.nn.JoinTable] */
    @Override // com.intel.analytics.bigdl.nn.keras.KerasLayer
    public AbstractModule<Tensor<T>, Tensor<T>, T> doBuild(Shape shape) {
        CAveTable<T> apply;
        int[] iArr = (int[]) shape.toSingle().toArray(ClassTag$.MODULE$.Int());
        Cell<T> buildCell = layer().buildCell(iArr);
        String mode = mode();
        if ("concat".equals(mode)) {
            apply = JoinTable$.MODULE$.apply(iArr.length - 1, iArr.length - 1, this.evidence$1, this.ev);
        } else if ("sum".equals(mode)) {
            apply = CAddTable$.MODULE$.apply(CAddTable$.MODULE$.apply$default$1(), this.evidence$1, this.ev);
        } else if ("mul".equals(mode)) {
            apply = CMulTable$.MODULE$.apply(this.evidence$1, this.ev);
        } else {
            if (!"ave".equals(mode)) {
                throw new MatchError(mode);
            }
            apply = CAveTable$.MODULE$.apply(CAveTable$.MODULE$.apply$default$1(), this.evidence$1, this.ev);
        }
        BiRecurrent$ biRecurrent$ = BiRecurrent$.MODULE$;
        BiRecurrent$.MODULE$.apply$default$2();
        return biRecurrent$.apply(apply, null, BiRecurrent$.MODULE$.apply$default$3(), this.evidence$1, this.ev).mo742add((AbstractModule) buildCell);
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public Bidirectional(Recurrent<T> recurrent, String str, Shape shape, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        super(KerasLayer$.MODULE$.addBatch(shape), ClassTag$.MODULE$.apply(Tensor.class), ClassTag$.MODULE$.apply(Tensor.class), classTag, tensorNumeric);
        boolean z;
        this.layer = recurrent;
        this.mergeMode = str;
        this.inputShape = shape;
        this.evidence$1 = classTag;
        this.ev = tensorNumeric;
        this.mode = str.toLowerCase();
        Predef$.MODULE$.require(recurrent.returnSequences(), () -> {
            return "Bidirectional currently requires RNNs to return the full sequence";
        });
        Predef$ predef$ = Predef$.MODULE$;
        String mode = mode();
        if (mode != null ? !mode.equals("sum") : "sum" != 0) {
            String mode2 = mode();
            if (mode2 != null ? !mode2.equals("mul") : "mul" != 0) {
                String mode3 = mode();
                if (mode3 != null ? !mode3.equals("concat") : "concat" != 0) {
                    String mode4 = mode();
                    if (mode4 != null ? !mode4.equals("ave") : "ave" != 0) {
                        z = false;
                        predef$.require(z, () -> {
                            return new StringBuilder(20).append("Invalid merge mode: ").append(this.mode()).toString();
                        });
                    }
                }
            }
        }
        z = true;
        predef$.require(z, () -> {
            return new StringBuilder(20).append("Invalid merge mode: ").append(this.mode()).toString();
        });
    }
}
