package com.intel.analytics.bigdl.example.tensorflow.transferlearning;

import com.intel.analytics.bigdl.dataset.Sample;
import com.intel.analytics.bigdl.dataset.Sample$;
import com.intel.analytics.bigdl.example.tensorflow.transferlearning.Utils;
import com.intel.analytics.bigdl.nn.CrossEntropyCriterion;
import com.intel.analytics.bigdl.nn.CrossEntropyCriterion$;
import com.intel.analytics.bigdl.nn.Linear$;
import com.intel.analytics.bigdl.nn.Sequential;
import com.intel.analytics.bigdl.nn.Sequential$;
import com.intel.analytics.bigdl.nn.Squeeze$;
import com.intel.analytics.bigdl.nn.abstractnn.AbstractCriterion;
import com.intel.analytics.bigdl.optim.Optimizer;
import com.intel.analytics.bigdl.optim.Optimizer$;
import com.intel.analytics.bigdl.optim.RMSprop$;
import com.intel.analytics.bigdl.optim.RMSprop$mcF$sp;
import com.intel.analytics.bigdl.optim.Top1Accuracy;
import com.intel.analytics.bigdl.optim.Trigger;
import com.intel.analytics.bigdl.optim.Trigger$;
import com.intel.analytics.bigdl.optim.ValidationMethod;
import com.intel.analytics.bigdl.package$;
import com.intel.analytics.bigdl.tensor.Tensor;
import com.intel.analytics.bigdl.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.tensor.TensorNumericMath$TensorNumeric$NumericFloat$;
import com.intel.analytics.bigdl.utils.Engine$;
import com.intel.analytics.bigdl.utils.LoggerFilter$;
import com.intel.analytics.bigdl.utils.tf.BigDLSessionImpl;
import com.intel.analytics.bigdl.utils.tf.TensorflowLoader$;
import java.nio.ByteOrder;
import org.apache.spark.SparkContext;
import org.apache.spark.rdd.RDD;
import scala.Predef$;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: TransferLearning.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/example/tensorflow/transferlearning/TransferLearning$.class */
public final class TransferLearning$ {
    public static TransferLearning$ MODULE$;

    static {
        new TransferLearning$();
    }

    public void main(String[] strArr) {
        Utils$.MODULE$.trainParser().parse(Predef$.MODULE$.wrapRefArray(strArr), new Utils.TrainParams(Utils$TrainParams$.MODULE$.apply$default$1(), Utils$TrainParams$.MODULE$.apply$default$2(), Utils$TrainParams$.MODULE$.apply$default$3(), Utils$TrainParams$.MODULE$.apply$default$4())).map(trainParams -> {
            $anonfun$main$1(trainParams);
            return BoxedUnit.UNIT;
        });
    }

    private RDD<Sample<Object>> getData(String str, SparkContext sparkContext) {
        BigDLSessionImpl bigDLSessionImpl = (BigDLSessionImpl) TensorflowLoader$.MODULE$.checkpoints(new StringBuilder(9).append(str).append("/model.pb").toString(), new StringBuilder(10).append(str).append("/model.bin").toString(), ByteOrder.LITTLE_ENDIAN, ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
        return bigDLSessionImpl.getRDD((Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new String[]{"InceptionV1/Logits/AvgPool_0a_7x7/AvgPool", "OneHotEncoding/one_hot"})), sparkContext, bigDLSessionImpl.getRDD$default$3()).map(table -> {
            return Sample$.MODULE$.apply((Tensor) table.apply(BoxesRunTime.boxToInteger(1)), (Tensor) ((Tensor) table.apply(BoxesRunTime.boxToInteger(2))).max(1)._2(), ClassTag$.MODULE$.Float(), (TensorNumericMath.TensorNumeric) TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
        }, ClassTag$.MODULE$.apply(Sample.class));
    }

    public static final /* synthetic */ void $anonfun$main$1(Utils.TrainParams trainParams) {
        SparkContext sparkContext = new SparkContext(Engine$.MODULE$.createSparkConf(Engine$.MODULE$.createSparkConf$default$1()).setAppName("Transfer Learning").set("spark.task.maxFailures", "1"));
        Engine$.MODULE$.init();
        RDD<Sample<Object>> data = MODULE$.getData(trainParams.trainingModelDir(), sparkContext);
        Sequential<Object> apply$mFc$sp = Sequential$.MODULE$.apply$mFc$sp(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
        apply$mFc$sp.mo742add(Squeeze$.MODULE$.apply((int[]) null, true, ClassTag$.MODULE$.Float(), (TensorNumericMath.TensorNumeric) TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$));
        Linear$ linear$ = Linear$.MODULE$;
        boolean apply$default$3 = Linear$.MODULE$.apply$default$3();
        Linear$.MODULE$.apply$default$4();
        Linear$.MODULE$.apply$default$5();
        Linear$.MODULE$.apply$default$6();
        Linear$.MODULE$.apply$default$7();
        Linear$.MODULE$.apply$default$8();
        Linear$.MODULE$.apply$default$9();
        apply$mFc$sp.mo742add(linear$.apply$mFc$sp(1024, 5, apply$default$3, null, null, null, null, null, null, ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$));
        CrossEntropyCriterion$ crossEntropyCriterion$ = CrossEntropyCriterion$.MODULE$;
        CrossEntropyCriterion$.MODULE$.apply$default$1();
        CrossEntropyCriterion<Object> apply$mFc$sp2 = crossEntropyCriterion$.apply$mFc$sp(null, CrossEntropyCriterion$.MODULE$.apply$default$2(), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
        Optimizer$ optimizer$ = Optimizer$.MODULE$;
        AbstractCriterion convCriterion = package$.MODULE$.convCriterion(apply$mFc$sp2);
        int batchSize = trainParams.batchSize();
        Optimizer$.MODULE$.apply$default$5();
        Optimizer$.MODULE$.apply$default$6();
        Optimizer apply = optimizer$.apply(apply$mFc$sp, data, convCriterion, batchSize, null, null, ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
        Trigger maxEpoch = Trigger$.MODULE$.maxEpoch(trainParams.nEpochs());
        RMSprop$mcF$sp rMSprop$mcF$sp = new RMSprop$mcF$sp(0.001d, RMSprop$.MODULE$.$lessinit$greater$default$2(), 0.9d, RMSprop$.MODULE$.$lessinit$greater$default$4(), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
        apply.setEndWhen(maxEpoch);
        apply.setOptimMethod(rMSprop$mcF$sp);
        if (trainParams.validationModelDir().isDefined()) {
            apply.setValidation(Trigger$.MODULE$.everyEpoch(), MODULE$.getData((String) trainParams.validationModelDir().get(), sparkContext), new ValidationMethod[]{new Top1Accuracy(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$)}, trainParams.batchSize());
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        apply.optimize();
        sparkContext.stop();
    }

    private TransferLearning$() {
        MODULE$ = this;
        LoggerFilter$.MODULE$.redirectSparkInfoLogs(LoggerFilter$.MODULE$.redirectSparkInfoLogs$default$1());
    }
}
