package com.intel.analytics.bigdl.nn;

import com.intel.analytics.bigdl.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.serialization.Bigdl;
import com.intel.analytics.bigdl.tensor.ConvertableTo$ConvertableToDouble$;
import com.intel.analytics.bigdl.tensor.Tensor;
import com.intel.analytics.bigdl.tensor.Tensor$;
import com.intel.analytics.bigdl.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.utils.Table;
import com.intel.analytics.bigdl.utils.serializer.DeserializeContext;
import com.intel.analytics.bigdl.utils.serializer.ModuleData;
import com.intel.analytics.bigdl.utils.serializer.ModuleSerializable;
import com.intel.analytics.bigdl.utils.serializer.SerializeContext;
import com.intel.analytics.bigdl.utils.serializer.SerializeResult;
import scala.Predef$;
import scala.collection.Seq$;
import scala.collection.TraversableOnce;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: MaskedSelect.scala */
@ScalaSignature(bytes = "\u0006\u0001\t\ra\u0001B\r\u001b\u0001\u0015B\u0001b\u0012\u0001\u0003\u0004\u0003\u0006Y\u0001\u0013\u0005\t\u001d\u0002\u0011\t\u0011)A\u0006\u001f\")1\r\u0001C\u0001I\"9!\u000e\u0001b\u0001\n\u0013Y\u0007B\u00027\u0001A\u0003%1\u0007C\u0004n\u0001\t\u0007I\u0011B6\t\r9\u0004\u0001\u0015!\u00034\u0011\u001dy\u0007A1A\u0005\n-Da\u0001\u001d\u0001!\u0002\u0013\u0019\u0004bB9\u0001\u0005\u0004%Ia\u001b\u0005\u0007e\u0002\u0001\u000b\u0011B\u001a\t\u000bM\u0004A\u0011\t;\t\u000b]\u0004A\u0011\t=\t\u000bq\u0004A\u0011I?\t\r}\u0004A\u0011IA\u0001\u0011\u001d\ti\u0001\u0001C!\u0003\u001fAq!a\u0005\u0001\t\u0003\n)bB\u0004\u0002*iA\t!a\u000b\u0007\reQ\u0002\u0012AA\u0017\u0011\u0019\u00197\u0003\"\u0001\u0002H!9\u0011\u0011J\n\u0005\u0002\u0005-\u0003bBAA'\u0011\u0005\u00131\u0011\u0005\b\u0003S\u001bB\u0011IAV\u0011%\tyoEA\u0001\n\u0013\t\tP\u0001\u0007NCN\\W\rZ*fY\u0016\u001cGO\u0003\u0002\u001c9\u0005\u0011aN\u001c\u0006\u0003;y\tQAY5hI2T!a\b\u0011\u0002\u0013\u0005t\u0017\r\\=uS\u000e\u001c(BA\u0011#\u0003\u0015Ig\u000e^3m\u0015\u0005\u0019\u0013aA2p[\u000e\u0001QC\u0001\u0014<'\t\u0001q\u0005E\u0003)W5\u001a\u0014(D\u0001*\u0015\tQ#$\u0001\u0006bEN$(/Y2u]:L!\u0001L\u0015\u0003\u001d\u0005\u00137\u000f\u001e:bGRlu\u000eZ;mKB\u0011a&M\u0007\u0002_)\u0011\u0001\u0007H\u0001\u0006kRLGn]\u0005\u0003e=\u0012Q\u0001V1cY\u0016\u00042\u0001N\u001c:\u001b\u0005)$B\u0001\u001c\u001d\u0003\u0019!XM\\:pe&\u0011\u0001(\u000e\u0002\u0007)\u0016t7o\u001c:\u0011\u0005iZD\u0002\u0001\u0003\u0006y\u0001\u0011\r!\u0010\u0002\u0002)F\u0011a\b\u0012\t\u0003\u007f\tk\u0011\u0001\u0011\u0006\u0002\u0003\u0006)1oY1mC&\u00111\t\u0011\u0002\b\u001d>$\b.\u001b8h!\tyT)\u0003\u0002G\u0001\n\u0019\u0011I\\=\u0002\u0015\u00154\u0018\u000eZ3oG\u0016$\u0013\u0007E\u0002J\u0019fj\u0011A\u0013\u0006\u0003\u0017\u0002\u000bqA]3gY\u0016\u001cG/\u0003\u0002N\u0015\nA1\t\\1tgR\u000bw-\u0001\u0002fmB\u0019\u0001\u000bY\u001d\u000f\u0005EsfB\u0001*^\u001d\t\u0019FL\u0004\u0002U7:\u0011QK\u0017\b\u0003-fk\u0011a\u0016\u0006\u00031\u0012\na\u0001\u0010:p_Rt\u0014\"A\u0012\n\u0005\u0005\u0012\u0013BA\u0010!\u0013\tib$\u0003\u000279%\u0011q,N\u0001\u0012)\u0016t7o\u001c:Ok6,'/[2NCRD\u0017BA1c\u00055!VM\\:pe:+X.\u001a:jG*\u0011q,N\u0001\u0007y%t\u0017\u000e\u001e \u0015\u0003\u0015$2A\u001a5j!\r9\u0007!O\u0007\u00025!)qi\u0001a\u0002\u0011\")aj\u0001a\u0002\u001f\u0006YQ.Y:l\u0013:$\u0017nY3t+\u0005\u0019\u0014\u0001D7bg.Le\u000eZ5dKN\u0004\u0013aD7bg.Le\u000eZ3y\u0005V4g-\u001a:\u0002!5\f7o[%oI\u0016D()\u001e4gKJ\u0004\u0013AC4sC\u0012\u0014UO\u001a4fe\u0006YqM]1e\u0005V4g-\u001a:!\u0003!9'/\u00193NCN\\\u0017!C4sC\u0012l\u0015m]6!\u00031)\b\u000fZ1uK>+H\u000f];u)\t\u0019T\u000fC\u0003w\u0019\u0001\u0007Q&A\u0003j]B,H/A\bva\u0012\fG/Z$sC\u0012Le\u000e];u)\ri\u0013P\u001f\u0005\u0006m6\u0001\r!\f\u0005\u0006w6\u0001\raM\u0001\u000bOJ\fGmT;uaV$\u0018AC2mK\u0006\u00148\u000b^1uKR\ta0D\u0001\u0001\u0003!\u0019\u0017M\\#rk\u0006dG\u0003BA\u0002\u0003\u0013\u00012aPA\u0003\u0013\r\t9\u0001\u0011\u0002\b\u0005>|G.Z1o\u0011\u0019\tYa\u0004a\u0001\t\u0006)q\u000e\u001e5fe\u00061Q-];bYN$B!a\u0001\u0002\u0012!1\u00111\u0002\tA\u0002\u0011\u000b\u0001\u0002[1tQ\u000e{G-\u001a\u000b\u0003\u0003/\u00012aPA\r\u0013\r\tY\u0002\u0011\u0002\u0004\u0013:$\bf\u0002\u0001\u0002 \u0005\u0015\u0012q\u0005\t\u0004\u007f\u0005\u0005\u0012bAA\u0012\u0001\n\u00012+\u001a:jC24VM]:j_:,\u0016\nR\u0001\u0006m\u0006dW/\u001a\u0010\to2Sdj\tCA.\u0005aQ*Y:lK\u0012\u001cV\r\\3diB\u0011qmE\n\b'\u0005=\u0012QGA!!\ry\u0014\u0011G\u0005\u0004\u0003g\u0001%AB!osJ+g\r\u0005\u0003\u00028\u0005uRBAA\u001d\u0015\r\tYdL\u0001\u000bg\u0016\u0014\u0018.\u00197ju\u0016\u0014\u0018\u0002BA \u0003s\u0011!#T8ek2,7+\u001a:jC2L'0\u00192mKB\u0019q(a\u0011\n\u0007\u0005\u0015\u0003I\u0001\u0007TKJL\u0017\r\\5{C\ndW\r\u0006\u0002\u0002,\u0005)\u0011\r\u001d9msV!\u0011QJA+)\t\ty\u0005\u0006\u0004\u0002R\u0005]\u0014Q\u0010\t\u0005O\u0002\t\u0019\u0006E\u0002;\u0003+\"\u0011\u0002P\u000b!\u0002\u0003\u0005)\u0019A\u001f)\u0011\u0005U\u0013\u0011LA0\u0003[\u00022aPA.\u0013\r\ti\u0006\u0011\u0002\fgB,7-[1mSj,G-M\u0005$\u0003C\n\u0019'a\u001a\u0002f9\u0019q(a\u0019\n\u0007\u0005\u0015\u0004)A\u0003GY>\fG/\r\u0004%\u0003S\nY'\u0011\b\u0004-\u0006-\u0014\"A!2\u0013\r\ny'!\u001d\u0002v\u0005MdbA \u0002r%\u0019\u00111\u000f!\u0002\r\u0011{WO\u00197fc\u0019!\u0013\u0011NA6\u0003\"I\u0011\u0011P\u000b\u0002\u0002\u0003\u000f\u00111P\u0001\u000bKZLG-\u001a8dK\u0012\u0012\u0004\u0003B%M\u0003'BaAT\u000bA\u0004\u0005}\u0004\u0003\u0002)a\u0003'\nA\u0002Z8M_\u0006$Wj\u001c3vY\u0016,B!!\"\u0002\u0014R!\u0011qQAP)\u0019\tI)!&\u0002\u001cBA\u0001fKAF\u0003\u0017\u000b\t\nE\u0002)\u0003\u001bK1!a$*\u0005!\t5\r^5wSRL\bc\u0001\u001e\u0002\u0014\u0012)AH\u0006b\u0001{!I\u0011q\u0013\f\u0002\u0002\u0003\u000f\u0011\u0011T\u0001\u000bKZLG-\u001a8dK\u0012\u001a\u0004\u0003B%M\u0003#CaA\u0014\fA\u0004\u0005u\u0005\u0003\u0002)a\u0003#Cq!!)\u0017\u0001\u0004\t\u0019+A\u0004d_:$X\r\u001f;\u0011\t\u0005]\u0012QU\u0005\u0005\u0003O\u000bID\u0001\nEKN,'/[1mSj,7i\u001c8uKb$\u0018!\u00053p'\u0016\u0014\u0018.\u00197ju\u0016lu\u000eZ;mKV!\u0011QVA`)\u0019\ty+!2\u0002NR1\u0011\u0011WA\\\u0003\u0003\u00042aPAZ\u0013\r\t)\f\u0011\u0002\u0005+:LG\u000fC\u0005\u0002:^\t\t\u0011q\u0001\u0002<\u0006QQM^5eK:\u001cW\r\n\u001b\u0011\t%c\u0015Q\u0018\t\u0004u\u0005}F!\u0002\u001f\u0018\u0005\u0004i\u0004B\u0002(\u0018\u0001\b\t\u0019\r\u0005\u0003QA\u0006u\u0006bBAQ/\u0001\u0007\u0011q\u0019\t\u0007\u0003o\tI-!0\n\t\u0005-\u0017\u0011\b\u0002\u0011'\u0016\u0014\u0018.\u00197ju\u0016\u001cuN\u001c;fqRDq!a4\u0018\u0001\u0004\t\t.A\nnCN\\W\rZ*fY\u0016\u001cGOQ;jY\u0012,'\u000f\u0005\u0003\u0002T\u0006%h\u0002BAk\u0003GtA!a6\u0002^:\u0019!+!7\n\u0007\u0005mG$A\u0007tKJL\u0017\r\\5{CRLwN\\\u0005\u0005\u0003?\f\t/A\u0003CS\u001e$GNC\u0002\u0002\\rIA!!:\u0002h\u0006Y!)[4E\u00196{G-\u001e7f\u0015\u0011\ty.!9\n\t\u0005-\u0018Q\u001e\u0002\b\u0005VLG\u000eZ3s\u0015\u0011\t)/a:\u0002\u0017I,\u0017\r\u001a*fg>dg/\u001a\u000b\u0003\u0003g\u0004B!!>\u0002��6\u0011\u0011q\u001f\u0006\u0005\u0003s\fY0\u0001\u0003mC:<'BAA\u007f\u0003\u0011Q\u0017M^1\n\t\t\u0005\u0011q\u001f\u0002\u0007\u001f\nTWm\u0019;")
/* loaded from: input_file:com/intel/analytics/bigdl/nn/MaskedSelect.class */
public class MaskedSelect<T> extends AbstractModule<Table, Tensor<T>, T> {
    public static final long serialVersionUID = 8596309896021196822L;
    private final TensorNumericMath.TensorNumeric<T> ev;
    private final Tensor<T> com$intel$analytics$bigdl$nn$MaskedSelect$$maskIndices;
    private final Tensor<T> com$intel$analytics$bigdl$nn$MaskedSelect$$maskIndexBuffer;
    private final Tensor<T> com$intel$analytics$bigdl$nn$MaskedSelect$$gradBuffer;
    private final Tensor<T> com$intel$analytics$bigdl$nn$MaskedSelect$$gradMask;

    public static <T> void doSerializeModule(SerializeContext<T> serializeContext, Bigdl.BigDLModule.Builder builder, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        MaskedSelect$.MODULE$.doSerializeModule(serializeContext, builder, classTag, tensorNumeric);
    }

    public static <T> AbstractModule<Activity, Activity, T> doLoadModule(DeserializeContext deserializeContext, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        return MaskedSelect$.MODULE$.doLoadModule(deserializeContext, classTag, tensorNumeric);
    }

    public static <T> SerializeResult serializeModule(SerializeContext<T> serializeContext, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        return MaskedSelect$.MODULE$.serializeModule(serializeContext, classTag, tensorNumeric);
    }

    public static <T> ModuleData<T> loadModule(DeserializeContext deserializeContext, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        return MaskedSelect$.MODULE$.loadModule(deserializeContext, classTag, tensorNumeric);
    }

    public static ModuleSerializable setCopyWeightAndBias(boolean z) {
        return MaskedSelect$.MODULE$.setCopyWeightAndBias(z);
    }

    public Tensor<T> com$intel$analytics$bigdl$nn$MaskedSelect$$maskIndices() {
        return this.com$intel$analytics$bigdl$nn$MaskedSelect$$maskIndices;
    }

    public Tensor<T> com$intel$analytics$bigdl$nn$MaskedSelect$$maskIndexBuffer() {
        return this.com$intel$analytics$bigdl$nn$MaskedSelect$$maskIndexBuffer;
    }

    public Tensor<T> com$intel$analytics$bigdl$nn$MaskedSelect$$gradBuffer() {
        return this.com$intel$analytics$bigdl$nn$MaskedSelect$$gradBuffer;
    }

    public Tensor<T> com$intel$analytics$bigdl$nn$MaskedSelect$$gradMask() {
        return this.com$intel$analytics$bigdl$nn$MaskedSelect$$gradMask;
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    public Tensor<T> updateOutput(Table table) {
        Tensor tensor = (Tensor) table.apply(BoxesRunTime.boxToInteger(1));
        Tensor<T> tensor2 = (Tensor) table.apply(BoxesRunTime.boxToInteger(2));
        if (BoxesRunTime.unboxToDouble(this.ev.toType(tensor2.mo1129sum(), ConvertableTo$ConvertableToDouble$.MODULE$)) > 0) {
            tensor.maskedSelect(tensor2, output());
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        return output();
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    /* renamed from: updateGradInput, reason: merged with bridge method [inline-methods] */
    public Table updateGradInput2(Table table, Tensor<T> tensor) {
        Tensor<?> tensor2 = (Tensor) table.apply(BoxesRunTime.boxToInteger(1));
        Tensor<?> tensor3 = (Tensor) table.apply(BoxesRunTime.boxToInteger(2));
        com$intel$analytics$bigdl$nn$MaskedSelect$$maskIndexBuffer().range(1.0d, tensor3.nElement(), com$intel$analytics$bigdl$nn$MaskedSelect$$maskIndexBuffer().range$default$3());
        com$intel$analytics$bigdl$nn$MaskedSelect$$maskIndexBuffer().resizeAs(tensor3);
        if (BoxesRunTime.unboxToDouble(this.ev.toType(tensor3.mo1129sum(), ConvertableTo$ConvertableToDouble$.MODULE$)) > 0) {
            com$intel$analytics$bigdl$nn$MaskedSelect$$maskIndexBuffer().maskedSelect(tensor3, com$intel$analytics$bigdl$nn$MaskedSelect$$maskIndices());
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        com$intel$analytics$bigdl$nn$MaskedSelect$$gradBuffer().resize(tensor2.nElement()).zero();
        com$intel$analytics$bigdl$nn$MaskedSelect$$gradBuffer().scatter(1, com$intel$analytics$bigdl$nn$MaskedSelect$$maskIndices(), tensor);
        com$intel$analytics$bigdl$nn$MaskedSelect$$gradBuffer().resizeAs(tensor2);
        gradInput().insert(1, com$intel$analytics$bigdl$nn$MaskedSelect$$gradBuffer());
        gradInput().insert(2, com$intel$analytics$bigdl$nn$MaskedSelect$$gradMask().resizeAs(tensor3).zero());
        return gradInput();
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    /* renamed from: clearState */
    public MaskedSelect<T> clearState2() {
        super.clearState2();
        com$intel$analytics$bigdl$nn$MaskedSelect$$maskIndices().set();
        com$intel$analytics$bigdl$nn$MaskedSelect$$maskIndexBuffer().set();
        com$intel$analytics$bigdl$nn$MaskedSelect$$gradBuffer().set();
        com$intel$analytics$bigdl$nn$MaskedSelect$$gradMask().set();
        return this;
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    public boolean canEqual(Object obj) {
        return obj instanceof MaskedSelect;
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    public boolean equals(Object obj) {
        boolean z;
        if (obj instanceof MaskedSelect) {
            MaskedSelect maskedSelect = (MaskedSelect) obj;
            z = super.equals(maskedSelect) && maskedSelect.canEqual(this);
        } else {
            z = false;
        }
        return z;
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    public int hashCode() {
        return BoxesRunTime.unboxToInt(((TraversableOnce) Seq$.MODULE$.apply(Predef$.MODULE$.wrapIntArray(new int[]{super.hashCode()})).map(obj -> {
            return BoxesRunTime.boxToInteger(getHashCode$1(obj));
        }, Seq$.MODULE$.canBuildFrom())).foldLeft(BoxesRunTime.boxToInteger(0), (i, i2) -> {
            return (31 * i) + i2;
        }));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static final int getHashCode$1(Object obj) {
        if (obj == null) {
            return 0;
        }
        return obj.hashCode();
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public MaskedSelect(ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        super(ClassTag$.MODULE$.apply(Table.class), ClassTag$.MODULE$.apply(Tensor.class), classTag, tensorNumeric);
        this.ev = tensorNumeric;
        this.com$intel$analytics$bigdl$nn$MaskedSelect$$maskIndices = Tensor$.MODULE$.apply(classTag, tensorNumeric);
        this.com$intel$analytics$bigdl$nn$MaskedSelect$$maskIndexBuffer = Tensor$.MODULE$.apply(classTag, tensorNumeric);
        this.com$intel$analytics$bigdl$nn$MaskedSelect$$gradBuffer = Tensor$.MODULE$.apply(classTag, tensorNumeric);
        this.com$intel$analytics$bigdl$nn$MaskedSelect$$gradMask = Tensor$.MODULE$.apply(classTag, tensorNumeric);
    }
}
