package com.intel.analytics.bigdl.models.autoencoder;

import com.intel.analytics.bigdl.dataset.AbstractDataSet;
import com.intel.analytics.bigdl.dataset.ByteRecord;
import com.intel.analytics.bigdl.dataset.DataSet$;
import com.intel.analytics.bigdl.dataset.MiniBatch;
import com.intel.analytics.bigdl.dataset.image.BytesToGreyImg$;
import com.intel.analytics.bigdl.dataset.image.GreyImgNormalizer$;
import com.intel.analytics.bigdl.dataset.image.GreyImgToBatch$;
import com.intel.analytics.bigdl.dataset.image.LabeledGreyImage;
import com.intel.analytics.bigdl.models.autoencoder.Utils;
import com.intel.analytics.bigdl.nn.MSECriterion$mcF$sp;
import com.intel.analytics.bigdl.nn.Module$;
import com.intel.analytics.bigdl.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.optim.Adagrad$mcF$sp;
import com.intel.analytics.bigdl.optim.OptimMethod$;
import com.intel.analytics.bigdl.optim.Optimizer;
import com.intel.analytics.bigdl.optim.Optimizer$;
import com.intel.analytics.bigdl.optim.Trigger$;
import com.intel.analytics.bigdl.package$;
import com.intel.analytics.bigdl.tensor.TensorNumericMath$TensorNumeric$NumericFloat$;
import com.intel.analytics.bigdl.utils.Engine$;
import com.intel.analytics.bigdl.utils.OptimizerV1$;
import com.intel.analytics.bigdl.utils.OptimizerV2$;
import java.nio.file.Paths;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.SparkContext;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.collection.Seq;
import scala.collection.mutable.WrappedArray;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;

/* compiled from: Train.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/models/autoencoder/Train$.class */
public final class Train$ {
    public static Train$ MODULE$;

    static {
        new Train$();
    }

    public void main(String[] strArr) {
        Utils$.MODULE$.trainParser().parse((Seq<String>) Predef$.MODULE$.wrapRefArray(strArr), (WrappedArray) new Utils.TrainParams(Utils$TrainParams$.MODULE$.$lessinit$greater$default$1(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$2(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$3(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$4(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$5(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$6(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$7(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$8())).map(trainParams -> {
            $anonfun$main$1(trainParams);
            return BoxedUnit.UNIT;
        });
    }

    public static final /* synthetic */ void $anonfun$main$1(Utils.TrainParams trainParams) {
        SparkContext sparkContext = new SparkContext(Engine$.MODULE$.createSparkConf(Engine$.MODULE$.createSparkConf$default$1()).setAppName("Train Autoencoder on MNIST"));
        Engine$.MODULE$.init();
        AbstractDataSet $minus$greater = DataSet$.MODULE$.array(Utils$.MODULE$.load(Paths.get(trainParams.folder(), "/train-images-idx3-ubyte"), Paths.get(trainParams.folder(), "/train-labels-idx1-ubyte")), sparkContext, ClassTag$.MODULE$.apply(ByteRecord.class)).$minus$greater(BytesToGreyImg$.MODULE$.apply(28, 28), ClassTag$.MODULE$.apply(LabeledGreyImage.class)).$minus$greater(GreyImgNormalizer$.MODULE$.apply(Utils$.MODULE$.trainMean(), Utils$.MODULE$.trainStd()), ClassTag$.MODULE$.apply(LabeledGreyImage.class)).$minus$greater(GreyImgToBatch$.MODULE$.apply(trainParams.batchSize()), ClassTag$.MODULE$.apply(MiniBatch.class)).$minus$greater(toAutoencoderBatch$.MODULE$.apply(), ClassTag$.MODULE$.apply(MiniBatch.class));
        AbstractModule<Activity, Activity, Object> load = trainParams.modelSnapshot().isDefined() ? Module$.MODULE$.load((String) trainParams.modelSnapshot().get(), ClassTag$.MODULE$.Float()) : trainParams.graphModel() ? Autoencoder$.MODULE$.graph(32) : Autoencoder$.MODULE$.apply(32);
        if (trainParams.optimizerVersion().isDefined()) {
            String lowerCase = ((String) trainParams.optimizerVersion().get()).toLowerCase();
            if ("optimizerv1".equals(lowerCase)) {
                Engine$.MODULE$.setOptimizerVersion(OptimizerV1$.MODULE$);
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else {
                if (!"optimizerv2".equals(lowerCase)) {
                    throw new MatchError(lowerCase);
                }
                Engine$.MODULE$.setOptimizerVersion(OptimizerV2$.MODULE$);
                BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
            }
        }
        Serializable load2 = trainParams.stateSnapshot().isDefined() ? OptimMethod$.MODULE$.load((String) trainParams.stateSnapshot().get(), ClassTag$.MODULE$.Float()) : new Adagrad$mcF$sp(0.01d, 0.0d, 5.0E-4d, ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
        Optimizer apply = Optimizer$.MODULE$.apply(load, $minus$greater, package$.MODULE$.convCriterion(new MSECriterion$mcF$sp(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$)), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
        if (trainParams.checkpoint().isDefined()) {
            apply.setCheckpoint((String) trainParams.checkpoint().get(), Trigger$.MODULE$.everyEpoch());
        } else {
            BoxedUnit boxedUnit3 = BoxedUnit.UNIT;
        }
        apply.setOptimMethod(load2).setEndWhen(Trigger$.MODULE$.maxEpoch(trainParams.maxEpoch())).optimize();
        sparkContext.stop();
    }

    private Train$() {
        MODULE$ = this;
        Logger.getLogger("org").setLevel(Level.ERROR);
        Logger.getLogger("akka").setLevel(Level.ERROR);
        Logger.getLogger("breeze").setLevel(Level.ERROR);
    }
}
