package com.intel.analytics.bigdl.models.vgg;

import com.intel.analytics.bigdl.dataset.AbstractDataSet;
import com.intel.analytics.bigdl.dataset.MiniBatch;
import com.intel.analytics.bigdl.models.vgg.Utils;
import com.intel.analytics.bigdl.nn.CrossEntropyCriterion$;
import com.intel.analytics.bigdl.nn.Module$;
import com.intel.analytics.bigdl.nn.SoftmaxWithCriterion;
import com.intel.analytics.bigdl.nn.SoftmaxWithCriterion$;
import com.intel.analytics.bigdl.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.optim.OptimMethod$;
import com.intel.analytics.bigdl.optim.Optimizer;
import com.intel.analytics.bigdl.optim.Optimizer$;
import com.intel.analytics.bigdl.optim.SGD;
import com.intel.analytics.bigdl.optim.SGD$;
import com.intel.analytics.bigdl.optim.SGD$mcF$sp;
import com.intel.analytics.bigdl.optim.Top1Accuracy;
import com.intel.analytics.bigdl.optim.Top5Accuracy;
import com.intel.analytics.bigdl.optim.Trigger;
import com.intel.analytics.bigdl.optim.Trigger$;
import com.intel.analytics.bigdl.optim.ValidationMethod;
import com.intel.analytics.bigdl.tensor.TensorNumericMath$TensorNumeric$NumericFloat$;
import com.intel.analytics.bigdl.utils.Engine$;
import com.intel.analytics.bigdl.utils.EngineType;
import com.intel.analytics.bigdl.utils.LoggerFilter$;
import com.intel.analytics.bigdl.utils.MklBlas$;
import com.intel.analytics.bigdl.utils.MklDnn$;
import com.intel.analytics.bigdl.utils.OptimizerV1$;
import com.intel.analytics.bigdl.utils.OptimizerV2$;
import com.intel.analytics.bigdl.visualization.TrainSummary;
import com.intel.analytics.bigdl.visualization.TrainSummary$;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.SparkContext;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: TrainImageNet.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/models/vgg/TrainImageNet$.class */
public final class TrainImageNet$ {
    public static TrainImageNet$ MODULE$;
    private final Logger logger;

    static {
        new TrainImageNet$();
    }

    public Logger logger() {
        return this.logger;
    }

    public void main(String[] strArr) {
        Utils$.MODULE$.trainParser().parse(Predef$.MODULE$.wrapRefArray(strArr), new Utils.TrainParams(Utils$TrainParams$.MODULE$.apply$default$1(), Utils$TrainParams$.MODULE$.apply$default$2(), Utils$TrainParams$.MODULE$.apply$default$3(), Utils$TrainParams$.MODULE$.apply$default$4(), Utils$TrainParams$.MODULE$.apply$default$5(), Utils$TrainParams$.MODULE$.apply$default$6(), Utils$TrainParams$.MODULE$.apply$default$7(), Utils$TrainParams$.MODULE$.apply$default$8(), Utils$TrainParams$.MODULE$.apply$default$9(), Utils$TrainParams$.MODULE$.apply$default$10(), Utils$TrainParams$.MODULE$.apply$default$11(), Utils$TrainParams$.MODULE$.apply$default$12(), Utils$TrainParams$.MODULE$.apply$default$13(), Utils$TrainParams$.MODULE$.apply$default$14(), Utils$TrainParams$.MODULE$.apply$default$15(), Utils$TrainParams$.MODULE$.apply$default$16(), Utils$TrainParams$.MODULE$.apply$default$17(), Utils$TrainParams$.MODULE$.apply$default$18(), Utils$TrainParams$.MODULE$.apply$default$19(), Utils$TrainParams$.MODULE$.apply$default$20(), Utils$TrainParams$.MODULE$.apply$default$21(), Utils$TrainParams$.MODULE$.apply$default$22())).foreach(trainParams -> {
            $anonfun$main$1(trainParams);
            return BoxedUnit.UNIT;
        });
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v104, types: [com.intel.analytics.bigdl.nn.CrossEntropyCriterion] */
    public static final /* synthetic */ void $anonfun$main$1(Utils.TrainParams trainParams) {
        AbstractModule graph;
        AbstractModule abstractModule;
        Serializable sGD$mcF$sp;
        SoftmaxWithCriterion<Object> apply$mFc$sp;
        SparkContext sparkContext = new SparkContext(Engine$.MODULE$.createSparkConf(Engine$.MODULE$.createSparkConf$default$1()).setAppName("Train VGG-16 on ImageNet2012").set("spark.rpc.message.maxSize", "200"));
        Engine$.MODULE$.init();
        int batchSize = trainParams.batchSize();
        String folder = trainParams.folder();
        int classNumber = trainParams.classNumber();
        AbstractDataSet<MiniBatch<Object>, ?> trainDataSet = Utils$.MODULE$.trainDataSet(new StringBuilder(6).append(folder).append("/train").toString(), sparkContext, 224, batchSize);
        AbstractDataSet<MiniBatch<Object>, ?> valDataSet = Utils$.MODULE$.valDataSet(new StringBuilder(4).append(folder).append("/val").toString(), sparkContext, 224, batchSize);
        if (trainParams.modelSnapshot().isDefined()) {
            abstractModule = Module$.MODULE$.load((String) trainParams.modelSnapshot().get(), ClassTag$.MODULE$.Float());
        } else {
            EngineType engineType = Engine$.MODULE$.getEngineType();
            if (MklBlas$.MODULE$.equals(engineType)) {
                graph = Vgg_16$.MODULE$.apply(classNumber, Vgg_16$.MODULE$.apply$default$2());
            } else {
                if (!MklDnn$.MODULE$.equals(engineType)) {
                    throw new MatchError(engineType);
                }
                graph = com.intel.analytics.bigdl.nn.mkldnn.models.Vgg_16$.MODULE$.graph(batchSize / Engine$.MODULE$.nodeNumber(), classNumber, com.intel.analytics.bigdl.nn.mkldnn.models.Vgg_16$.MODULE$.graph$default$3());
            }
            abstractModule = graph;
        }
        AbstractModule abstractModule2 = abstractModule;
        Predef$.MODULE$.println(abstractModule2);
        if (trainParams.optimizerVersion().isDefined()) {
            String lowerCase = ((String) trainParams.optimizerVersion().get()).toLowerCase();
            if ("optimizerv1".equals(lowerCase)) {
                Engine$.MODULE$.setOptimizerVersion(OptimizerV1$.MODULE$);
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else {
                if (!"optimizerv2".equals(lowerCase)) {
                    throw new MatchError(lowerCase);
                }
                Engine$.MODULE$.setOptimizerVersion(OptimizerV2$.MODULE$);
                BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
            }
        }
        if (trainParams.stateSnapshot().isDefined()) {
            sGD$mcF$sp = (SGD) OptimMethod$.MODULE$.load((String) trainParams.stateSnapshot().get(), ClassTag$.MODULE$.Float());
        } else {
            trainParams.learningRate();
            int ceil = (int) package$.MODULE$.ceil(1281167 / batchSize);
            SGD.SequentialSchedule sequentialSchedule = new SGD.SequentialSchedule(ceil);
            int unboxToInt = ceil * BoxesRunTime.unboxToInt(trainParams.warmupEpoch().getOrElse(() -> {
                return 0;
            }));
            if (unboxToInt != 0) {
                double maxLr = (trainParams.maxLr() - trainParams.learningRate()) / unboxToInt;
                sequentialSchedule.add(new SGD.Warmup(maxLr), unboxToInt);
                MODULE$.logger().info(new StringBuilder(45).append("warmUpIteraion: ").append(unboxToInt).append(", startLr: ").append(trainParams.learningRate()).append(", ").append("maxLr: ").append(trainParams.maxLr()).append(", delta: ").append(maxLr).toString());
            }
            sequentialSchedule.add(new SGD.Poly(0.5d, 40000), 40000 - unboxToInt);
            double learningRate = trainParams.learningRate();
            double weightDecay = trainParams.weightDecay();
            double momentum = trainParams.momentum();
            double dampening = trainParams.dampening();
            boolean nesterov = trainParams.nesterov();
            SGD$.MODULE$.$lessinit$greater$default$8();
            SGD$.MODULE$.$lessinit$greater$default$9();
            sGD$mcF$sp = new SGD$mcF$sp(learningRate, 0.0d, weightDecay, momentum, dampening, nesterov, sequentialSchedule, null, null, ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
        }
        Serializable serializable = sGD$mcF$sp;
        TrainSummary apply = TrainSummary$.MODULE$.apply("vgg16-imagenet", String.valueOf(sparkContext.applicationId()));
        apply.setSummaryTrigger("LearningRate", Trigger$.MODULE$.severalIteration(1));
        apply.setSummaryTrigger("Parameters", Trigger$.MODULE$.severalIteration(10));
        EngineType engineType2 = Engine$.MODULE$.getEngineType();
        if (MklBlas$.MODULE$.equals(engineType2)) {
            CrossEntropyCriterion$ crossEntropyCriterion$ = CrossEntropyCriterion$.MODULE$;
            CrossEntropyCriterion$.MODULE$.apply$default$1();
            apply$mFc$sp = crossEntropyCriterion$.apply$mFc$sp(null, CrossEntropyCriterion$.MODULE$.apply$default$2(), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
        } else {
            if (!MklDnn$.MODULE$.equals(engineType2)) {
                throw new MatchError(engineType2);
            }
            apply$mFc$sp = SoftmaxWithCriterion$.MODULE$.apply$mFc$sp(SoftmaxWithCriterion$.MODULE$.apply$default$1(), SoftmaxWithCriterion$.MODULE$.apply$default$2(), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
        }
        Optimizer apply2 = Optimizer$.MODULE$.apply(abstractModule2, trainDataSet, com.intel.analytics.bigdl.package$.MODULE$.convCriterion(apply$mFc$sp), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
        Trigger severalIteration = Trigger$.MODULE$.severalIteration(trainParams.checkpointIteration());
        ValidationMethod[] validationMethodArr = {new Top1Accuracy(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$), new Top5Accuracy(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$)};
        if (trainParams.checkpoint().isDefined()) {
            apply2.setCheckpoint((String) trainParams.checkpoint().get(), severalIteration);
        } else {
            BoxedUnit boxedUnit3 = BoxedUnit.UNIT;
        }
        apply2.setGradientClippingByl2Norm(BoxesRunTime.unboxToDouble(trainParams.gradientL2NormThreshold().getOrElse(() -> {
            return 10000.0d;
        }))).setOptimMethod(serializable).setValidation(severalIteration, valDataSet, validationMethodArr).setEndWhen(Trigger$.MODULE$.severalIteration(trainParams.maxIteration())).optimize();
        sparkContext.stop();
    }

    private TrainImageNet$() {
        MODULE$ = this;
        LoggerFilter$.MODULE$.redirectSparkInfoLogs(LoggerFilter$.MODULE$.redirectSparkInfoLogs$default$1());
        Logger.getLogger("com.intel.analytics.bigdl.optim").setLevel(Level.INFO);
        this.logger = Logger.getLogger(getClass());
    }
}
