package com.intel.analytics.bigdl.example.treeLSTMSentiment;

import caffe.Caffe;
import com.intel.analytics.bigdl.dataset.PaddingParam;
import com.intel.analytics.bigdl.dataset.PaddingParam$;
import com.intel.analytics.bigdl.dataset.Sample;
import com.intel.analytics.bigdl.example.treeLSTMSentiment.Utils;
import com.intel.analytics.bigdl.nn.ClassNLLCriterion$;
import com.intel.analytics.bigdl.nn.TimeDistributedCriterion;
import com.intel.analytics.bigdl.nn.TimeDistributedCriterion$;
import com.intel.analytics.bigdl.optim.Adagrad$;
import com.intel.analytics.bigdl.optim.Adagrad$mcF$sp;
import com.intel.analytics.bigdl.optim.Optimizer$;
import com.intel.analytics.bigdl.optim.TreeNNAccuracy;
import com.intel.analytics.bigdl.optim.Trigger$;
import com.intel.analytics.bigdl.optim.ValidationMethod;
import com.intel.analytics.bigdl.package$;
import com.intel.analytics.bigdl.tensor.Tensor;
import com.intel.analytics.bigdl.tensor.Tensor$;
import com.intel.analytics.bigdl.tensor.TensorNumericMath$TensorNumeric$NumericFloat$;
import com.intel.analytics.bigdl.utils.Engine$;
import com.intel.analytics.bigdl.utils.LoggerFilter$;
import com.intel.analytics.bigdl.utils.OptimizerV1$;
import com.intel.analytics.bigdl.utils.OptimizerV2$;
import com.intel.analytics.bigdl.utils.T$;
import org.apache.log4j.Level;
import org.apache.spark.SparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.rdd.RDD;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.MatchError;
import scala.Predef$;
import scala.Some;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.Seq;
import scala.collection.immutable.Map;
import scala.collection.immutable.StringOps;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: Train.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/example/treeLSTMSentiment/Train$.class */
public final class Train$ {
    public static Train$ MODULE$;
    private final Logger log;

    static {
        new Train$();
    }

    public Logger log() {
        return this.log;
    }

    public void main(String[] strArr) {
        train((Utils.TreeLSTMSentimentParam) Utils$.MODULE$.paramParser().parse(Predef$.MODULE$.wrapRefArray(strArr), new Utils.TreeLSTMSentimentParam(Utils$TreeLSTMSentimentParam$.MODULE$.apply$default$1(), Utils$TreeLSTMSentimentParam$.MODULE$.apply$default$2(), Utils$TreeLSTMSentimentParam$.MODULE$.apply$default$3(), Utils$TreeLSTMSentimentParam$.MODULE$.apply$default$4(), Utils$TreeLSTMSentimentParam$.MODULE$.apply$default$5(), Utils$TreeLSTMSentimentParam$.MODULE$.apply$default$6(), Utils$TreeLSTMSentimentParam$.MODULE$.apply$default$7(), Utils$TreeLSTMSentimentParam$.MODULE$.apply$default$8())).get());
    }

    public void train(Utils.TreeLSTMSentimentParam treeLSTMSentimentParam) {
        String baseDir = treeLSTMSentimentParam.baseDir();
        TimeDistributedCriterion$ timeDistributedCriterion$ = TimeDistributedCriterion$.MODULE$;
        ClassNLLCriterion$ classNLLCriterion$ = ClassNLLCriterion$.MODULE$;
        ClassNLLCriterion$.MODULE$.apply$default$1();
        TimeDistributedCriterion<Object> apply$mFc$sp = timeDistributedCriterion$.apply$mFc$sp(classNLLCriterion$.apply$mFc$sp(null, ClassNLLCriterion$.MODULE$.apply$default$2(), ClassNLLCriterion$.MODULE$.apply$default$3(), ClassNLLCriterion$.MODULE$.apply$default$4(), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$), TimeDistributedCriterion$.MODULE$.apply$default$2(), TimeDistributedCriterion$.MODULE$.apply$default$3(), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
        SparkContext sparkContext = new SparkContext(Engine$.MODULE$.createSparkConf(Engine$.MODULE$.createSparkConf$default$1()).setAppName("Text classification").set("spark.task.maxFailures", "1"));
        Engine$.MODULE$.init();
        String sb = new StringBuilder(26).append(baseDir).append("/glove/glove.840B.300d.txt").toString();
        String sb2 = new StringBuilder(20).append(baseDir).append("/sst/vocab-cased.txt").toString();
        log().info("Start loading embeddings\n");
        Tuple2<Tensor<Object>, Map<String, Object>> loadEmbeddingAndVocabulary = Utils$.MODULE$.loadEmbeddingAndVocabulary(sparkContext, sb, sb2, 3);
        if (loadEmbeddingAndVocabulary == null) {
            throw new MatchError(loadEmbeddingAndVocabulary);
        }
        Tuple2 tuple2 = new Tuple2((Tensor) loadEmbeddingAndVocabulary._1(), (Map) loadEmbeddingAndVocabulary._2());
        Tensor<Object> tensor = (Tensor) tuple2._1();
        Map map = (Map) tuple2._2();
        log().info("Finish loading embeddings\n");
        Broadcast<Map<String, Object>> broadcast = sparkContext.broadcast(map, ClassTag$.MODULE$.apply(Map.class));
        Tuple3<RDD<Tensor<Object>>, RDD<float[]>, RDD<int[]>> preProcessData = Utils$.MODULE$.preProcessData(sparkContext, broadcast, 2, new StringBuilder(22).append(baseDir).append("/sst/train/parents.txt").toString(), new StringBuilder(21).append(baseDir).append("/sst/train/labels.txt").toString(), new StringBuilder(20).append(baseDir).append("/sst/train/sents.txt").toString());
        if (preProcessData == null) {
            throw new MatchError(preProcessData);
        }
        Tuple3 tuple3 = new Tuple3((RDD) preProcessData._1(), (RDD) preProcessData._2(), (RDD) preProcessData._3());
        RDD<Tensor<Object>> rdd = (RDD) tuple3._1();
        RDD<float[]> rdd2 = (RDD) tuple3._2();
        RDD<int[]> rdd3 = (RDD) tuple3._3();
        log().info(new StringOps(Predef$.MODULE$.augmentString(new StringBuilder(Caffe.LayerParameter.DROPOUT_PARAM_FIELD_NUMBER).append("\n         |train treeRDD count: ").append(rdd.count()).append("\n         |train labelRDD count: ").append(rdd2.count()).append("\n         |train sentenceRDD count: ").append(rdd3.count()).append("\n      ").toString())).stripMargin());
        Tuple3<RDD<Tensor<Object>>, RDD<float[]>, RDD<int[]>> preProcessData2 = Utils$.MODULE$.preProcessData(sparkContext, broadcast, 2, new StringBuilder(20).append(baseDir).append("/sst/dev/parents.txt").toString(), new StringBuilder(19).append(baseDir).append("/sst/dev/labels.txt").toString(), new StringBuilder(18).append(baseDir).append("/sst/dev/sents.txt").toString());
        if (preProcessData2 == null) {
            throw new MatchError(preProcessData2);
        }
        Tuple3 tuple32 = new Tuple3((RDD) preProcessData2._1(), (RDD) preProcessData2._2(), (RDD) preProcessData2._3());
        RDD<Tensor<Object>> rdd4 = (RDD) tuple32._1();
        RDD<float[]> rdd5 = (RDD) tuple32._2();
        RDD<int[]> rdd6 = (RDD) tuple32._3();
        log().info(new StringOps(Predef$.MODULE$.augmentString(new StringBuilder(Caffe.LayerParameter.ACCURACY_PARAM_FIELD_NUMBER).append("\n         |dev treeRDD count: ").append(rdd4.count()).append("\n         |dev labelRDD count: ").append(rdd5.count()).append("\n         |dev sentenceRDD count: ").append(rdd6.count()).append("\n      ").toString())).stripMargin());
        RDD<Sample<Object>> sample = Utils$.MODULE$.toSample(rdd, rdd2, rdd3);
        RDD<Sample<Object>> sample2 = Utils$.MODULE$.toSample(rdd4, rdd5, rdd6);
        if (treeLSTMSentimentParam.optimizerVersion().isDefined()) {
            String lowerCase = ((String) treeLSTMSentimentParam.optimizerVersion().get()).toLowerCase();
            if ("optimizerv1".equals(lowerCase)) {
                Engine$.MODULE$.setOptimizerVersion(OptimizerV1$.MODULE$);
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else {
                if (!"optimizerv2".equals(lowerCase)) {
                    throw new MatchError(lowerCase);
                }
                Engine$.MODULE$.setOptimizerVersion(OptimizerV2$.MODULE$);
                BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
            }
        }
        Optimizer$.MODULE$.apply(TreeLSTMSentiment$.MODULE$.apply(tensor, treeLSTMSentimentParam.hiddenSize(), 5, treeLSTMSentimentParam.p()), sample, package$.MODULE$.convCriterion(apply$mFc$sp), treeLSTMSentimentParam.batchSize(), new PaddingParam(new Some(new Tensor[]{Tensor$.MODULE$.apply$mFc$sp(T$.MODULE$.apply(BoxesRunTime.boxToFloat(1), (Seq<Object>) Predef$.MODULE$.genericWrapArray(new Object[0])), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$), Tensor$.MODULE$.apply$mFc$sp(T$.MODULE$.apply(BoxesRunTime.boxToFloat(-1.0f), (Seq<Object>) Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToFloat(-1.0f), BoxesRunTime.boxToFloat(-1.0f)})), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$)}), PaddingParam$.MODULE$.apply$default$2(), ClassTag$.MODULE$.Float()), new PaddingParam(new Some(new Tensor[]{Tensor$.MODULE$.apply$mFc$sp(T$.MODULE$.apply(BoxesRunTime.boxToFloat(-1.0f), (Seq<Object>) Predef$.MODULE$.genericWrapArray(new Object[0])), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$)}), PaddingParam$.MODULE$.apply$default$2(), ClassTag$.MODULE$.Float()), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$).setOptimMethod(new Adagrad$mcF$sp(treeLSTMSentimentParam.learningRate(), Adagrad$.MODULE$.$lessinit$greater$default$2(), treeLSTMSentimentParam.regRate(), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$)).setValidation(Trigger$.MODULE$.everyEpoch(), sample2, new ValidationMethod[]{new TreeNNAccuracy(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$)}, treeLSTMSentimentParam.batchSize(), new PaddingParam(new Some(new Tensor[]{Tensor$.MODULE$.apply$mFc$sp(T$.MODULE$.apply(BoxesRunTime.boxToFloat(1), (Seq<Object>) Predef$.MODULE$.genericWrapArray(new Object[0])), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$), Tensor$.MODULE$.apply$mFc$sp(T$.MODULE$.apply(BoxesRunTime.boxToFloat(-1.0f), (Seq<Object>) Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToFloat(-1.0f), BoxesRunTime.boxToFloat(-1.0f)})), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$)}), PaddingParam$.MODULE$.apply$default$2(), ClassTag$.MODULE$.Float()), new PaddingParam(new Some(new Tensor[]{Tensor$.MODULE$.apply$mFc$sp(T$.MODULE$.apply(BoxesRunTime.boxToFloat(-1.0f), (Seq<Object>) Predef$.MODULE$.genericWrapArray(new Object[0])), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$)}), PaddingParam$.MODULE$.apply$default$2(), ClassTag$.MODULE$.Float())).setEndWhen(Trigger$.MODULE$.maxEpoch(treeLSTMSentimentParam.epoch())).optimize();
        sparkContext.stop();
    }

    private Train$() {
        MODULE$ = this;
        this.log = LoggerFactory.getLogger(getClass());
        LoggerFilter$.MODULE$.redirectSparkInfoLogs(LoggerFilter$.MODULE$.redirectSparkInfoLogs$default$1());
        org.apache.log4j.Logger.getLogger("com.intel.analytics.bigdl.optim").setLevel(Level.INFO);
    }
}
