package com.intel.analytics.bigdl.nn.keras;

import com.intel.analytics.bigdl.nn.AbsCriterion$;
import com.intel.analytics.bigdl.nn.BCECriterion$;
import com.intel.analytics.bigdl.nn.CategoricalCrossEntropy$;
import com.intel.analytics.bigdl.nn.ClassNLLCriterion$;
import com.intel.analytics.bigdl.nn.CosineProximityCriterion$;
import com.intel.analytics.bigdl.nn.HardSigmoid$;
import com.intel.analytics.bigdl.nn.InitializationMethod;
import com.intel.analytics.bigdl.nn.KullbackLeiblerDivergenceCriterion$;
import com.intel.analytics.bigdl.nn.MSECriterion$;
import com.intel.analytics.bigdl.nn.MarginCriterion$;
import com.intel.analytics.bigdl.nn.MeanAbsolutePercentageCriterion$;
import com.intel.analytics.bigdl.nn.MeanSquaredLogarithmicCriterion$;
import com.intel.analytics.bigdl.nn.Ones$;
import com.intel.analytics.bigdl.nn.PoissonCriterion$;
import com.intel.analytics.bigdl.nn.RandomNormal;
import com.intel.analytics.bigdl.nn.RandomUniform;
import com.intel.analytics.bigdl.nn.ReLU$;
import com.intel.analytics.bigdl.nn.Sigmoid$;
import com.intel.analytics.bigdl.nn.SoftPlus$;
import com.intel.analytics.bigdl.nn.SoftSign$;
import com.intel.analytics.bigdl.nn.Tanh$;
import com.intel.analytics.bigdl.nn.Xavier$;
import com.intel.analytics.bigdl.nn.Zeros$;
import com.intel.analytics.bigdl.nn.abstractnn.AbstractCriterion;
import com.intel.analytics.bigdl.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.optim.Adadelta;
import com.intel.analytics.bigdl.optim.Adagrad;
import com.intel.analytics.bigdl.optim.Adagrad$;
import com.intel.analytics.bigdl.optim.Adam;
import com.intel.analytics.bigdl.optim.Adam$;
import com.intel.analytics.bigdl.optim.Adamax;
import com.intel.analytics.bigdl.optim.Adamax$;
import com.intel.analytics.bigdl.optim.OptimMethod;
import com.intel.analytics.bigdl.optim.RMSprop;
import com.intel.analytics.bigdl.optim.RMSprop$;
import com.intel.analytics.bigdl.optim.SGD;
import com.intel.analytics.bigdl.optim.SGD$;
import com.intel.analytics.bigdl.optim.Top1Accuracy;
import com.intel.analytics.bigdl.optim.ValidationMethod;
import com.intel.analytics.bigdl.tensor.Tensor;
import com.intel.analytics.bigdl.tensor.TensorNumericMath;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag;
import scala.runtime.BoxesRunTime;

/* compiled from: KerasUtils.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/nn/keras/KerasUtils$.class */
public final class KerasUtils$ {
    public static KerasUtils$ MODULE$;

    static {
        new KerasUtils$();
    }

    public Tuple2<Object, Object> getPadsFromBorderMode(String str) {
        return (str != null ? !str.equals("same") : "same" != 0) ? new Tuple2.mcII.sp(0, 0) : new Tuple2.mcII.sp(-1, -1);
    }

    public String getPadsFromBorderMode$default$1() {
        return "valid";
    }

    public InitializationMethod getInitMethod(String str) {
        Serializable randomNormal;
        String lowerCase = str.toLowerCase();
        if ("glorot_uniform".equals(lowerCase)) {
            randomNormal = Xavier$.MODULE$;
        } else if ("one".equals(lowerCase)) {
            randomNormal = Ones$.MODULE$;
        } else if ("zero".equals(lowerCase)) {
            randomNormal = Zeros$.MODULE$;
        } else if ("uniform".equals(lowerCase)) {
            randomNormal = new RandomUniform(-0.05d, 0.05d);
        } else {
            if (!"normal".equals(lowerCase)) {
                throw new IllegalArgumentException(new StringBuilder(35).append("Unsupported initialization method: ").append(str.toLowerCase()).toString());
            }
            randomNormal = new RandomNormal(0.0d, 0.05d);
        }
        return randomNormal;
    }

    public <T> KerasLayer<Tensor<T>, Tensor<T>, T> getKerasActivation(String str, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        if (str == null) {
            return null;
        }
        String lowerCase = str.toLowerCase();
        if (lowerCase != null ? lowerCase.equals("softmax") : "softmax" == 0) {
            return SoftMax$.MODULE$.apply(SoftMax$.MODULE$.apply$default$1(), classTag, tensorNumeric);
        }
        return new KerasIdentityWrapper(com.intel.analytics.bigdl.package$.MODULE$.convModule(getTorchActivation(str, classTag, tensorNumeric)), classTag, tensorNumeric);
    }

    public <T> AbstractModule<Tensor<T>, Tensor<T>, T> getTorchActivation(String str, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        AbstractModule apply;
        if (str == null) {
            return null;
        }
        String lowerCase = str.toLowerCase();
        if ("tanh".equals(lowerCase)) {
            apply = Tanh$.MODULE$.apply(classTag, tensorNumeric);
        } else if ("sigmoid".equals(lowerCase)) {
            apply = Sigmoid$.MODULE$.apply(classTag, tensorNumeric);
        } else if ("relu".equals(lowerCase)) {
            apply = ReLU$.MODULE$.apply(ReLU$.MODULE$.apply$default$1(), classTag, tensorNumeric);
        } else if ("softmax".equals(lowerCase)) {
            apply = com.intel.analytics.bigdl.nn.SoftMax$.MODULE$.apply(com.intel.analytics.bigdl.nn.SoftMax$.MODULE$.apply$default$1(), classTag, tensorNumeric);
        } else if ("softplus".equals(lowerCase)) {
            apply = SoftPlus$.MODULE$.apply(SoftPlus$.MODULE$.apply$default$1(), classTag, tensorNumeric);
        } else if ("softsign".equals(lowerCase)) {
            apply = SoftSign$.MODULE$.apply(classTag, tensorNumeric);
        } else {
            if (!"hard_sigmoid".equals(lowerCase)) {
                throw new IllegalArgumentException(new StringBuilder(77).append("Invalid activation: ").append(str.toLowerCase()).append(". Only simple activations can be constructed using string").toString());
            }
            apply = HardSigmoid$.MODULE$.apply(classTag, tensorNumeric);
        }
        return apply;
    }

    public int computeConvOutputLength(int i, int i2, String str, int i3, int i4) {
        int i5;
        int i6 = i2 + ((i2 - 1) * (i4 - 1));
        if ("valid".equals(str)) {
            i5 = (i - i6) + 1;
        } else {
            if (!"same".equals(str)) {
                throw new MatchError(str);
            }
            i5 = i;
        }
        return ((i5 + i3) - 1) / i3;
    }

    public int computeConvOutputLength$default$5() {
        return 1;
    }

    public Tuple3<Object, Object, Object> getPadsFromBorderMode3D(String str) {
        return (str != null ? !str.equals("same") : "same" != 0) ? new Tuple3<>(BoxesRunTime.boxToInteger(0), BoxesRunTime.boxToInteger(0), BoxesRunTime.boxToInteger(0)) : new Tuple3<>(BoxesRunTime.boxToInteger(-1), BoxesRunTime.boxToInteger(-1), BoxesRunTime.boxToInteger(-1));
    }

    public String getPadsFromBorderMode3D$default$1() {
        return "valid";
    }

    /* JADX WARN: Removed duplicated region for block: B:12:0x0065  */
    /* JADX WARN: Removed duplicated region for block: B:8:0x005b  */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public com.intel.analytics.bigdl.nn.abstractnn.DataFormat toBigDLFormat(java.lang.String r5) {
        /*
            r4 = this;
            scala.Predef$ r0 = scala.Predef$.MODULE$
            r1 = r5
            java.lang.String r1 = r1.toLowerCase()
            java.lang.String r2 = "tf"
            r7 = r2
            r2 = r1
            if (r2 != 0) goto L17
        L10:
            r1 = r7
            if (r1 == 0) goto L3c
            goto L1e
        L17:
            r2 = r7
            boolean r1 = r1.equals(r2)
            if (r1 != 0) goto L3c
        L1e:
            r1 = r5
            java.lang.String r1 = r1.toLowerCase()
            java.lang.String r2 = "th"
            r8 = r2
            r2 = r1
            if (r2 != 0) goto L34
        L2c:
            r1 = r8
            if (r1 == 0) goto L3c
            goto L40
        L34:
            r2 = r8
            boolean r1 = r1.equals(r2)
            if (r1 == 0) goto L40
        L3c:
            r1 = 1
            goto L41
        L40:
            r1 = 0
        L41:
            r2 = r5
            com.intel.analytics.bigdl.nn.abstractnn.DataFormat r2 = () -> { // scala.Function0.apply():java.lang.Object
                return $anonfun$toBigDLFormat$1(r2);
            }
            r0.require(r1, r2)
            r0 = r5
            java.lang.String r0 = r0.toLowerCase()
            r9 = r0
            java.lang.String r0 = "tf"
            r1 = r9
            boolean r0 = r0.equals(r1)
            if (r0 == 0) goto L62
            com.intel.analytics.bigdl.nn.abstractnn.DataFormat$NHWC$ r0 = com.intel.analytics.bigdl.nn.abstractnn.DataFormat$NHWC$.MODULE$
            r6 = r0
            goto L84
        L62:
            goto L65
        L65:
            java.lang.String r0 = "th"
            r1 = r9
            boolean r0 = r0.equals(r1)
            if (r0 == 0) goto L77
            com.intel.analytics.bigdl.nn.abstractnn.DataFormat$NCHW$ r0 = com.intel.analytics.bigdl.nn.abstractnn.DataFormat$NCHW$.MODULE$
            r6 = r0
            goto L84
        L77:
            goto L7a
        L7a:
            scala.MatchError r0 = new scala.MatchError
            r1 = r0
            r2 = r9
            r1.<init>(r2)
            throw r0
        L84:
            r0 = r6
            return r0
        */
        throw new UnsupportedOperationException("Method not decompiled: com.intel.analytics.bigdl.nn.keras.KerasUtils$.toBigDLFormat(java.lang.String):com.intel.analytics.bigdl.nn.abstractnn.DataFormat");
    }

    /* JADX WARN: Removed duplicated region for block: B:12:0x0065  */
    /* JADX WARN: Removed duplicated region for block: B:8:0x005b  */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public java.lang.String toBigDLFormat5D(java.lang.String r5) {
        /*
            r4 = this;
            scala.Predef$ r0 = scala.Predef$.MODULE$
            r1 = r5
            java.lang.String r1 = r1.toLowerCase()
            java.lang.String r2 = "tf"
            r7 = r2
            r2 = r1
            if (r2 != 0) goto L17
        L10:
            r1 = r7
            if (r1 == 0) goto L3c
            goto L1e
        L17:
            r2 = r7
            boolean r1 = r1.equals(r2)
            if (r1 != 0) goto L3c
        L1e:
            r1 = r5
            java.lang.String r1 = r1.toLowerCase()
            java.lang.String r2 = "th"
            r8 = r2
            r2 = r1
            if (r2 != 0) goto L34
        L2c:
            r1 = r8
            if (r1 == 0) goto L3c
            goto L40
        L34:
            r2 = r8
            boolean r1 = r1.equals(r2)
            if (r1 == 0) goto L40
        L3c:
            r1 = 1
            goto L41
        L40:
            r1 = 0
        L41:
            r2 = r5
            java.lang.String r2 = () -> { // scala.Function0.apply():java.lang.Object
                return $anonfun$toBigDLFormat5D$1(r2);
            }
            r0.require(r1, r2)
            r0 = r5
            java.lang.String r0 = r0.toLowerCase()
            r9 = r0
            java.lang.String r0 = "tf"
            r1 = r9
            boolean r0 = r0.equals(r1)
            if (r0 == 0) goto L62
            java.lang.String r0 = "CHANNEL_LAST"
            r6 = r0
            goto L84
        L62:
            goto L65
        L65:
            java.lang.String r0 = "th"
            r1 = r9
            boolean r0 = r0.equals(r1)
            if (r0 == 0) goto L77
            java.lang.String r0 = "CHANNEL_FIRST"
            r6 = r0
            goto L84
        L77:
            goto L7a
        L7a:
            scala.MatchError r0 = new scala.MatchError
            r1 = r0
            r2 = r9
            r1.<init>(r2)
            throw r0
        L84:
            r0 = r6
            return r0
        */
        throw new UnsupportedOperationException("Method not decompiled: com.intel.analytics.bigdl.nn.keras.KerasUtils$.toBigDLFormat5D(java.lang.String):java.lang.String");
    }

    public <T> AbstractCriterion<Activity, Activity, T> toBigDLCriterion(String str, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        AbstractCriterion<Activity, Activity, T> convCriterion;
        String lowerCase = str.toLowerCase();
        if ("binary_crossentropy".equals(lowerCase)) {
            com.intel.analytics.bigdl.package$ package_ = com.intel.analytics.bigdl.package$.MODULE$;
            BCECriterion$ bCECriterion$ = BCECriterion$.MODULE$;
            BCECriterion$.MODULE$.apply$default$1();
            convCriterion = package_.convCriterion(bCECriterion$.apply(null, BCECriterion$.MODULE$.apply$default$2(), classTag, tensorNumeric));
        } else if ("categorical_crossentropy".equals(lowerCase)) {
            convCriterion = com.intel.analytics.bigdl.package$.MODULE$.convCriterion(CategoricalCrossEntropy$.MODULE$.apply(classTag, tensorNumeric));
        } else if ("mse".equals(lowerCase)) {
            convCriterion = com.intel.analytics.bigdl.package$.MODULE$.convCriterion(MSECriterion$.MODULE$.apply(classTag, tensorNumeric));
        } else if ("mean_squared_error".equals(lowerCase)) {
            convCriterion = com.intel.analytics.bigdl.package$.MODULE$.convCriterion(MSECriterion$.MODULE$.apply(classTag, tensorNumeric));
        } else if ("mae".equals(lowerCase)) {
            convCriterion = com.intel.analytics.bigdl.package$.MODULE$.convCriterion(AbsCriterion$.MODULE$.apply(AbsCriterion$.MODULE$.apply$default$1(), classTag, tensorNumeric));
        } else if ("mean_absolute_error".equals(lowerCase)) {
            convCriterion = com.intel.analytics.bigdl.package$.MODULE$.convCriterion(AbsCriterion$.MODULE$.apply(AbsCriterion$.MODULE$.apply$default$1(), classTag, tensorNumeric));
        } else if ("hinge".equals(lowerCase)) {
            convCriterion = com.intel.analytics.bigdl.package$.MODULE$.convCriterion(MarginCriterion$.MODULE$.apply(MarginCriterion$.MODULE$.apply$default$1(), MarginCriterion$.MODULE$.apply$default$2(), MarginCriterion$.MODULE$.apply$default$3(), classTag, tensorNumeric));
        } else if ("mape".equals(lowerCase)) {
            convCriterion = com.intel.analytics.bigdl.package$.MODULE$.convCriterion(MeanAbsolutePercentageCriterion$.MODULE$.apply(classTag, tensorNumeric));
        } else if ("mean_absolute_percentage_error".equals(lowerCase)) {
            convCriterion = com.intel.analytics.bigdl.package$.MODULE$.convCriterion(MeanAbsolutePercentageCriterion$.MODULE$.apply(classTag, tensorNumeric));
        } else if ("msle".equals(lowerCase)) {
            convCriterion = com.intel.analytics.bigdl.package$.MODULE$.convCriterion(MeanSquaredLogarithmicCriterion$.MODULE$.apply(classTag, tensorNumeric));
        } else if ("mean_squared_logarithmic_error".equals(lowerCase)) {
            convCriterion = com.intel.analytics.bigdl.package$.MODULE$.convCriterion(MeanSquaredLogarithmicCriterion$.MODULE$.apply(classTag, tensorNumeric));
        } else if ("squared_hinge".equals(lowerCase)) {
            convCriterion = com.intel.analytics.bigdl.package$.MODULE$.convCriterion(MarginCriterion$.MODULE$.apply(MarginCriterion$.MODULE$.apply$default$1(), MarginCriterion$.MODULE$.apply$default$2(), true, classTag, tensorNumeric));
        } else if ("sparse_categorical_crossentropy".equals(lowerCase)) {
            ClassNLLCriterion$.MODULE$.apply$default$1();
            convCriterion = com.intel.analytics.bigdl.package$.MODULE$.convCriterion(ClassNLLCriterion$.MODULE$.apply(null, ClassNLLCriterion$.MODULE$.apply$default$2(), false, ClassNLLCriterion$.MODULE$.apply$default$4(), classTag, tensorNumeric));
        } else if ("kld".equals(lowerCase)) {
            convCriterion = com.intel.analytics.bigdl.package$.MODULE$.convCriterion(KullbackLeiblerDivergenceCriterion$.MODULE$.apply(classTag, tensorNumeric));
        } else if ("kullback_leibler_divergence".equals(lowerCase)) {
            convCriterion = com.intel.analytics.bigdl.package$.MODULE$.convCriterion(KullbackLeiblerDivergenceCriterion$.MODULE$.apply(classTag, tensorNumeric));
        } else if ("cosine_proximity".equals(lowerCase)) {
            convCriterion = com.intel.analytics.bigdl.package$.MODULE$.convCriterion(CosineProximityCriterion$.MODULE$.apply(classTag, tensorNumeric));
        } else {
            if (!"poisson".equals(lowerCase)) {
                throw new IllegalArgumentException(new StringBuilder(14).append("Invalid loss: ").append(str.toLowerCase()).toString());
            }
            convCriterion = com.intel.analytics.bigdl.package$.MODULE$.convCriterion(PoissonCriterion$.MODULE$.apply(classTag, tensorNumeric));
        }
        return convCriterion;
    }

    public <T> OptimMethod<T> toBigDLOptimMethod(String str, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        OptimMethod adam;
        String lowerCase = str.toLowerCase();
        if ("sgd".equals(lowerCase)) {
            double $lessinit$greater$default$2 = SGD$.MODULE$.$lessinit$greater$default$2();
            double $lessinit$greater$default$3 = SGD$.MODULE$.$lessinit$greater$default$3();
            double $lessinit$greater$default$4 = SGD$.MODULE$.$lessinit$greater$default$4();
            double $lessinit$greater$default$5 = SGD$.MODULE$.$lessinit$greater$default$5();
            boolean $lessinit$greater$default$6 = SGD$.MODULE$.$lessinit$greater$default$6();
            SGD.LearningRateSchedule $lessinit$greater$default$7 = SGD$.MODULE$.$lessinit$greater$default$7();
            SGD$.MODULE$.$lessinit$greater$default$8();
            SGD$.MODULE$.$lessinit$greater$default$9();
            adam = new SGD(0.01d, $lessinit$greater$default$2, $lessinit$greater$default$3, $lessinit$greater$default$4, $lessinit$greater$default$5, $lessinit$greater$default$6, $lessinit$greater$default$7, null, null, classTag, tensorNumeric);
        } else if ("rmsprop".equals(lowerCase)) {
            adam = new RMSprop(0.001d, RMSprop$.MODULE$.$lessinit$greater$default$2(), 0.9d, RMSprop$.MODULE$.$lessinit$greater$default$4(), classTag, tensorNumeric);
        } else if ("adamax".equals(lowerCase)) {
            adam = new Adamax(Adamax$.MODULE$.$lessinit$greater$default$1(), Adamax$.MODULE$.$lessinit$greater$default$2(), Adamax$.MODULE$.$lessinit$greater$default$3(), 1.0E-8d, classTag, tensorNumeric);
        } else if ("adagrad".equals(lowerCase)) {
            adam = new Adagrad(0.01d, Adagrad$.MODULE$.$lessinit$greater$default$2(), Adagrad$.MODULE$.$lessinit$greater$default$3(), classTag, tensorNumeric);
        } else if ("adadelta".equals(lowerCase)) {
            adam = new Adadelta(0.95d, 1.0E-8d, classTag, tensorNumeric);
        } else {
            if (!"adam".equals(lowerCase)) {
                throw new MatchError(lowerCase);
            }
            adam = new Adam(Adam$.MODULE$.$lessinit$greater$default$1(), Adam$.MODULE$.$lessinit$greater$default$2(), Adam$.MODULE$.$lessinit$greater$default$3(), Adam$.MODULE$.$lessinit$greater$default$4(), Adam$.MODULE$.$lessinit$greater$default$5(), classTag, tensorNumeric);
        }
        return adam;
    }

    public <T> ValidationMethod<T>[] toBigDLMetrics(String[] strArr, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        if (strArr == null) {
            return null;
        }
        if (new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(strArr)).sameElements(Predef$.MODULE$.wrapRefArray(new String[]{"accuracy"}))) {
            return new ValidationMethod[]{new Top1Accuracy(classTag, tensorNumeric)};
        }
        throw new IllegalArgumentException(new StringBuilder(21).append("Unsupported metrics: ").append(new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(strArr)).mkString(", ")).toString());
    }

    private KerasUtils$() {
        MODULE$ = this;
    }
}
