package com.intel.analytics.bigdl.nn;

import com.intel.analytics.bigdl.nn.abstractnn.TensorModule;
import com.intel.analytics.bigdl.tensor.ConvertableFrom$ConvertableFromDouble$;
import com.intel.analytics.bigdl.tensor.Tensor;
import com.intel.analytics.bigdl.tensor.Tensor$;
import com.intel.analytics.bigdl.tensor.TensorNumericMath;
import opennlp.tools.parser.Parse;
import scala.reflect.ClassTag;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;

/* compiled from: SpatialDropout1D.scala */
@ScalaSignature(bytes = "\u0006\u0001\u00055e\u0001B\f\u0019\u0001\rB\u0001\"\u000f\u0001\u0003\u0006\u0004%\tA\u000f\u0005\t}\u0001\u0011\t\u0011)A\u0005w!Aq\b\u0001B\u0002B\u0003-\u0001\t\u0003\u0005G\u0001\t\u0005\t\u0015a\u0003H\u0011\u0015i\u0006\u0001\"\u0001_\u0011\u001d)\u0007\u00011A\u0005\u0002iBqA\u001a\u0001A\u0002\u0013\u0005q\r\u0003\u0004n\u0001\u0001\u0006Ka\u000f\u0005\b]\u0002\u0001\r\u0011\"\u0001p\u0011\u001d!\b\u00011A\u0005\u0002UDaa\u001e\u0001!B\u0013\u0001\b\"\u0002=\u0001\t\u0003J\b\"\u0002?\u0001\t\u0003j\bbBA\u0002\u0001\u0011\u0005\u0013Q\u0001\u0005\b\u0003\u0013\u0001A\u0011IA\u0006\u000f\u001d\tI\u0003\u0007E\u0001\u0003W1aa\u0006\r\t\u0002\u00055\u0002BB/\u0012\t\u0003\tY\u0004C\u0004\u0002>E!\t!a\u0010\t\u0013\u0005]\u0013#%A\u0005\u0002\u0005e\u0003\"CA:#E\u0005I\u0011AA;\u0011%\tI(EA\u0001\n\u0013\tYH\u0001\tTa\u0006$\u0018.\u00197Ee>\u0004x.\u001e;2\t*\u0011\u0011DG\u0001\u0003]:T!a\u0007\u000f\u0002\u000b\tLw\r\u001a7\u000b\u0005uq\u0012!C1oC2LH/[2t\u0015\ty\u0002%A\u0003j]R,GNC\u0001\"\u0003\r\u0019w.\\\u0002\u0001+\t!Sf\u0005\u0002\u0001KA\u0019a%K\u0016\u000e\u0003\u001dR!\u0001\u000b\r\u0002\u0015\u0005\u00147\u000f\u001e:bGRtg.\u0003\u0002+O\taA+\u001a8t_Jlu\u000eZ;mKB\u0011A&\f\u0007\u0001\t\u0015q\u0003A1\u00010\u0005\u0005!\u0016C\u0001\u00197!\t\tD'D\u00013\u0015\u0005\u0019\u0014!B:dC2\f\u0017BA\u001b3\u0005\u001dqu\u000e\u001e5j]\u001e\u0004\"!M\u001c\n\u0005a\u0012$aA!os\u0006)\u0011N\\5u!V\t1\b\u0005\u00022y%\u0011QH\r\u0002\u0007\t>,(\r\\3\u0002\r%t\u0017\u000e\u001e)!\u0003))g/\u001b3f]\u000e,G%\r\t\u0004\u0003\u0012[S\"\u0001\"\u000b\u0005\r\u0013\u0014a\u0002:fM2,7\r^\u0005\u0003\u000b\n\u0013\u0001b\u00117bgN$\u0016mZ\u0001\u0003KZ\u00042\u0001\u0013.,\u001d\tIuK\u0004\u0002K+:\u00111\n\u0016\b\u0003\u0019Ns!!\u0014*\u000f\u00059\u000bV\"A(\u000b\u0005A\u0013\u0013A\u0002\u001fs_>$h(C\u0001\"\u0013\ty\u0002%\u0003\u0002\u001e=%\u00111\u0004H\u0005\u0003-j\ta\u0001^3og>\u0014\u0018B\u0001-Z\u0003E!VM\\:pe:+X.\u001a:jG6\u000bG\u000f\u001b\u0006\u0003-jI!a\u0017/\u0003\u001bQ+gn]8s\u001dVlWM]5d\u0015\tA\u0016,\u0001\u0004=S:LGO\u0010\u000b\u0003?\u0012$2\u0001\u00192d!\r\t\u0007aK\u0007\u00021!)q(\u0002a\u0002\u0001\")a)\u0002a\u0002\u000f\"9\u0011(\u0002I\u0001\u0002\u0004Y\u0014!\u00019\u0002\u000bA|F%Z9\u0015\u0005!\\\u0007CA\u0019j\u0013\tQ'G\u0001\u0003V]&$\bb\u00027\b\u0003\u0003\u0005\raO\u0001\u0004q\u0012\n\u0014A\u00019!\u0003\u0015qw.[:f+\u0005\u0001\bcA9sW5\t\u0011,\u0003\u0002t3\n1A+\u001a8t_J\f\u0011B\\8jg\u0016|F%Z9\u0015\u0005!4\bb\u00027\u000b\u0003\u0003\u0005\r\u0001]\u0001\u0007]>L7/\u001a\u0011\u0002\u0019U\u0004H-\u0019;f\u001fV$\b/\u001e;\u0015\u0005AT\b\"B>\r\u0001\u0004\u0001\u0018!B5oaV$\u0018aD;qI\u0006$Xm\u0012:bI&s\u0007/\u001e;\u0015\u0007Atx\u0010C\u0003|\u001b\u0001\u0007\u0001\u000f\u0003\u0004\u0002\u00025\u0001\r\u0001]\u0001\u000bOJ\fGmT;uaV$\u0018AC2mK\u0006\u00148\u000b^1uKR\u0011\u0011qA\u0007\u0002\u0001\u0005AAo\\*ue&tw\r\u0006\u0002\u0002\u000eA!\u0011qBA\f\u001d\u0011\t\t\"a\u0005\u0011\u00059\u0013\u0014bAA\u000be\u00051\u0001K]3eK\u001aLA!!\u0007\u0002\u001c\t11\u000b\u001e:j]\u001eT1!!\u00063Q\u001d\u0001\u0011qDA\u0013\u0003O\u00012!MA\u0011\u0013\r\t\u0019C\r\u0002\u0011'\u0016\u0014\u0018.\u00197WKJ\u001c\u0018n\u001c8V\u0013\u0012\u000bQA^1mk\u0016t\u0002b0Uq?\u001a ~Dk\u0001\u0011'B\fG/[1m\tJ|\u0007o\\;uc\u0011\u0003\"!Y\t\u0014\u000bE\ty#!\u000e\u0011\u0007E\n\t$C\u0002\u00024I\u0012a!\u00118z%\u00164\u0007cA\u0019\u00028%\u0019\u0011\u0011\b\u001a\u0003\u0019M+'/[1mSj\f'\r\\3\u0015\u0005\u0005-\u0012!B1qa2LX\u0003BA!\u0003\u0013\"B!a\u0011\u0002VQ1\u0011QIA&\u0003#\u0002B!\u0019\u0001\u0002HA\u0019A&!\u0013\u0005\u000b9\u001a\"\u0019A\u0018\t\u0013\u000553#!AA\u0004\u0005=\u0013AC3wS\u0012,gnY3%eA!\u0011\tRA$\u0011\u001915\u0003q\u0001\u0002TA!\u0001JWA$\u0011\u001dI4\u0003%AA\u0002m\nq\"\u00199qYf$C-\u001a4bk2$H%M\u000b\u0005\u00037\n\t(\u0006\u0002\u0002^)\u001a1(a\u0018,\u0005\u0005\u0005\u0004\u0003BA2\u0003[j!!!\u001a\u000b\t\u0005\u001d\u0014\u0011N\u0001\nk:\u001c\u0007.Z2lK\u0012T1!a\u001b3\u0003)\tgN\\8uCRLwN\\\u0005\u0005\u0003_\n)GA\tv]\u000eDWmY6fIZ\u000b'/[1oG\u0016$QA\f\u000bC\u0002=\n1\u0004\n7fgNLg.\u001b;%OJ,\u0017\r^3sI\u0011,g-Y;mi\u0012\nT\u0003BA.\u0003o\"QAL\u000bC\u0002=\n1B]3bIJ+7o\u001c7wKR\u0011\u0011Q\u0010\t\u0005\u0003\u007f\nI)\u0004\u0002\u0002\u0002*!\u00111QAC\u0003\u0011a\u0017M\\4\u000b\u0005\u0005\u001d\u0015\u0001\u00026bm\u0006LA!a#\u0002\u0002\n1qJ\u00196fGR\u0004")
/* loaded from: input_file:com/intel/analytics/bigdl/nn/SpatialDropout1D.class */
public class SpatialDropout1D<T> extends TensorModule<T> {
    public static final long serialVersionUID = -4636332259181125718L;
    private final double initP;
    private final TensorNumericMath.TensorNumeric<T> ev;
    private double p;
    private Tensor<T> noise;

    public double initP() {
        return this.initP;
    }

    public double p() {
        return this.p;
    }

    public void p_$eq(double d) {
        this.p = d;
    }

    public Tensor<T> noise() {
        return this.noise;
    }

    public void noise_$eq(Tensor<T> tensor) {
        this.noise = tensor;
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    public Tensor<T> updateOutput(Tensor<T> tensor) {
        output().resizeAs(tensor).copy(tensor);
        if (!train()) {
            return output().mul(this.ev.mo2991fromType(BoxesRunTime.boxToDouble(1 - p()), ConvertableFrom$ConvertableFromDouble$.MODULE$));
        }
        int[] size = tensor.size();
        if (tensor.dim() == 2) {
            Tensor<T> noise = noise();
            noise.resize(new int[]{1, size[1]}, noise.resize$default$2());
        } else {
            if (tensor.dim() != 3) {
                throw new RuntimeException("SpatialDropout1D: Input must be 3D or 4D");
            }
            Tensor<T> noise2 = noise();
            noise2.resize(new int[]{size[0], 1, size[2]}, noise2.resize$default$2());
        }
        noise().bernoulli(1 - p());
        return output().cmul(noise().expandAs(tensor));
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    public Tensor<T> updateGradInput(Tensor<T> tensor, Tensor<T> tensor2) {
        if (!train()) {
            throw new RuntimeException("SpatialDropout1D: backprop only defined while training");
        }
        gradInput().resizeAs(tensor2).copy(tensor2);
        gradInput().cmul(noise().expandAs(tensor));
        return gradInput();
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    public SpatialDropout1D<T> clearState() {
        super.clearState();
        noise().set();
        return this;
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    public String toString() {
        return new StringBuilder(2).append(getPrintName()).append(Parse.BRACKET_LRB).append(p()).append(Parse.BRACKET_RRB).toString();
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public SpatialDropout1D(double d, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        super(classTag, tensorNumeric);
        this.initP = d;
        this.ev = tensorNumeric;
        this.p = d;
        this.noise = Tensor$.MODULE$.apply(classTag, tensorNumeric);
    }
}
