package com.intel.analytics.bigdl.optim;

import com.intel.analytics.bigdl.dataset.AbstractDataSet;
import com.intel.analytics.bigdl.dataset.DistributedDataSet;
import com.intel.analytics.bigdl.dataset.MiniBatch;
import com.intel.analytics.bigdl.models.utils.ModelBroadcast;
import com.intel.analytics.bigdl.models.utils.ModelBroadcast$;
import com.intel.analytics.bigdl.nn.Container;
import com.intel.analytics.bigdl.nn.abstractnn.AbstractCriterion;
import com.intel.analytics.bigdl.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.optim.DistriOptimizer;
import com.intel.analytics.bigdl.optim.DistriOptimizerV2;
import com.intel.analytics.bigdl.parameters.AllReduceParameter;
import com.intel.analytics.bigdl.parameters.FutureResult;
import com.intel.analytics.bigdl.parameters.ParameterProcessor;
import com.intel.analytics.bigdl.tensor.Tensor;
import com.intel.analytics.bigdl.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.utils.Engine$;
import com.intel.analytics.bigdl.utils.Table;
import com.intel.analytics.bigdl.utils.intermediate.ConversionUtils$;
import com.intel.analytics.bigdl.visualization.TrainSummary;
import com.intel.analytics.bigdl.visualization.ValidationSummary;
import org.apache.spark.SparkContext;
import org.apache.spark.TaskContext$;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.rdd.RDD;
import org.apache.spark.util.DoubleAccumulator;
import scala.Array$;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Some;
import scala.Tuple2;
import scala.collection.GenTraversableOnce;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableOnce;
import scala.collection.immutable.IndexedSeq$;
import scala.collection.immutable.Iterable$;
import scala.collection.immutable.Map;
import scala.collection.immutable.Map$;
import scala.collection.immutable.MapLike;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.ArrayBuffer;
import scala.collection.mutable.ArrayBuffer$;
import scala.collection.mutable.ArrayOps;
import scala.package$;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.LazyRef;
import scala.runtime.RichInt$;
import scala.runtime.ScalaRunTime$;

/* compiled from: DistriOptimizerV2.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/optim/DistriOptimizerV2$.class */
public final class DistriOptimizerV2$ extends AbstractOptimizer {
    public static DistriOptimizerV2$ MODULE$;
    private Option<OptimizerLogger> _logger;

    static {
        new DistriOptimizerV2$();
    }

    public Option<OptimizerLogger> _logger() {
        return this._logger;
    }

    public void _logger_$eq(Option<OptimizerLogger> option) {
        this._logger = option;
    }

    public OptimizerLogger logger() {
        if (_logger().isEmpty()) {
            _logger_$eq(new Some(new DistriLogger()));
        }
        return (OptimizerLogger) _logger().get();
    }

    public <T> void optimize(MasterCache<T> masterCache, RDD<DistriOptimizerV2.Cache<T>> rdd, DistributedDataSet<MiniBatch<T>> distributedDataSet, Trigger trigger, Option<Trigger> option, Option<AbstractDataSet<MiniBatch<T>, ?>> option2, Option<ValidationMethod<T>[]> option3, Option<Trigger> option4, Option<String> option5, Option<TrainSummary> option6, Option<ValidationSummary> option7, boolean z, TrainingContext<T> trainingContext, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        OptimMethod optimMethod = (OptimMethod) masterCache.optimMethods().values().head();
        trainingContext.loadState(optimMethod.state());
        logger().info(new StringBuilder(7).append("config ").append(trainingContext.state()).toString());
        if (BoxesRunTime.unboxToInt(optimMethod.state().apply(StateEntry$.MODULE$.RECORDS_PROCESSED())) == 0) {
            long nanoTime = System.nanoTime();
            logger().info("Shuffle data");
            distributedDataSet.shuffle();
            logger().info(new StringBuilder(30).append("Shuffle data complete. Takes ").append((System.nanoTime() - nanoTime) / 1.0E9d).append("s").toString());
        }
        SparkContext sparkContext = distributedDataSet.originRDD().sparkContext();
        RDD<MiniBatch<T>> data = distributedDataSet.data(true);
        TrainingTrace apply = TrainingTrace$.MODULE$.apply(optimMethod.state());
        while (!trigger.apply(trainingContext.state())) {
            iteration(sparkContext, data, rdd, masterCache, trainingContext, apply, classTag, tensorNumeric);
            if (trainingContext.hasCompleteAllSamples(apply.recordsOfEpoch(), masterCache.model())) {
                distributedDataSet.shuffle();
                data = distributedDataSet.data(true);
            }
            validate(option, option2, option3, trainingContext.subModelNumber(), rdd, trainingContext.state(), option7, Optimizer$.MODULE$.header(apply.epochs(), apply.recordsOfEpoch(), trainingContext.numSamples(), apply.iterations(), apply.trainingTakes()), masterCache.parameter());
            checkpoint(option4, option5, z, apply.trainingTakes(), rdd, trainingContext.state(), masterCache.parameter(), masterCache.optimMethods(), masterCache.model(), classTag, tensorNumeric);
            option6.foreach(trainSummary -> {
                $anonfun$optimize$1(rdd, trainingContext, masterCache, classTag, tensorNumeric, trainSummary);
                return BoxedUnit.UNIT;
            });
        }
    }

    private void initMetrics(SparkContext sparkContext, Metrics metrics, int i) {
        metrics.set(COMPUTING_TIME_EACH_NODE$.MODULE$.value(), (ArrayBuffer<Object>) ArrayBuffer$.MODULE$.apply(Nil$.MODULE$), sparkContext);
        metrics.set(GET_WEIGHTS_EACH_NODE$.MODULE$.value(), (ArrayBuffer<Object>) ArrayBuffer$.MODULE$.apply(Nil$.MODULE$), sparkContext);
        metrics.set(COMPUTING_TIME_AVERAGE$.MODULE$.value(), 0.0d, sparkContext, i);
        metrics.set(AGGREGATE_GRADIENT_TIME$.MODULE$.value(), 0.0d, sparkContext, i);
        metrics.set(GET_WEIGHTS_AVERAGE$.MODULE$.value(), 0.0d, sparkContext, i);
        metrics.set(PUT_GRADIENT$.MODULE$.value(), 0.0d, sparkContext, Engine$.MODULE$.nodeNumber());
        metrics.set(AGGREGATE_PARTITION_GRADIENT$.MODULE$.value(), 0.0d, sparkContext, Engine$.MODULE$.nodeNumber());
        metrics.set(COMPUTE_WEIGHT_AVERAGE$.MODULE$.value(), 0.0d, sparkContext, Engine$.MODULE$.nodeNumber());
        metrics.set(SEND_WEIGHTS_AVERAGE$.MODULE$.value(), 0.0d, sparkContext, Engine$.MODULE$.nodeNumber());
    }

    private <T> void iteration(SparkContext sparkContext, RDD<MiniBatch<T>> rdd, RDD<DistriOptimizerV2.Cache<T>> rdd2, MasterCache<T> masterCache, TrainingContext<T> trainingContext, TrainingTrace trainingTrace, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        DoubleAccumulator doubleAccumulator = sparkContext.doubleAccumulator("loss sum");
        DoubleAccumulator doubleAccumulator2 = sparkContext.doubleAccumulator("record number");
        Metrics metrics = masterCache.metrics();
        initMetrics(sparkContext, metrics, masterCache.partitionNum());
        trainingTrace.traceIteration(() -> {
            MODULE$.parameterSync(Predef$.MODULE$.Double2double(doubleAccumulator.value()), BoxesRunTime.unboxToInt(rdd.zipPartitions(rdd2, true, (iterator, iterator2) -> {
                DistriOptimizerV2.Cache cache = (DistriOptimizerV2.Cache) iterator2.next();
                Tensor narrow = ((Tensor) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(cache.modelWeights())).head()).narrow(1, cache.parameter().paramOffset(), cache.parameter().size());
                DistriOptimizerV2.TrainingResults train = MODULE$.train(cache, (MiniBatch[]) TrainingTrace$.MODULE$.time(() -> {
                    FutureResult<Object> weights = cache.parameter().getWeights(narrow);
                    MiniBatch[] fetchBatch = trainingContext.fetchBatch(iterator, classTag);
                    weights.waitResult();
                    return fetchBatch;
                }, metrics, new MetricEntry[]{GET_WEIGHTS_AVERAGE$.MODULE$, GET_WEIGHTS_EACH_NODE$.MODULE$}), trainingContext, metrics, classTag, tensorNumeric);
                doubleAccumulator.add(train.loss());
                doubleAccumulator2.add(train.records());
                return package$.MODULE$.Iterator().single(BoxesRunTime.boxToInteger(train.successed()));
            }, ClassTag$.MODULE$.apply(DistriOptimizerV2.Cache.class), ClassTag$.MODULE$.Int()).reduce((i, i2) -> {
                return i + i2;
            })), masterCache, rdd2, trainingContext, classTag, tensorNumeric);
        });
        driverStatesUpdate(masterCache, (int) Predef$.MODULE$.Double2double(doubleAccumulator2.value()), trainingContext, trainingTrace, metrics, classTag, tensorNumeric);
    }

    public <T> Tuple2<RDD<DistriOptimizerV2.Cache<T>>, ModelBroadcast<T>> com$intel$analytics$bigdl$optim$DistriOptimizerV2$$initCacheOfSlave(MasterCache<T> masterCache, DistributedDataSet<MiniBatch<T>> distributedDataSet, TrainingContext<T> trainingContext, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        DistriOptimizerV2$TrainingConfig$1<T> apply = TrainingConfig$3(new LazyRef()).apply(masterCache.criterion(), masterCache.validationMethods(), masterCache.optimMethods(), masterCache.parameterSplits(), masterCache.parameterProcessers());
        SparkContext sparkContext = distributedDataSet.originRDD().sparkContext();
        Broadcast broadcast = sparkContext.broadcast(apply, ClassTag$.MODULE$.apply(DistriOptimizerV2$TrainingConfig$1.class));
        AbstractModule<Activity, Activity, T> convert = ConversionUtils$.MODULE$.convert(masterCache.model(), classTag);
        convert.getParameters();
        ModelBroadcast<T> broadcast2 = ModelBroadcast$.MODULE$.apply(classTag, tensorNumeric).broadcast(sparkContext, convert);
        Engine$.MODULE$.nodeNumber();
        Engine$.MODULE$.coreNumber();
        AllReduceParameter<T> parameter = masterCache.parameter();
        int subModelNumber = trainingContext.subModelNumber();
        Table state = trainingContext.state();
        RDD<?> originRDD = distributedDataSet.originRDD();
        RDD persist = originRDD.mapPartitions(iterator -> {
            int partitionId = TaskContext$.MODULE$.getPartitionId();
            DistriOptimizerV2$TrainingConfig$1 distriOptimizerV2$TrainingConfig$1 = (DistriOptimizerV2$TrainingConfig$1) broadcast.value();
            Replica[] replicaArr = (Replica[]) ((TraversableOnce) RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), subModelNumber).map(obj -> {
                return $anonfun$initCacheOfSlave$2(broadcast2, distriOptimizerV2$TrainingConfig$1, state, partitionId, classTag, BoxesRunTime.unboxToInt(obj));
            }, IndexedSeq$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.apply(Replica.class));
            MODULE$.logger().info(new StringBuilder(26).append("model thread pool size is ").append(Engine$.MODULE$.model().getPoolSize()).toString());
            parameter.init(((Replica) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(replicaArr)).head()).weights().narrow(1, parameter.paramOffset(), parameter.size()), tensorNumeric);
            return package$.MODULE$.Iterator().single(new DistriOptimizerV2.Cache((AbstractModule[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(replicaArr)).map(replica -> {
                return replica.model();
            }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(AbstractModule.class))), (Tensor[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(replicaArr)).map(replica2 -> {
                return replica2.weights();
            }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tensor.class))), (Tensor[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(replicaArr)).map(replica3 -> {
                return replica3.gradients();
            }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tensor.class))), (AbstractCriterion[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(replicaArr)).map(replica4 -> {
                return replica4.criterion();
            }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(AbstractCriterion.class))), (Table[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(replicaArr)).map(replica5 -> {
                return replica5.state();
            }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Table.class))), new long[subModelNumber], (Option[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(replicaArr)).map(replica6 -> {
                return replica6.validationMethods();
            }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Option.class))), (Map) distriOptimizerV2$TrainingConfig$1.optimMethods().map(tuple2 -> {
                return new Tuple2(tuple2._1(), ((OptimMethod) tuple2._2()).m2819clone());
            }, Map$.MODULE$.canBuildFrom()), null, parameter, distriOptimizerV2$TrainingConfig$1.parameterSplits(), distriOptimizerV2$TrainingConfig$1.parameterProcessers()));
        }, originRDD.mapPartitions$default$2(), ClassTag$.MODULE$.apply(DistriOptimizerV2.Cache.class)).persist();
        persist.setName("Thread Model RDD");
        logger().info("Cache thread models...");
        persist.count();
        logger().info("Cache thread models... done");
        return new Tuple2<>(persist, broadcast2);
    }

    private <T> void setModelId(AbstractModule<Activity, Activity, T> abstractModule, int i, ClassTag<T> classTag) {
        abstractModule.setId(i);
        if (abstractModule instanceof Container) {
            ((Container) abstractModule).modules().foreach(abstractModule2 -> {
                $anonfun$setModelId$1(i, classTag, abstractModule2);
                return BoxedUnit.UNIT;
            });
        }
    }

    @Override // com.intel.analytics.bigdl.optim.AbstractOptimizer
    public <T> AbstractModule<Activity, Activity, T> getModel(RDD<DistriOptimizer.Cache<T>> rdd, AllReduceParameter<T> allReduceParameter, AbstractModule<Activity, Activity, T> abstractModule, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        int length = rdd.partitions().length;
        abstractModule.setExtraParameter((Tensor[]) rdd.map(cache -> {
            return ((AbstractModule) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(cache.localModels())).head()).getExtraParameter();
        }, ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Tensor.class))).first());
        Tuple2<Tensor<T>[], Tensor<T>[]> parameters = abstractModule.parameters();
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), ((Tensor[]) parameters._2()).length).foreach(obj -> {
            return $anonfun$getModel$2(parameters, BoxesRunTime.unboxToInt(obj));
        });
        Tuple2<Tensor<T>, Tensor<T>> parameters2 = abstractModule.getParameters();
        if (parameters2 == null) {
            throw new MatchError(parameters2);
        }
        Tuple2 tuple2 = new Tuple2((Tensor) parameters2._1(), (Tensor) parameters2._2());
        Tensor tensor = (Tensor) tuple2._1();
        Tensor tensor2 = (Tensor) tuple2._2();
        Tuple2 tuple22 = (Tuple2) rdd.mapPartitions(iterator -> {
            int partitionId = TaskContext$.MODULE$.getPartitionId();
            return package$.MODULE$.Iterator().single(new Tuple2(Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(BoxesRunTime.boxToInteger(partitionId)), allReduceParameter.weightPartition())})), Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(BoxesRunTime.boxToInteger(partitionId)), allReduceParameter.gradientPartition())}))));
        }, rdd.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Tuple2.class)).reduce((tuple23, tuple24) -> {
            return new Tuple2(((MapLike) tuple23._1()).$plus$plus((GenTraversableOnce) tuple24._1()), ((MapLike) tuple23._2()).$plus$plus((GenTraversableOnce) tuple24._2()));
        });
        if (tuple22 == null) {
            throw new MatchError(tuple22);
        }
        Tuple2 tuple25 = new Tuple2((Map) tuple22._1(), (Map) tuple22._2());
        Map map = (Map) tuple25._1();
        Map map2 = (Map) tuple25._2();
        int size = allReduceParameter.size() / length;
        Predef$.MODULE$.require(size != 0, () -> {
            return "parameter length should not less than partition number";
        });
        int size2 = allReduceParameter.size() % length;
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), length).map(obj2 -> {
            return $anonfun$getModel$6(allReduceParameter, size, size2, tensor, map, tensor2, map2, BoxesRunTime.unboxToInt(obj2));
        }, IndexedSeq$.MODULE$.canBuildFrom());
        return abstractModule;
    }

    private <T> DistriOptimizerV2.TrainingResults train(DistriOptimizerV2.Cache<T> cache, MiniBatch<T>[] miniBatchArr, TrainingContext<T> trainingContext, Metrics metrics, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        int size = ((MiniBatch) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(miniBatchArr)).head()).size();
        ArrayBuffer arrayBuffer = new ArrayBuffer();
        Seq seq = (Seq) TrainingTrace$.MODULE$.time(() -> {
            return trainingContext.train(miniBatchArr, cache.localModels(), cache.localCriterions(), classTag, tensorNumeric);
        }, metrics, new MetricEntry[]{COMPUTING_TIME_EACH_NODE$.MODULE$, COMPUTING_TIME_AVERAGE$.MODULE$});
        double d = 0.0d;
        int i = 0;
        while (true) {
            int i2 = i;
            if (i2 >= seq.size()) {
                Tensor tensor = (Tensor) TrainingTrace$.MODULE$.time(() -> {
                    if (seq.nonEmpty()) {
                        return trainingContext.aggregate((Tensor[]) ((TraversableOnce) seq.map(lossWithElapsedTime -> {
                            return cache.modelGradients()[lossWithElapsedTime.index()];
                        }, Seq$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.apply(Tensor.class)), classTag);
                    }
                    cache.modelGradients()[0].zero();
                    return cache.modelGradients()[0];
                }, metrics, new MetricEntry[]{AGGREGATE_GRADIENT_TIME$.MODULE$});
                TrainingTrace$.MODULE$.time(() -> {
                    cache.parameter().putGradients(tensor);
                }, metrics, new MetricEntry[]{PUT_GRADIENT$.MODULE$});
                arrayBuffer.$plus$plus$eq(Engine$.MODULE$.m3091default().invoke((Seq) RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), trainingContext.subModelNumber()).map(obj -> {
                    return ()
                    /*  JADX ERROR: Method code generation error
                        jadx.core.utils.exceptions.CodegenException: Error generate insn: 0x0008: RETURN 
                          (wrap:scala.Function0:0x0005: INVOKE_CUSTOM 
                          (r3v0 'cache' com.intel.analytics.bigdl.optim.DistriOptimizerV2$Cache)
                          (wrap:int:0x0002: INVOKE (r4v0 'obj' java.lang.Object) STATIC call: scala.runtime.BoxesRunTime.unboxToInt(java.lang.Object):int A[WRAPPED])
                         A[MD:(com.intel.analytics.bigdl.optim.DistriOptimizerV2$Cache, int):scala.runtime.java8.JFunction0$mcV$sp (s), WRAPPED]
                         handle type: INVOKE_STATIC
                         lambda: scala.runtime.java8.JFunction0.mcV.sp.apply$mcV$sp():void
                         call insn: INVOKE (r0 I:com.intel.analytics.bigdl.optim.DistriOptimizerV2$Cache), (r1 I:int) STATIC call: com.intel.analytics.bigdl.optim.DistriOptimizerV2$.$anonfun$train$6(com.intel.analytics.bigdl.optim.DistriOptimizerV2$Cache, int):void A[MD:(com.intel.analytics.bigdl.optim.DistriOptimizerV2$Cache, int):void (m)])
                         in method: com.intel.analytics.bigdl.optim.DistriOptimizerV2$.$anonfun$train$5$adapted(com.intel.analytics.bigdl.optim.DistriOptimizerV2$Cache, java.lang.Object):scala.Function0, file: input_file:com/intel/analytics/bigdl/optim/DistriOptimizerV2$.class
                        	at jadx.core.codegen.InsnGen.makeInsn(InsnGen.java:310)
                        	at jadx.core.codegen.InsnGen.makeInsn(InsnGen.java:273)
                        	at jadx.core.codegen.RegionGen.makeSimpleBlock(RegionGen.java:94)
                        	at jadx.core.dex.nodes.IBlock.generate(IBlock.java:15)
                        	at jadx.core.codegen.RegionGen.makeRegion(RegionGen.java:66)
                        	at jadx.core.dex.regions.Region.generate(Region.java:35)
                        	at jadx.core.codegen.RegionGen.makeRegion(RegionGen.java:66)
                        	at jadx.core.codegen.MethodGen.addRegionInsns(MethodGen.java:297)
                        	at jadx.core.codegen.MethodGen.addInstructions(MethodGen.java:276)
                        	at jadx.core.codegen.InsnGen.makeInlinedLambdaMethod(InsnGen.java:1048)
                        	at jadx.core.codegen.InsnGen.makeInvokeLambda(InsnGen.java:936)
                        	at jadx.core.codegen.InsnGen.makeInvoke(InsnGen.java:827)
                        	at jadx.core.codegen.InsnGen.makeInsnBody(InsnGen.java:422)
                        	at jadx.core.codegen.InsnGen.addWrappedArg(InsnGen.java:145)
                        	at jadx.core.codegen.InsnGen.addArg(InsnGen.java:121)
                        	at jadx.core.codegen.InsnGen.addArg(InsnGen.java:108)
                        	at jadx.core.codegen.InsnGen.generateMethodArguments(InsnGen.java:1117)
                        	at jadx.core.codegen.InsnGen.makeInvoke(InsnGen.java:884)
                        	at jadx.core.codegen.InsnGen.makeInsnBody(InsnGen.java:422)
                        	at jadx.core.codegen.InsnGen.addWrappedArg(InsnGen.java:145)
                        	at jadx.core.codegen.InsnGen.addArg(InsnGen.java:121)
                        	at jadx.core.codegen.InsnGen.addArg(InsnGen.java:108)
                        	at jadx.core.codegen.InsnGen.makeInsnBody(InsnGen.java:345)
                        	at jadx.core.codegen.InsnGen.addWrappedArg(InsnGen.java:145)
                        	at jadx.core.codegen.InsnGen.addArg(InsnGen.java:121)
                        	at jadx.core.codegen.InsnGen.addArg(InsnGen.java:108)
                        	at jadx.core.codegen.InsnGen.generateMethodArguments(InsnGen.java:1117)
                        	at jadx.core.codegen.InsnGen.makeInvoke(InsnGen.java:884)
                        	at jadx.core.codegen.InsnGen.makeInsnBody(InsnGen.java:422)
                        	at jadx.core.codegen.InsnGen.addWrappedArg(InsnGen.java:145)
                        	at jadx.core.codegen.InsnGen.addArg(InsnGen.java:121)
                        	at jadx.core.codegen.InsnGen.addArg(InsnGen.java:108)
                        	at jadx.core.codegen.InsnGen.generateMethodArguments(InsnGen.java:1117)
                        	at jadx.core.codegen.InsnGen.makeInvoke(InsnGen.java:884)
                        	at jadx.core.codegen.InsnGen.makeInsnBody(InsnGen.java:422)
                        	at jadx.core.codegen.InsnGen.makeInsn(InsnGen.java:303)
                        	at jadx.core.codegen.InsnGen.makeInsn(InsnGen.java:273)
                        	at jadx.core.codegen.RegionGen.makeSimpleBlock(RegionGen.java:94)
                        	at jadx.core.dex.nodes.IBlock.generate(IBlock.java:15)
                        	at jadx.core.codegen.RegionGen.makeRegion(RegionGen.java:66)
                        	at jadx.core.dex.regions.Region.generate(Region.java:35)
                        	at jadx.core.codegen.RegionGen.makeRegion(RegionGen.java:66)
                        	at jadx.core.codegen.RegionGen.makeRegionIndent(RegionGen.java:83)
                        	at jadx.core.codegen.RegionGen.makeIf(RegionGen.java:126)
                        	at jadx.core.dex.regions.conditions.IfRegion.generate(IfRegion.java:90)
                        	at jadx.core.codegen.RegionGen.makeRegion(RegionGen.java:66)
                        	at jadx.core.dex.regions.Region.generate(Region.java:35)
                        	at jadx.core.codegen.RegionGen.makeRegion(RegionGen.java:66)
                        	at jadx.core.dex.regions.Region.generate(Region.java:35)
                        	at jadx.core.codegen.RegionGen.makeRegion(RegionGen.java:66)
                        	at jadx.core.codegen.RegionGen.makeRegionIndent(RegionGen.java:83)
                        	at jadx.core.codegen.RegionGen.makeLoop(RegionGen.java:175)
                        	at jadx.core.dex.regions.loops.LoopRegion.generate(LoopRegion.java:171)
                        	at jadx.core.codegen.RegionGen.makeRegion(RegionGen.java:66)
                        	at jadx.core.dex.regions.Region.generate(Region.java:35)
                        	at jadx.core.codegen.RegionGen.makeRegion(RegionGen.java:66)
                        	at jadx.core.codegen.MethodGen.addRegionInsns(MethodGen.java:297)
                        	at jadx.core.codegen.MethodGen.addInstructions(MethodGen.java:276)
                        	at jadx.core.codegen.ClassGen.addMethodCode(ClassGen.java:406)
                        	at jadx.core.codegen.ClassGen.addMethod(ClassGen.java:335)
                        	at jadx.core.codegen.ClassGen.lambda$addInnerClsAndMethods$3(ClassGen.java:301)
                        	at java.base/java.util.stream.ForEachOps$ForEachOp$OfRef.accept(ForEachOps.java:184)
                        	at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
                        	at java.base/java.util.stream.SortedOps$RefSortingSink.end(SortedOps.java:395)
                        	at java.base/java.util.stream.Sink$ChainedReference.end(Sink.java:261)
                        Caused by: jadx.core.utils.exceptions.JadxRuntimeException: Unexpected argument type in lambda call: InsnWrapArg
                        	at jadx.core.codegen.InsnGen.makeInlinedLambdaMethod(InsnGen.java:1043)
                        	at jadx.core.codegen.InsnGen.makeInvokeLambda(InsnGen.java:936)
                        	at jadx.core.codegen.InsnGen.makeInvoke(InsnGen.java:827)
                        	at jadx.core.codegen.InsnGen.makeInsnBody(InsnGen.java:422)
                        	at jadx.core.codegen.InsnGen.addWrappedArg(InsnGen.java:145)
                        	at jadx.core.codegen.InsnGen.addArg(InsnGen.java:121)
                        	at jadx.core.codegen.InsnGen.addArg(InsnGen.java:108)
                        	at jadx.core.codegen.InsnGen.makeInsnBody(InsnGen.java:368)
                        	at jadx.core.codegen.InsnGen.makeInsn(InsnGen.java:303)
                        	... 64 more
                        */
                    /*
                        r0 = r3
                        r1 = r4
                        int r1 = scala.runtime.BoxesRunTime.unboxToInt(r1)
                        scala.Function0 r0 = $anonfun$train$5(r0, r1)
                        return r0
                    */
                    throw new UnsupportedOperationException("Method not decompiled: com.intel.analytics.bigdl.optim.DistriOptimizerV2$.$anonfun$train$5$adapted(com.intel.analytics.bigdl.optim.DistriOptimizerV2$Cache, java.lang.Object):scala.Function0");
                }, IndexedSeq$.MODULE$.canBuildFrom())));
                return new DistriOptimizerV2.TrainingResults(seq.size(), d, seq.size() * size);
            }
            d += ((LossWithElapsedTime) seq.apply(i2)).loss();
            cache.moduleTimeList()[i2] = ((LossWithElapsedTime) seq.apply(i2)).elapsed();
            i = i2 + 1;
        }
    }

    private <T> void updateStates(Map<String, OptimMethod<T>> map, Table table, boolean z) {
        map.map(tuple2 -> {
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            OptimMethod optimMethod = (OptimMethod) tuple2._2();
            optimMethod.state().update(StateEntry$.MODULE$.EPOCH(), table.apply(StateEntry$.MODULE$.EPOCH()));
            optimMethod.state().update(StateEntry$.MODULE$.NEVAL(), table.apply(StateEntry$.MODULE$.NEVAL()));
            optimMethod.state().update(StateEntry$.MODULE$.LOSS(), table.apply(StateEntry$.MODULE$.LOSS()));
            if (z) {
                optimMethod.state().update(StateEntry$.MODULE$.SCORE(), table.apply(StateEntry$.MODULE$.SCORE()));
            } else {
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            }
            return optimMethod.state().keySet().contains(StateEntry$.MODULE$.RECORDS_PROCESSED()) ? optimMethod.state().update(StateEntry$.MODULE$.RECORDS_PROCESSED(), table.apply(StateEntry$.MODULE$.RECORDS_PROCESSED())) : BoxedUnit.UNIT;
        }, Iterable$.MODULE$.canBuildFrom());
    }

    private <T> void driverStatesUpdate(MasterCache<T> masterCache, int i, TrainingContext<T> trainingContext, TrainingTrace trainingTrace, Metrics metrics, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Map<String, OptimMethod<T>> optimMethods = masterCache.optimMethods();
        boolean isDefined = masterCache.validationMethods().isDefined();
        optimMethods.foreach(tuple2 -> {
            $anonfun$driverStatesUpdate$1(tuple2);
            return BoxedUnit.UNIT;
        });
        long trainingTakes = trainingTrace.trainingTakes();
        long iterationTakes = trainingTrace.iterationTakes();
        float f = i / ((float) (iterationTakes / 1.0E9d));
        String header = Optimizer$.MODULE$.header(trainingTrace.epochs(), trainingTrace.updateRecords(i).recordsOfEpoch(), trainingContext.numSamples(), trainingTrace.iterations(), trainingTakes);
        logger().info(new StringBuilder(72).append(header).append(" Trained ").append(i).append(" records in ").append((float) (iterationTakes / 1.0E9d)).append(" seconds. ").append("Throughput is ").append(f).append(" records/second. ").append("Loss is ").append(BoxesRunTime.unboxToFloat(trainingContext.state().apply(StateEntry$.MODULE$.LOSS()))).append(". ").append(Optimizer$.MODULE$.getHyperParameterLog(optimMethods)).toString());
        logger().debug(new StringBuilder(1).append("\n").append(metrics.summary(metrics.summary$default$1(), metrics.summary$default$2())).toString());
        trainingContext.state().update(StateEntry$.MODULE$.THROUGHPUT(), BoxesRunTime.boxToFloat(i / ((float) (iterationTakes / 1.0E9d))));
        trainingContext.state().update(StateEntry$.MODULE$.NEVAL(), BoxesRunTime.boxToInteger(trainingTrace.iterations() + 1));
        trainingContext.state().update(StateEntry$.MODULE$.LEARNING_RATE(), BoxesRunTime.boxToFloat((float) ((OptimMethod) ((Tuple2) optimMethods.head())._2()).getLearningRate()));
        if (trainingContext.hasCompleteAllSamples(trainingTrace.recordsOfEpoch(), masterCache.model())) {
            trainingTrace.startNewEpoch();
            logger().info(new StringBuilder(39).append(header).append(" Epoch finished. Wall clock time is ").append(trainingTakes / 1000000.0d).append(" ms").toString());
        }
        trainingContext.state().update(StateEntry$.MODULE$.EPOCH(), BoxesRunTime.boxToInteger(trainingTrace.epochs()));
        trainingContext.state().update(StateEntry$.MODULE$.RECORDS_PROCESSED(), BoxesRunTime.boxToInteger(trainingTrace.recordsOfEpoch()));
        updateStates(optimMethods, trainingContext.state(), isDefined);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public <T> void parameterSync(double d, int i, MasterCache<T> masterCache, RDD<DistriOptimizerV2.Cache<T>> rdd, TrainingContext<T> trainingContext, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Metrics metrics = masterCache.metrics();
        AllReduceParameter<T> parameter = masterCache.parameter();
        boolean isDefined = masterCache.validationMethods().isDefined();
        trainingContext.state().update(StateEntry$.MODULE$.NUM_FINISHED_MODELS(), BoxesRunTime.boxToInteger(i));
        trainingContext.state().update(StateEntry$.MODULE$.IS_GRADIENT_UPDATED(), BoxesRunTime.boxToBoolean(false));
        new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(masterCache.parameterProcessers())).foreach(parameterProcessor -> {
            $anonfun$parameterSync$1(rdd, parameter, metrics, trainingContext, tensorNumeric, parameterProcessor);
            return BoxedUnit.UNIT;
        });
        boolean unboxToBoolean = BoxesRunTime.unboxToBoolean(trainingContext.state().apply(StateEntry$.MODULE$.IS_GRADIENT_UPDATED()));
        rdd.mapPartitions(iterator -> {
            DistriOptimizerV2.Cache cache = (DistriOptimizerV2.Cache) iterator.next();
            Map optimMethods = cache.optimMethods();
            ParameterProcessor[] parameterProcessers = cache.parameterProcessers();
            Map<String, Tuple2<Object, Object>> parameterSplits = cache.parameterSplits();
            Tuple2<Object, Object> localPartitionRange = cache.parameter().localPartitionRange();
            if (localPartitionRange == null) {
                throw new MatchError(localPartitionRange);
            }
            Tuple2.mcII.sp spVar = new Tuple2.mcII.sp(localPartitionRange._1$mcI$sp(), localPartitionRange._2$mcI$sp());
            int _1$mcI$sp = spVar._1$mcI$sp();
            int _2$mcI$sp = spVar._2$mcI$sp();
            if (unboxToBoolean) {
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else {
                TrainingTrace$.MODULE$.time(() -> {
                    cache.parameter().aggregateGradientPartition(i);
                }, metrics, new MetricEntry[]{AGGREGATE_PARTITION_GRADIENT$.MODULE$});
            }
            new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(parameterProcessers)).foreach(parameterProcessor2 -> {
                $anonfun$parameterSync$4(parameter, cache, trainingContext, tensorNumeric, parameterProcessor2);
                return BoxedUnit.UNIT;
            });
            MODULE$.updateStates(optimMethods, trainingContext.state(), isDefined);
            Map map = (Map) optimMethods.map(tuple2 -> {
                if (tuple2 == null) {
                    throw new MatchError(tuple2);
                }
                String str = (String) tuple2._1();
                OptimMethod optimMethod = (OptimMethod) tuple2._2();
                Tuple2 tuple2 = (Tuple2) parameterSplits.apply(str);
                int max = Math.max(_1$mcI$sp, tuple2._1$mcI$sp());
                return new Tuple2(str, new ParamSegments((max - _1$mcI$sp) + 1, Math.min(_2$mcI$sp + _1$mcI$sp, tuple2._1$mcI$sp() + tuple2._2$mcI$sp()) - max, optimMethod));
            }, Map$.MODULE$.canBuildFrom());
            Tensor weightPartition = cache.parameter().weightPartition();
            Tensor gradientPartition = cache.parameter().gradientPartition();
            double d2 = d / i;
            TrainingTrace$.MODULE$.time(() -> {
                trainingContext.update(map, weightPartition, gradientPartition, d2, classTag, tensorNumeric);
            }, metrics, new MetricEntry[]{COMPUTE_WEIGHT_AVERAGE$.MODULE$});
            TrainingTrace$.MODULE$.time(() -> {
                cache.parameter().sendWeightPartition();
            }, metrics, new MetricEntry[]{SEND_WEIGHTS_AVERAGE$.MODULE$});
            return package$.MODULE$.Iterator().empty();
        }, rdd.mapPartitions$default$2(), classTag).count();
        trainingContext.state().update(StateEntry$.MODULE$.IS_GRADIENT_UPDATED(), BoxesRunTime.boxToBoolean(true));
        trainingContext.state().update(StateEntry$.MODULE$.LOSS(), BoxesRunTime.boxToFloat(((float) d) / i));
    }

    public static final /* synthetic */ void $anonfun$optimize$1(RDD rdd, TrainingContext trainingContext, MasterCache masterCache, ClassTag classTag, TensorNumericMath.TensorNumeric tensorNumeric, TrainSummary trainSummary) {
        MODULE$.saveSummary(trainSummary, rdd, trainingContext.state(), masterCache.parameter(), masterCache.model(), classTag, tensorNumeric);
    }

    private static final /* synthetic */ DistriOptimizerV2$TrainingConfig$2$ TrainingConfig$lzycompute$1(LazyRef lazyRef) {
        DistriOptimizerV2$TrainingConfig$2$ distriOptimizerV2$TrainingConfig$2$;
        synchronized (lazyRef) {
            distriOptimizerV2$TrainingConfig$2$ = lazyRef.initialized() ? (DistriOptimizerV2$TrainingConfig$2$) lazyRef.value() : (DistriOptimizerV2$TrainingConfig$2$) lazyRef.initialize(new DistriOptimizerV2$TrainingConfig$2$());
        }
        return distriOptimizerV2$TrainingConfig$2$;
    }

    private final DistriOptimizerV2$TrainingConfig$2$ TrainingConfig$3(LazyRef lazyRef) {
        return lazyRef.initialized() ? (DistriOptimizerV2$TrainingConfig$2$) lazyRef.value() : TrainingConfig$lzycompute$1(lazyRef);
    }

    public static final /* synthetic */ Replica $anonfun$initCacheOfSlave$2(ModelBroadcast modelBroadcast, DistriOptimizerV2$TrainingConfig$1 distriOptimizerV2$TrainingConfig$1, Table table, int i, ClassTag classTag, int i2) {
        AbstractModule value = modelBroadcast.value(true, modelBroadcast.value$default$2());
        AbstractCriterion cloneCriterion = distriOptimizerV2$TrainingConfig$1.criterion().cloneCriterion();
        Table m3112clone = table.m3112clone();
        Some some = distriOptimizerV2$TrainingConfig$1.validationMethods().isDefined() ? new Some(new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) distriOptimizerV2$TrainingConfig$1.validationMethods().get())).map(validationMethod -> {
            return validationMethod.m2812clone();
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ValidationMethod.class)))) : None$.MODULE$;
        Tuple2 parameters = value.getParameters();
        if (parameters == null) {
            throw new MatchError(parameters);
        }
        Tuple2 tuple2 = new Tuple2((Tensor) parameters._1(), (Tensor) parameters._2());
        Tensor tensor = (Tensor) tuple2._1();
        Tensor tensor2 = (Tensor) tuple2._2();
        MODULE$.setModelId(value, i, classTag);
        return new Replica(value, tensor, tensor2, cloneCriterion, m3112clone, some);
    }

    public static final /* synthetic */ void $anonfun$setModelId$1(int i, ClassTag classTag, AbstractModule abstractModule) {
        MODULE$.setModelId(abstractModule, i, classTag);
    }

    public static final /* synthetic */ Tensor $anonfun$getModel$2(Tuple2 tuple2, int i) {
        return ((Tensor[]) tuple2._2())[i].resizeAs(((Tensor[]) tuple2._1())[i]);
    }

    public static final /* synthetic */ Tensor $anonfun$getModel$6(AllReduceParameter allReduceParameter, int i, int i2, Tensor tensor, Map map, Tensor tensor2, Map map2, int i3) {
        int paramOffset = allReduceParameter.paramOffset() + (i3 * i) + scala.math.package$.MODULE$.min(i3, i2);
        int i4 = i + (i3 < i2 ? 1 : 0);
        tensor.narrow(1, paramOffset, i4).copy((Tensor) map.apply(BoxesRunTime.boxToInteger(i3)));
        return tensor2.narrow(1, paramOffset, i4).copy((Tensor) map2.apply(BoxesRunTime.boxToInteger(i3)));
    }

    public static final /* synthetic */ void $anonfun$driverStatesUpdate$1(Tuple2 tuple2) {
        ((OptimMethod) tuple2._2()).updateHyperParameter();
    }

    public static final /* synthetic */ void $anonfun$parameterSync$1(RDD rdd, AllReduceParameter allReduceParameter, Metrics metrics, TrainingContext trainingContext, TensorNumericMath.TensorNumeric tensorNumeric, ParameterProcessor parameterProcessor) {
        parameterProcessor.collectGlobalData(rdd, allReduceParameter, metrics, trainingContext.state(), tensorNumeric);
    }

    public static final /* synthetic */ void $anonfun$parameterSync$4(AllReduceParameter allReduceParameter, DistriOptimizerV2.Cache cache, TrainingContext trainingContext, TensorNumericMath.TensorNumeric tensorNumeric, ParameterProcessor parameterProcessor) {
        parameterProcessor.processParameters(allReduceParameter, cache, trainingContext.state(), tensorNumeric);
    }

    private DistriOptimizerV2$() {
        MODULE$ = this;
        this._logger = None$.MODULE$;
    }
}
