package com.intel.analytics.bigdl.nn;

import com.intel.analytics.bigdl.nn.abstractnn.TensorModule;
import com.intel.analytics.bigdl.tensor.Tensor;
import com.intel.analytics.bigdl.tensor.Tensor$;
import com.intel.analytics.bigdl.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.utils.Engine$;
import com.intel.analytics.bigdl.utils.Shape;
import com.intel.analytics.bigdl.utils.ThreadPool;
import scala.Function0;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.concurrent.Future;
import scala.reflect.ClassTag;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;

/* compiled from: LogSoftMax.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005\rh\u0001B\u0011#\u00015B\u0001b\u0011\u0001\u0003\u0004\u0003\u0006Y\u0001\u0012\u0005\t\u0015\u0002\u0011\t\u0011)A\u0006\u0017\")\u0011\r\u0001C\u0001E\"9\u0001\u000e\u0001b\u0001\n\u0013I\u0007B\u00028\u0001A\u0003%!\u000eC\u0004p\u0001\t\u0007I\u0011B5\t\rA\u0004\u0001\u0015!\u0003k\u0011\u0015\t\b\u0001\"\u0011s\u0011\u0015)\b\u0001\"\u0003w\u0011\u0015q\b\u0001\"\u0011��\u0011\u001d\t9\u0001\u0001C\u0005\u0003\u0013Aq!!\u0005\u0001\t\u0003\n\u0019\u0002C\u0004\u0002\u0018\u0001!\t%!\u0007\t\u0013\u0005-\u0002\u00011A\u0005\n\u00055\u0002\"CA!\u0001\u0001\u0007I\u0011BA\"\u0011!\tI\u0005\u0001Q!\n\u0005=raBA0E!\u0005\u0011\u0011\r\u0004\u0007C\tB\t!a\u0019\t\r\u0005\u0014B\u0011AA9\u0011\u001d\t\u0019H\u0005C\u0001\u0003kB\u0011\"a+\u0013\u0005\u0004%I!!,\t\u0011\u0005U&\u0003)A\u0005\u0003_C\u0011\"a.\u0013\u0005\u0004%I!!,\t\u0011\u0005e&\u0003)A\u0005\u0003_C\u0011\"a/\u0013\u0005\u0004%I!!,\t\u0011\u0005u&\u0003)A\u0005\u0003_C\u0011\"a0\u0013\u0005\u0004%I!!,\t\u0011\u0005\u0005'\u0003)A\u0005\u0003_C\u0011\"a1\u0013\u0005\u0004%I!!,\t\u0011\u0005\u0015'\u0003)A\u0005\u0003_Cq!a2\u0013\t\u0003\tI\rC\u0005\u0002PJ\t\t\u0011\"\u0003\u0002R\nQAj\\4T_\u001a$X*\u0019=\u000b\u0005\r\"\u0013A\u00018o\u0015\t)c%A\u0003cS\u001e$GN\u0003\u0002(Q\u0005I\u0011M\\1msRL7m\u001d\u0006\u0003S)\nQ!\u001b8uK2T\u0011aK\u0001\u0004G>l7\u0001A\u000b\u0003]]\u001a\"\u0001A\u0018\u0011\u0007A\u001aT'D\u00012\u0015\t\u0011$%\u0001\u0006bEN$(/Y2u]:L!\u0001N\u0019\u0003\u0019Q+gn]8s\u001b>$W\u000f\\3\u0011\u0005Y:D\u0002\u0001\u0003\u0006q\u0001\u0011\r!\u000f\u0002\u0002)F\u0011!\b\u0011\t\u0003wyj\u0011\u0001\u0010\u0006\u0002{\u0005)1oY1mC&\u0011q\b\u0010\u0002\b\u001d>$\b.\u001b8h!\tY\u0014)\u0003\u0002Cy\t\u0019\u0011I\\=\u0002\u0015\u00154\u0018\u000eZ3oG\u0016$\u0013\u0007E\u0002F\u0011Vj\u0011A\u0012\u0006\u0003\u000fr\nqA]3gY\u0016\u001cG/\u0003\u0002J\r\nA1\t\\1tgR\u000bw-\u0001\u0002fmB\u0019AJX\u001b\u000f\u00055[fB\u0001(Z\u001d\ty\u0005L\u0004\u0002Q/:\u0011\u0011K\u0016\b\u0003%Vk\u0011a\u0015\u0006\u0003)2\na\u0001\u0010:p_Rt\u0014\"A\u0016\n\u0005%R\u0013BA\u0014)\u0013\t)c%\u0003\u0002[I\u00051A/\u001a8t_JL!\u0001X/\u0002#Q+gn]8s\u001dVlWM]5d\u001b\u0006$\bN\u0003\u0002[I%\u0011q\f\u0019\u0002\u000e)\u0016t7o\u001c:Ok6,'/[2\u000b\u0005qk\u0016A\u0002\u001fj]&$h\bF\u0001d)\r!gm\u001a\t\u0004K\u0002)T\"\u0001\u0012\t\u000b\r\u001b\u00019\u0001#\t\u000b)\u001b\u00019A&\u0002\t=tWm]\u000b\u0002UB\u00191\u000e\\\u001b\u000e\u0003uK!!\\/\u0003\rQ+gn]8s\u0003\u0015yg.Z:!\u0003\u0019\u0011WO\u001a4fe\u00069!-\u001e4gKJ\u0004\u0013\u0001D;qI\u0006$XmT;uaV$HC\u00016t\u0011\u0015!\b\u00021\u0001k\u0003\u0015Ig\u000e];u\u0003E)\b\u000fZ1uK>+H\u000f];u\rJ\fW.\u001a\u000b\u0004ojd\bCA\u001ey\u0013\tIHH\u0001\u0003V]&$\b\"B>\n\u0001\u0004Q\u0017AA5o\u0011\u0015i\u0018\u00021\u0001k\u0003\ryW\u000f^\u0001\u0010kB$\u0017\r^3He\u0006$\u0017J\u001c9viR)!.!\u0001\u0002\u0004!)AO\u0003a\u0001U\"1\u0011Q\u0001\u0006A\u0002)\f!b\u001a:bI>+H\u000f];u\u0003Q)\b\u000fZ1uK\u001e\u0013\u0018\rZ%oaV$hI]1nKR)q/a\u0003\u0002\u000e!)Qp\u0003a\u0001U\"1\u0011qB\u0006A\u0002)\fqa\u001a:bI>+H/\u0001\u0006dY\u0016\f'o\u0015;bi\u0016$\"!!\u0006\u000e\u0003\u0001\t!cY8naV$XmT;uaV$8\u000b[1qKR!\u00111DA\u0014!\u0011\ti\"a\t\u000e\u0005\u0005}!bAA\u0011I\u0005)Q\u000f^5mg&!\u0011QEA\u0010\u0005\u0015\u0019\u0006.\u00199f\u0011\u001d\tI#\u0004a\u0001\u00037\t!\"\u001b8qkR\u001c\u0006.\u00199f\u0003\u001d\u0011Xm];miN,\"!a\f\u0011\u000bm\n\t$!\u000e\n\u0007\u0005MBHA\u0003BeJ\f\u0017\u0010E\u0003\u00028\u0005ur/\u0004\u0002\u0002:)\u0019\u00111\b\u001f\u0002\u0015\r|gnY;se\u0016tG/\u0003\u0003\u0002@\u0005e\"A\u0002$viV\u0014X-A\u0006sKN,H\u000e^:`I\u0015\fHcA<\u0002F!I\u0011qI\b\u0002\u0002\u0003\u0007\u0011qF\u0001\u0004q\u0012\n\u0014\u0001\u0003:fgVdGo\u001d\u0011)\u0007A\ti\u0005E\u0002<\u0003\u001fJ1!!\u0015=\u0005%!(/\u00198tS\u0016tG\u000fK\u0004\u0001\u0003+\nY&!\u0018\u0011\u0007m\n9&C\u0002\u0002Zq\u0012\u0001cU3sS\u0006dg+\u001a:tS>tW+\u0013#\u0002\u000bY\fG.^3\u001f\u0011Y\u007f\u0010Q\t\u000b{M��\u000f!\u0002T8h'>4G/T1y!\t)'cE\u0003\u0013\u0003K\nY\u0007E\u0002<\u0003OJ1!!\u001b=\u0005\u0019\te.\u001f*fMB\u00191(!\u001c\n\u0007\u0005=DH\u0001\u0007TKJL\u0017\r\\5{C\ndW\r\u0006\u0002\u0002b\u0005)\u0011\r\u001d9msV!\u0011qOA@)\t\tI\b\u0006\u0004\u0002|\u0005\u0005\u0016q\u0015\t\u0005K\u0002\ti\bE\u00027\u0003\u007f\"\u0011\u0002\u000f\u000b!\u0002\u0003\u0005)\u0019A\u001d)\u0011\u0005}\u00141QAE\u0003/\u00032aOAC\u0013\r\t9\t\u0010\u0002\fgB,7-[1mSj,G-M\u0005$\u0003\u0017\u000bi)!%\u0002\u0010:\u00191(!$\n\u0007\u0005=E(A\u0003GY>\fG/\r\u0004%\u0003'\u000b)*\u0010\b\u0004%\u0006U\u0015\"A\u001f2\u0013\r\nI*a'\u0002 \u0006uebA\u001e\u0002\u001c&\u0019\u0011Q\u0014\u001f\u0002\r\u0011{WO\u00197fc\u0019!\u00131SAK{!I\u00111\u0015\u000b\u0002\u0002\u0003\u000f\u0011QU\u0001\u000bKZLG-\u001a8dK\u0012\u0012\u0004\u0003B#I\u0003{BaA\u0013\u000bA\u0004\u0005%\u0006\u0003\u0002'_\u0003{\n!!\u0011\u0019\u0016\u0005\u0005=\u0006cA\u001e\u00022&\u0019\u00111\u0017\u001f\u0003\r\u0011{WO\u00197f\u0003\r\t\u0005\u0007I\u0001\u0003\u0003F\n1!Q\u0019!\u0003\t\t%'A\u0002Be\u0001\n!!Q\u001a\u0002\u0007\u0005\u001b\u0004%\u0001\u0002Bi\u0005\u0019\u0011\t\u000e\u0011\u0002\u001d\u0015D\b/T5okN\f\u0005\u000f\u001d:pqR!\u0011qVAf\u0011\u001d\tim\ba\u0001\u0003_\u000b\u0011\u0001_\u0001\fe\u0016\fGMU3t_24X\r\u0006\u0002\u0002TB!\u0011Q[Ap\u001b\t\t9N\u0003\u0003\u0002Z\u0006m\u0017\u0001\u00027b]\u001eT!!!8\u0002\t)\fg/Y\u0005\u0005\u0003C\f9N\u0001\u0004PE*,7\r\u001e")
/* loaded from: input_file:com/intel/analytics/bigdl/nn/LogSoftMax.class */
public class LogSoftMax<T> extends TensorModule<T> {
    public static final long serialVersionUID = -2954501946670913825L;
    private final TensorNumericMath.TensorNumeric<T> ev;
    private final Tensor<T> ones;
    private final Tensor<T> buffer;
    private transient Future<BoxedUnit>[] results;

    public static double expMinusApprox(double d) {
        return LogSoftMax$.MODULE$.expMinusApprox(d);
    }

    private Future<BoxedUnit>[] results() {
        return this.results;
    }

    private void results_$eq(Future<BoxedUnit>[] futureArr) {
        this.results = futureArr;
    }

    private Tensor<T> ones() {
        return this.ones;
    }

    private Tensor<T> buffer() {
        return this.buffer;
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    public Tensor<T> updateOutput(Tensor<T> tensor) {
        Predef$.MODULE$.require(tensor.dim() == 1 || tensor.dim() == 2, () -> {
            return new StringBuilder(22).append("LogSoftMax: ").append(ErrorInfo$.MODULE$.constrainInputAsVectorOrBatch()).append("input dim ").append(tensor.dim()).toString();
        });
        output().resizeAs(tensor).copy(tensor);
        Tuple2.mcII.sp spVar = tensor.nDimension() == 1 ? new Tuple2.mcII.sp(1, tensor.size(1)) : new Tuple2.mcII.sp(tensor.size(1), tensor.size(2));
        if (spVar == null) {
            throw new MatchError(spVar);
        }
        Tuple2.mcII.sp spVar2 = new Tuple2.mcII.sp(spVar._1$mcI$sp(), spVar._2$mcI$sp());
        int _1$mcI$sp = spVar2._1$mcI$sp();
        spVar2._2$mcI$sp();
        if (_1$mcI$sp == 1) {
            updateOutputFrame(tensor, output());
        } else {
            if (results() == null || results().length != _1$mcI$sp) {
                results_$eq(new Future[_1$mcI$sp]);
            }
            int i = 1;
            while (true) {
                int i2 = i;
                if (i2 > _1$mcI$sp) {
                    break;
                }
                results()[i2 - 1] = Engine$.MODULE$.model().invoke((Function0) () -> {
                    this.updateOutputFrame(tensor.select(1, i2), this.output().select(1, i2));
                });
                i = i2 + 1;
            }
            ThreadPool model = Engine$.MODULE$.model();
            model.sync(Predef$.MODULE$.wrapRefArray(results()), model.sync$default$2());
        }
        return output();
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* JADX WARN: Multi-variable type inference failed */
    public void updateOutputFrame(Tensor<T> tensor, Tensor<T> tensor2) {
        if (ones().nElement() < tensor.nElement()) {
            ones().resizeAs(tensor).fill(this.ev.one());
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        if (buffer().nElement() != tensor2.nElement()) {
            buffer().resizeAs(tensor2);
        } else {
            BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
        }
        T max = tensor.mo2936max();
        buffer().fill(this.ev.negative(max));
        buffer().add((Tensor) tensor);
        buffer().exp();
        tensor2.add((Tensor<T>) this.ev.negative(this.ev.plus(max, this.ev.log(buffer().mo2933dot(ones())))));
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    public Tensor<T> updateGradInput(Tensor<T> tensor, Tensor<T> tensor2) {
        Predef$.MODULE$.require(output().nDimension() == 1 || output().nDimension() == 2, () -> {
            return "vector or matrix expected";
        });
        Predef$.MODULE$.require(tensor2.dim() == tensor.dim(), () -> {
            return new StringBuilder(83).append("LogSoftMax: input and gradOutput shapes do not match, input_dim: ").append(tensor.dim()).append(", gradOutput_dim: ").append(tensor2.dim()).toString();
        });
        gradInput().resizeAs(tensor).copy(tensor2);
        Tuple2.mcII.sp spVar = output().nDimension() == 1 ? new Tuple2.mcII.sp(1, output().size(1)) : new Tuple2.mcII.sp(output().size(1), output().size(2));
        if (spVar == null) {
            throw new MatchError(spVar);
        }
        Tuple2.mcII.sp spVar2 = new Tuple2.mcII.sp(spVar._1$mcI$sp(), spVar._2$mcI$sp());
        int _1$mcI$sp = spVar2._1$mcI$sp();
        spVar2._2$mcI$sp();
        if (_1$mcI$sp == 1) {
            updateGradInputFrame(output(), gradInput());
        } else {
            if (results() == null || results().length != _1$mcI$sp) {
                results_$eq(new Future[_1$mcI$sp]);
            }
            int i = 1;
            while (true) {
                int i2 = i;
                if (i2 > _1$mcI$sp) {
                    break;
                }
                results()[i2 - 1] = Engine$.MODULE$.model().invoke((Function0) () -> {
                    this.updateGradInputFrame(this.output().select(1, i2), this.gradInput().select(1, i2));
                });
                i = i2 + 1;
            }
            ThreadPool model = Engine$.MODULE$.model();
            model.sync(Predef$.MODULE$.wrapRefArray(results()), model.sync$default$2());
        }
        return gradInput();
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* JADX WARN: Multi-variable type inference failed */
    public void updateGradInputFrame(Tensor<T> tensor, Tensor<T> tensor2) {
        buffer().exp(tensor);
        tensor2.add((Tensor<T>) this.ev.negative(tensor2.mo2933dot(ones())), (Tensor<Tensor<T>>) buffer());
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    public LogSoftMax<T> clearState() {
        super.clearState();
        ones().set();
        buffer().set();
        results_$eq(null);
        return this;
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule, com.intel.analytics.bigdl.nn.abstractnn.InferShape
    public Shape computeOutputShape(Shape shape) {
        return shape;
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public LogSoftMax(ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        super(classTag, tensorNumeric);
        this.ev = tensorNumeric;
        this.results = null;
        this.ones = Tensor$.MODULE$.apply(classTag, tensorNumeric);
        this.buffer = Tensor$.MODULE$.apply(classTag, tensorNumeric);
    }
}
