package com.intel.analytics.bigdl.optim;

import com.intel.analytics.bigdl.tensor.ConvertableFrom$ConvertableFromDouble$;
import com.intel.analytics.bigdl.tensor.Tensor;
import com.intel.analytics.bigdl.tensor.Tensor$;
import com.intel.analytics.bigdl.tensor.TensorNumericMath;
import scala.Array$;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.Tuple4;
import scala.reflect.ClassTag;
import scala.runtime.BoxesRunTime;

/* compiled from: Adadelta.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/optim/Adadelta$mcD$sp.class */
public class Adadelta$mcD$sp extends Adadelta<Object> implements OptimMethod$mcD$sp {
    public final TensorNumericMath.TensorNumeric<Object> ev$mcD$sp;
    private final ClassTag<Object> evidence$1;

    @Override // com.intel.analytics.bigdl.optim.Adadelta, com.intel.analytics.bigdl.optim.OptimMethod
    public Tuple2<Tensor<Object>, Object> optimize(Function1<Tensor<Object>, Tuple2<Object, Tensor<Object>>> function1, Tensor<Object> tensor) {
        return optimize$mcD$sp(function1, tensor);
    }

    @Override // com.intel.analytics.bigdl.optim.Adadelta, com.intel.analytics.bigdl.optim.OptimMethod
    public Tuple2<Tensor<Object>, double[]> optimize$mcD$sp(Function1<Tensor<Object>, Tuple2<Object, Tensor<Object>>> function1, Tensor<Object> tensor) {
        int unboxToInt = BoxesRunTime.unboxToInt(state().getOrElse("evalCounter", BoxesRunTime.boxToInteger(0)));
        double decayRate = decayRate();
        double Epsilon = Epsilon();
        Tuple2 tuple2 = (Tuple2) function1.apply(tensor);
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        Tuple2 tuple22 = new Tuple2(BoxesRunTime.boxToDouble(tuple2._1$mcD$sp()), (Tensor) tuple2._2());
        double _1$mcD$sp = tuple22._1$mcD$sp();
        Tensor<?> tensor2 = (Tensor) tuple22._2();
        Tuple4 tuple4 = state().get("paramVariance").isDefined() ? new Tuple4(state().get("paramVariance").get(), state().get("paramStd").get(), state().get("delta").get(), state().get("accDelta").get()) : new Tuple4(Tensor$.MODULE$.apply$mDc$sp(this.com$intel$analytics$bigdl$optim$Adadelta$$evidence$1, this.ev$mcD$sp).resizeAs(tensor2).zero(), Tensor$.MODULE$.apply$mDc$sp(this.com$intel$analytics$bigdl$optim$Adadelta$$evidence$1, this.ev$mcD$sp).resizeAs(tensor2).zero(), Tensor$.MODULE$.apply$mDc$sp(this.com$intel$analytics$bigdl$optim$Adadelta$$evidence$1, this.ev$mcD$sp).resizeAs(tensor2).zero(), Tensor$.MODULE$.apply$mDc$sp(this.com$intel$analytics$bigdl$optim$Adadelta$$evidence$1, this.ev$mcD$sp).resizeAs(tensor2).zero());
        if (tuple4 == null) {
            throw new MatchError(tuple4);
        }
        Tuple4 tuple42 = new Tuple4((Tensor) tuple4._1(), (Tensor) tuple4._2(), (Tensor) tuple4._3(), (Tensor) tuple4._4());
        Tensor<Object> tensor3 = (Tensor) tuple42._1();
        Tensor<Object> tensor4 = (Tensor) tuple42._2();
        Tensor<Object> tensor5 = (Tensor) tuple42._3();
        Tensor<Object> tensor6 = (Tensor) tuple42._4();
        tensor3.mul(BoxesRunTime.boxToDouble(this.ev$mcD$sp.fromType$mcD$sp(BoxesRunTime.boxToDouble(decayRate), ConvertableFrom$ConvertableFromDouble$.MODULE$))).addcmul(BoxesRunTime.boxToDouble(this.ev$mcD$sp.fromType$mcD$sp(BoxesRunTime.boxToDouble(1 - decayRate), ConvertableFrom$ConvertableFromDouble$.MODULE$)), tensor2, tensor2);
        tensor4.copy(tensor3).add((Tensor<Object>) BoxesRunTime.boxToDouble(this.ev$mcD$sp.fromType$mcD$sp(BoxesRunTime.boxToDouble(Epsilon), ConvertableFrom$ConvertableFromDouble$.MODULE$))).sqrt();
        tensor5.copy(tensor6).add((Tensor<Object>) BoxesRunTime.boxToDouble(this.ev$mcD$sp.fromType$mcD$sp(BoxesRunTime.boxToDouble(Epsilon), ConvertableFrom$ConvertableFromDouble$.MODULE$))).sqrt().cdiv(tensor4).cmul(tensor2);
        tensor.add((Tensor<Object>) BoxesRunTime.boxToDouble(this.ev$mcD$sp.fromType$mcD$sp(BoxesRunTime.boxToDouble(-1.0d), ConvertableFrom$ConvertableFromDouble$.MODULE$)), (Tensor<Tensor<Object>>) tensor5);
        tensor6.mul(BoxesRunTime.boxToDouble(this.ev$mcD$sp.fromType$mcD$sp(BoxesRunTime.boxToDouble(decayRate), ConvertableFrom$ConvertableFromDouble$.MODULE$))).addcmul(BoxesRunTime.boxToDouble(this.ev$mcD$sp.fromType$mcD$sp(BoxesRunTime.boxToDouble(1 - decayRate), ConvertableFrom$ConvertableFromDouble$.MODULE$)), tensor5, tensor5);
        state().update("evalCounter", BoxesRunTime.boxToInteger(unboxToInt + 1));
        state().update("paramVariance", tensor3);
        state().update("paramStd", tensor4);
        state().update("delta", tensor5);
        state().update("accDelta", tensor6);
        return new Tuple2<>(tensor, Array$.MODULE$.apply(Predef$.MODULE$.genericWrapArray(new double[]{_1$mcD$sp}), this.com$intel$analytics$bigdl$optim$Adadelta$$evidence$1));
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public Adadelta$mcD$sp(double d, double d2, ClassTag<Object> classTag, TensorNumericMath.TensorNumeric<Object> tensorNumeric) {
        super(d, d2, classTag, tensorNumeric);
        this.ev$mcD$sp = tensorNumeric;
        this.evidence$1 = classTag;
    }
}
