package com.intel.analytics.bigdl.nn;

import com.intel.analytics.bigdl.tensor.ConvertableFrom$ConvertableFromFloat$;
import com.intel.analytics.bigdl.tensor.ConvertableFrom$ConvertableFromInt$;
import com.intel.analytics.bigdl.tensor.Tensor;
import com.intel.analytics.bigdl.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.utils.Engine$;
import com.intel.analytics.bigdl.utils.ThreadPool;
import scala.Float$;
import scala.Function0;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.Tuple3;
import scala.collection.IndexedSeq;
import scala.collection.mutable.ArrayOps;
import scala.concurrent.Future;
import scala.reflect.ClassTag;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.ScalaRunTime$;

/* compiled from: SoftMax.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/nn/SoftMax$.class */
public final class SoftMax$ implements Serializable {
    public static SoftMax$ MODULE$;

    static {
        new SoftMax$();
    }

    public <T> int $lessinit$greater$default$1() {
        return 1;
    }

    public <T> SoftMax<T> apply(int i, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        return new SoftMax<>(i, classTag, tensorNumeric);
    }

    public <T> SoftMax<T> apply(ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        return new SoftMax<>(1, classTag, tensorNumeric);
    }

    public <T> int apply$default$1() {
        return 1;
    }

    public <T> Tensor<T> updateOutput(Tensor<T> tensor, Tensor<T> tensor2, Future<BoxedUnit>[] futureArr, int i, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Tuple3 tuple3;
        switch (tensor.nDimension() - i) {
            case 0:
                tuple3 = new Tuple3(BoxesRunTime.boxToInteger(1), BoxesRunTime.boxToInteger(tensor.size(i)), BoxesRunTime.boxToInteger(1));
                break;
            case 1:
                tuple3 = new Tuple3(BoxesRunTime.boxToInteger(tensor.size(i)), BoxesRunTime.boxToInteger(tensor.size(i + 1)), BoxesRunTime.boxToInteger(1));
                break;
            case 2:
                tuple3 = new Tuple3(BoxesRunTime.boxToInteger(1), BoxesRunTime.boxToInteger(tensor.size(i)), BoxesRunTime.boxToInteger(tensor.size(i + 1) * tensor.size(i + 2)));
                break;
            default:
                tuple3 = new Tuple3(BoxesRunTime.boxToInteger(tensor.size(i)), BoxesRunTime.boxToInteger(tensor.size(i + 1)), BoxesRunTime.boxToInteger(tensor.size(i + 2) * tensor.size(i + 3)));
                break;
        }
        Tuple3 tuple32 = tuple3;
        if (tuple32 == null) {
            throw new MatchError(tuple32);
        }
        Tuple3 tuple33 = new Tuple3(BoxesRunTime.boxToInteger(BoxesRunTime.unboxToInt(tuple32._1())), BoxesRunTime.boxToInteger(BoxesRunTime.unboxToInt(tuple32._2())), BoxesRunTime.boxToInteger(BoxesRunTime.unboxToInt(tuple32._3())));
        int unboxToInt = BoxesRunTime.unboxToInt(tuple33._1());
        int unboxToInt2 = BoxesRunTime.unboxToInt(tuple33._2());
        int unboxToInt3 = BoxesRunTime.unboxToInt(tuple33._3());
        Object array = tensor2.storage().array();
        Object array2 = tensor.isContiguous() ? tensor.storage().array() : tensor.contiguous().storage().array();
        int storageOffset = tensor.storageOffset() - 1;
        int i2 = 0;
        while (true) {
            int i3 = i2;
            if (i3 >= unboxToInt3 * unboxToInt) {
                ThreadPool model = Engine$.MODULE$.model();
                model.sync(Predef$.MODULE$.wrapRefArray(futureArr), model.sync$default$2());
                return tensor2;
            }
            futureArr[i3] = Engine$.MODULE$.model().invoke((Function0) () -> {
                int i4 = ((i3 / unboxToInt3) * unboxToInt2 * unboxToInt3) + (i3 % unboxToInt3) + storageOffset;
                int i5 = ((i3 / unboxToInt3) * unboxToInt2 * unboxToInt3) + (i3 % unboxToInt3);
                Object mo2991fromType = tensorNumeric.mo2991fromType(BoxesRunTime.boxToFloat(Float$.MODULE$.MinValue()), ConvertableFrom$ConvertableFromFloat$.MODULE$);
                int i6 = 0;
                while (true) {
                    int i7 = i6;
                    if (i7 >= unboxToInt2) {
                        break;
                    }
                    if (tensorNumeric.isGreater(ScalaRunTime$.MODULE$.array_apply(array2, (i7 * unboxToInt3) + i4), mo2991fromType)) {
                        mo2991fromType = ScalaRunTime$.MODULE$.array_apply(array2, (i7 * unboxToInt3) + i4);
                    }
                    i6 = i7 + 1;
                }
                Object mo2991fromType2 = tensorNumeric.mo2991fromType(BoxesRunTime.boxToInteger(0), ConvertableFrom$ConvertableFromInt$.MODULE$);
                int i8 = 0;
                while (true) {
                    int i9 = i8;
                    if (i9 >= unboxToInt2) {
                        break;
                    }
                    Object exp = tensorNumeric.exp(tensorNumeric.minus(ScalaRunTime$.MODULE$.array_apply(array2, (i9 * unboxToInt3) + i4), mo2991fromType));
                    ScalaRunTime$.MODULE$.array_update(array, (i9 * unboxToInt3) + i5, exp);
                    mo2991fromType2 = tensorNumeric.plus(mo2991fromType2, exp);
                    i8 = i9 + 1;
                }
                int i10 = 0;
                while (true) {
                    int i11 = i10;
                    if (i11 >= unboxToInt2) {
                        return;
                    }
                    ScalaRunTime$.MODULE$.array_update(array, (i11 * unboxToInt3) + i5, tensorNumeric.times(ScalaRunTime$.MODULE$.array_apply(array, (i11 * unboxToInt3) + i5), tensorNumeric.divide(tensorNumeric.mo2991fromType(BoxesRunTime.boxToInteger(1), ConvertableFrom$ConvertableFromInt$.MODULE$), mo2991fromType2)));
                    i10 = i11 + 1;
                }
            });
            i2 = i3 + 1;
        }
    }

    public <T> int updateOutput$default$4() {
        return 1;
    }

    public <T> Tensor<T> updateGradInput(Tensor<T> tensor, Tensor<T> tensor2, Tensor<T> tensor3, Tensor<T> tensor4, Future<BoxedUnit>[] futureArr, int i, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Tuple3 tuple3;
        Predef$ predef$ = Predef$.MODULE$;
        IndexedSeq deep = new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps(tensor.size())).deep();
        IndexedSeq deep2 = new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps(tensor2.size())).deep();
        predef$.require(deep != null ? deep.equals(deep2) : deep2 == null, () -> {
            return new StringBuilder(69).append("input should have the same size with gradOutput").append("inputsize ").append(new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps(tensor.size())).deep()).append(" gradOutput ").append(new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps(tensor2.size())).deep()).toString();
        });
        switch (tensor4.nDimension() - i) {
            case 0:
                tuple3 = new Tuple3(BoxesRunTime.boxToInteger(1), BoxesRunTime.boxToInteger(tensor4.size(i)), BoxesRunTime.boxToInteger(1));
                break;
            case 1:
                tuple3 = new Tuple3(BoxesRunTime.boxToInteger(tensor4.size(i)), BoxesRunTime.boxToInteger(tensor4.size(i + 1)), BoxesRunTime.boxToInteger(1));
                break;
            case 2:
                tuple3 = new Tuple3(BoxesRunTime.boxToInteger(1), BoxesRunTime.boxToInteger(tensor4.size(i)), BoxesRunTime.boxToInteger(tensor4.size(i + 1) * tensor4.size(i + 2)));
                break;
            default:
                tuple3 = new Tuple3(BoxesRunTime.boxToInteger(tensor4.size(i)), BoxesRunTime.boxToInteger(tensor4.size(i + 1)), BoxesRunTime.boxToInteger(tensor4.size(i + 2) * tensor4.size(i + 3)));
                break;
        }
        Tuple3 tuple32 = tuple3;
        if (tuple32 == null) {
            throw new MatchError(tuple32);
        }
        Tuple3 tuple33 = new Tuple3(BoxesRunTime.boxToInteger(BoxesRunTime.unboxToInt(tuple32._1())), BoxesRunTime.boxToInteger(BoxesRunTime.unboxToInt(tuple32._2())), BoxesRunTime.boxToInteger(BoxesRunTime.unboxToInt(tuple32._3())));
        int unboxToInt = BoxesRunTime.unboxToInt(tuple33._1());
        int unboxToInt2 = BoxesRunTime.unboxToInt(tuple33._2());
        int unboxToInt3 = BoxesRunTime.unboxToInt(tuple33._3());
        Object array = tensor3.storage().array();
        Object array2 = tensor4.isContiguous() ? tensor4.storage().array() : tensor4.contiguous().storage().array();
        Object array3 = tensor2.isContiguous() ? tensor2.storage().array() : tensor2.contiguous().storage().array();
        int i2 = 0;
        while (true) {
            int i3 = i2;
            if (i3 >= unboxToInt3 * unboxToInt) {
                ThreadPool model = Engine$.MODULE$.model();
                model.sync(Predef$.MODULE$.wrapRefArray(futureArr), model.sync$default$2());
                return tensor3;
            }
            futureArr[i3] = Engine$.MODULE$.model().invoke((Function0) () -> {
                int i4 = ((i3 / unboxToInt3) * unboxToInt2 * unboxToInt3) + (i3 % unboxToInt3);
                int i5 = ((i3 / unboxToInt3) * unboxToInt2 * unboxToInt3) + (i3 % unboxToInt3);
                int i6 = ((i3 / unboxToInt3) * unboxToInt2 * unboxToInt3) + (i3 % unboxToInt3);
                Object mo2991fromType = tensorNumeric.mo2991fromType(BoxesRunTime.boxToInteger(0), ConvertableFrom$ConvertableFromInt$.MODULE$);
                int i7 = 0;
                while (true) {
                    int i8 = i7;
                    if (i8 >= unboxToInt2) {
                        break;
                    }
                    mo2991fromType = tensorNumeric.plus(mo2991fromType, tensorNumeric.times(ScalaRunTime$.MODULE$.array_apply(array3, (i8 * unboxToInt3) + i6), ScalaRunTime$.MODULE$.array_apply(array2, (i8 * unboxToInt3) + i5)));
                    i7 = i8 + 1;
                }
                int i9 = 0;
                while (true) {
                    int i10 = i9;
                    if (i10 >= unboxToInt2) {
                        return;
                    }
                    ScalaRunTime$.MODULE$.array_update(array, (i10 * unboxToInt3) + i4, tensorNumeric.times(ScalaRunTime$.MODULE$.array_apply(array2, (i10 * unboxToInt3) + i5), tensorNumeric.minus(ScalaRunTime$.MODULE$.array_apply(array3, (i10 * unboxToInt3) + i6), mo2991fromType)));
                    i9 = i10 + 1;
                }
            });
            i2 = i3 + 1;
        }
    }

    public <T> int updateGradInput$default$6() {
        return 1;
    }

    private Object readResolve() {
        return MODULE$;
    }

    public SoftMax<Object> apply$mDc$sp(int i, ClassTag<Object> classTag, TensorNumericMath.TensorNumeric<Object> tensorNumeric) {
        return new SoftMax<>(i, classTag, tensorNumeric);
    }

    public SoftMax<Object> apply$mFc$sp(int i, ClassTag<Object> classTag, TensorNumericMath.TensorNumeric<Object> tensorNumeric) {
        return new SoftMax<>(i, classTag, tensorNumeric);
    }

    public SoftMax<Object> apply$mDc$sp(ClassTag<Object> classTag, TensorNumericMath.TensorNumeric<Object> tensorNumeric) {
        return new SoftMax<>(1, classTag, tensorNumeric);
    }

    public SoftMax<Object> apply$mFc$sp(ClassTag<Object> classTag, TensorNumericMath.TensorNumeric<Object> tensorNumeric) {
        return new SoftMax<>(1, classTag, tensorNumeric);
    }

    private SoftMax$() {
        MODULE$ = this;
    }
}
