package com.intel.analytics.bigdl.optim;

import com.intel.analytics.bigdl.dataset.AbstractDataSet;
import com.intel.analytics.bigdl.dataset.DistributedDataSet;
import com.intel.analytics.bigdl.dataset.MiniBatch;
import com.intel.analytics.bigdl.models.utils.ModelBroadcast;
import com.intel.analytics.bigdl.models.utils.ModelBroadcast$;
import com.intel.analytics.bigdl.nn.Container;
import com.intel.analytics.bigdl.nn.abstractnn.AbstractCriterion;
import com.intel.analytics.bigdl.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.nn.mkldnn.MklDnnContainer;
import com.intel.analytics.bigdl.nn.mkldnn.Phase$TrainingPhase$;
import com.intel.analytics.bigdl.optim.DistriOptimizer;
import com.intel.analytics.bigdl.parameters.AllReduceParameter;
import com.intel.analytics.bigdl.tensor.ConvertableTo$ConvertableToDouble$;
import com.intel.analytics.bigdl.tensor.Tensor;
import com.intel.analytics.bigdl.tensor.Tensor$;
import com.intel.analytics.bigdl.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.utils.BlockManagerParameterSynchronizer;
import com.intel.analytics.bigdl.utils.DistriParameterSynchronizer;
import com.intel.analytics.bigdl.utils.Engine$;
import com.intel.analytics.bigdl.utils.EngineType;
import com.intel.analytics.bigdl.utils.MklBlas$;
import com.intel.analytics.bigdl.utils.MklDnn$;
import com.intel.analytics.bigdl.utils.T$;
import com.intel.analytics.bigdl.utils.Table;
import com.intel.analytics.bigdl.utils.ThreadPool;
import com.intel.analytics.bigdl.utils.Util$;
import com.intel.analytics.bigdl.visualization.TrainSummary;
import com.intel.analytics.bigdl.visualization.ValidationSummary;
import java.util.concurrent.Future;
import org.apache.log4j.Logger;
import org.apache.spark.SparkContext;
import org.apache.spark.TaskContext$;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.rdd.RDD;
import scala.Array$;
import scala.Function0;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Serializable;
import scala.Some;
import scala.Tuple2;
import scala.Tuple3;
import scala.Tuple4;
import scala.Tuple6;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.collection.immutable.IndexedSeq$;
import scala.collection.immutable.Iterable$;
import scala.collection.immutable.Map;
import scala.collection.immutable.Map$;
import scala.collection.immutable.Nil$;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.ArrayBuffer;
import scala.collection.mutable.ArrayBuffer$;
import scala.collection.mutable.ArrayOps;
import scala.collection.mutable.Buffer;
import scala.collection.mutable.Buffer$;
import scala.collection.mutable.HashMap;
import scala.package$;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.DoubleRef;
import scala.runtime.IntRef;
import scala.runtime.LongRef;
import scala.runtime.ObjectRef;
import scala.runtime.RichInt$;
import scala.runtime.ScalaRunTime$;

/* compiled from: ParallelOptimizer.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/optim/ParallelOptimizer$.class */
public final class ParallelOptimizer$ extends AbstractOptimizer {
    public static ParallelOptimizer$ MODULE$;
    private final Logger logger;

    static {
        new ParallelOptimizer$();
    }

    public Logger logger() {
        return this.logger;
    }

    public <T> void optimize(AbstractModule<Activity, Activity, T> abstractModule, DistributedDataSet<MiniBatch<T>> distributedDataSet, int i, Table table, Trigger trigger, Metrics metrics, RDD<DistriOptimizer.Cache<T>> rdd, Map<String, OptimMethod<T>> map, Option<Trigger> option, Option<AbstractDataSet<MiniBatch<T>, ?>> option2, Option<ValidationMethod<T>[]> option3, Option<Trigger> option4, Option<String> option5, Option<TrainSummary> option6, Option<ValidationSummary> option7, boolean z, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        int i2;
        SparkContext sparkContext = distributedDataSet.originRDD().sparkContext();
        int length = distributedDataSet.originRDD().partitions().length;
        LongRef create = LongRef.create(0L);
        long j = 0;
        map.values().foreach(optimMethod -> {
            if (optimMethod.state().contains("epoch")) {
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else {
                optimMethod.state().update("epoch", BoxesRunTime.boxToInteger(1));
            }
            if (optimMethod.state().contains("neval")) {
                BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
            } else {
                optimMethod.state().update("neval", BoxesRunTime.boxToInteger(1));
            }
            if (optimMethod.state().contains("Loss")) {
                BoxedUnit boxedUnit3 = BoxedUnit.UNIT;
            } else {
                optimMethod.state().update("Loss", BoxesRunTime.boxToFloat(Float.POSITIVE_INFINITY));
            }
            if (optimMethod.state().contains("score")) {
                BoxedUnit boxedUnit4 = BoxedUnit.UNIT;
            } else {
                optimMethod.state().update("score", BoxesRunTime.boxToFloat(0.0f));
            }
            return !optimMethod.state().contains("recordsProcessedThisEpoch") ? optimMethod.state().update("recordsProcessedThisEpoch", BoxesRunTime.boxToInteger(0)) : BoxedUnit.UNIT;
        });
        EngineType engineType = Engine$.MODULE$.getEngineType();
        if (MklBlas$.MODULE$.equals(engineType)) {
            i2 = i;
        } else {
            if (!MklDnn$.MODULE$.equals(engineType)) {
                throw new MatchError(engineType);
            }
            i2 = 1;
        }
        int i3 = i2;
        Predef$.MODULE$.require(i3 == 1, () -> {
            return "currently only single model supported especially for mkldnn";
        });
        Table apply = T$.MODULE$.apply(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("epoch"), ((OptimMethod) map.values().head()).state().apply("epoch")), (Seq<Tuple2<Object, Object>>) Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("neval"), ((OptimMethod) map.values().head()).state().apply("neval")), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("Loss"), ((OptimMethod) map.values().head()).state().apply("Loss")), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("score"), ((OptimMethod) map.values().head()).state().apply("score")), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("parallelism"), BoxesRunTime.boxToInteger(i3))}));
        logger().info("Count dataset");
        long nanoTime = System.nanoTime();
        int unboxToInt = BoxesRunTime.unboxToInt(distributedDataSet.data(false).map(miniBatch -> {
            return BoxesRunTime.boxToInteger(miniBatch.size());
        }, ClassTag$.MODULE$.Int()).reduce((i4, i5) -> {
            return i4 + i5;
        }));
        logger().info(new StringBuilder(39).append("Count dataset complete. Time elapsed: ").append((System.nanoTime() - nanoTime) / 1.0E9d).append("s").toString());
        if (unboxToInt != distributedDataSet.size()) {
            logger().warn("If the dataset is built directly from RDD[Minibatch], the data in each minibatch is fixed, and a single minibatch is randomly selected in each partition. If the dataset is transformed from RDD[Sample], each minibatch will be constructed on the fly from random samples, which is better for convergence.");
        }
        logger().info(new StringBuilder(7).append("config ").append(table).toString());
        IntRef create2 = IntRef.create(BoxesRunTime.unboxToInt(((OptimMethod) map.values().head()).state().apply("recordsProcessedThisEpoch")));
        if (create2.elem == 0) {
            long nanoTime2 = System.nanoTime();
            logger().info("Shuffle data");
            distributedDataSet.shuffle();
            logger().info(new StringBuilder(30).append("Shuffle data complete. Takes ").append((System.nanoTime() - nanoTime2) / 1.0E9d).append("s").toString());
        }
        new ArrayBuffer();
        LongRef create3 = LongRef.create(Long.MAX_VALUE);
        LongRef create4 = LongRef.create(Long.MAX_VALUE);
        IntRef create5 = IntRef.create(0);
        double unboxToDouble = BoxesRunTime.unboxToDouble(table.get("dropPercentage").get());
        int unboxToInt2 = BoxesRunTime.unboxToInt(table.get("warmupIterationNum").get());
        int unboxToInt3 = BoxesRunTime.unboxToInt(table.get("computeThresholdbatchSize").get());
        double unboxToDouble2 = BoxesRunTime.unboxToDouble(table.get("maxDropPercentage").get());
        int i6 = new StringOps(Predef$.MODULE$.augmentString(System.getProperty("bigdl.parallelOptimizer.iterationPerTime", "1"))).toInt();
        int i7 = length * i3 * i6;
        int i8 = 0;
        ObjectRef create6 = ObjectRef.create(new double[i3]);
        long nanoTime3 = System.nanoTime();
        RDD<T> data = distributedDataSet.data(true);
        while (!trigger.apply(apply)) {
            DoubleRef create7 = DoubleRef.create(0.0d);
            IntRef create8 = IntRef.create(0);
            metrics.set("computing time for each node", (ArrayBuffer<Object>) ArrayBuffer$.MODULE$.apply(Nil$.MODULE$), sparkContext);
            metrics.set("computing time average", 0.0d, sparkContext, length);
            long nanoTime4 = System.nanoTime();
            Tuple3 tuple3 = (Tuple3) data.zipPartitions(rdd, true, (iterator, iterator2) -> {
                int i9 = 0;
                DistriOptimizer.Cache cache = (DistriOptimizer.Cache) iterator2.next();
                ObjectRef create9 = ObjectRef.create((Object) null);
                for (int i10 = 0; i10 < i6; i10++) {
                    System.nanoTime();
                    create9.elem = (MiniBatch) iterator.next();
                    long nanoTime5 = System.nanoTime();
                    if (unboxToDouble > 0.0d && create5.elem > (unboxToInt2 + unboxToInt3) - 1) {
                        create4.elem = create3.elem;
                    }
                    int i11 = (create5.elem % unboxToInt3) * i3;
                    ThreadPool m1282default = Engine$.MODULE$.m1282default();
                    Buffer invokeAndWait2 = m1282default.invokeAndWait2(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Function0[]{() -> {
                        long nanoTime6 = System.nanoTime();
                        AbstractModule abstractModule2 = cache.localModels()[0];
                        abstractModule2.training2();
                        AbstractCriterion abstractCriterion = cache.localCriterions()[0];
                        Activity input = ((MiniBatch) create9.elem).getInput();
                        Activity target = ((MiniBatch) create9.elem).getTarget();
                        Activity forward = abstractModule2.forward(input);
                        ((double[]) create6.elem)[0] = BoxesRunTime.unboxToDouble(tensorNumeric.toType(abstractCriterion.forward(forward, target), ConvertableTo$ConvertableToDouble$.MODULE$));
                        abstractModule2.backward(input, abstractCriterion.backward(forward, target));
                        cache.moduleTimeList()[0 + i11] = System.nanoTime() - nanoTime6;
                        return 0;
                    }})), create4.elem, m1282default.invokeAndWait2$default$3());
                    long nanoTime6 = System.nanoTime() - nanoTime5;
                    metrics.add("computing time average", nanoTime6);
                    metrics.add("computing time for each node", nanoTime6);
                    Buffer buffer = (Buffer) ((TraversableLike) invokeAndWait2.filter(future -> {
                        return BoxesRunTime.boxToBoolean($anonfun$optimize$7(future));
                    })).map(future2 -> {
                        return BoxesRunTime.boxToInteger($anonfun$optimize$8(future2));
                    }, Buffer$.MODULE$.canBuildFrom());
                    int size = buffer.size();
                    i9 += size;
                    create8.elem += size * ((MiniBatch) create9.elem).size();
                    int i12 = 0;
                    while (true) {
                        int i13 = i12;
                        if (i13 < size) {
                            create7.elem += ((double[]) create6.elem)[BoxesRunTime.unboxToInt(buffer.apply(i13))];
                            i12 = i13 + 1;
                        }
                    }
                }
                create.elem += System.nanoTime() - nanoTime4;
                return package$.MODULE$.Iterator().single(new Tuple3(BoxesRunTime.boxToInteger(i9), BoxesRunTime.boxToDouble(create7.elem), BoxesRunTime.boxToInteger(create8.elem)));
            }, ClassTag$.MODULE$.apply(DistriOptimizer.Cache.class), ClassTag$.MODULE$.apply(Tuple3.class)).reduce((tuple32, tuple33) -> {
                return new Tuple3(BoxesRunTime.boxToInteger(BoxesRunTime.unboxToInt(tuple32._1()) + BoxesRunTime.unboxToInt(tuple33._1())), BoxesRunTime.boxToDouble(BoxesRunTime.unboxToDouble(tuple32._2()) + BoxesRunTime.unboxToDouble(tuple33._2())), BoxesRunTime.boxToInteger(BoxesRunTime.unboxToInt(tuple32._3()) + BoxesRunTime.unboxToInt(tuple33._3())));
            });
            if (tuple3 == null) {
                throw new MatchError(tuple3);
            }
            Tuple3 tuple34 = new Tuple3(BoxesRunTime.boxToInteger(BoxesRunTime.unboxToInt(tuple3._1())), BoxesRunTime.boxToDouble(BoxesRunTime.unboxToDouble(tuple3._2())), BoxesRunTime.boxToInteger(BoxesRunTime.unboxToInt(tuple3._3())));
            int unboxToInt4 = BoxesRunTime.unboxToInt(tuple34._1());
            double unboxToDouble3 = BoxesRunTime.unboxToDouble(tuple34._2());
            int unboxToInt5 = BoxesRunTime.unboxToInt(tuple34._3());
            i8 += i7 - unboxToInt4;
            if (unboxToDouble == 0.0d || unboxToInt4 >= i7 * (1.0d - unboxToDouble2)) {
                apply.update("numFinishedModel", BoxesRunTime.boxToInteger(unboxToInt4));
                create2.elem += unboxToInt5;
                create.elem += System.nanoTime() - nanoTime4;
                apply.update("Loss", BoxesRunTime.boxToDouble(unboxToDouble3 / unboxToInt4));
                map.foreach(tuple2 -> {
                    $anonfun$optimize$10(tuple2);
                    return BoxedUnit.UNIT;
                });
                apply.update("LearningRate", BoxesRunTime.boxToFloat((float) ((OptimMethod) ((Tuple2) map.head())._2()).getLearningRate()));
                apply.update("Throughput", BoxesRunTime.boxToFloat(unboxToInt5 / ((float) ((r0 - nanoTime4) / 1.0E9d))));
                String header = Optimizer$.MODULE$.header(BoxesRunTime.unboxToInt(apply.apply("epoch")), create2.elem, unboxToInt, BoxesRunTime.unboxToInt(apply.apply("neval")), create.elem);
                logger().info(new StringBuilder(71).append(header).append(" Trained ").append(unboxToInt5).append(" records in ").append((r0 - nanoTime4) / 1.0E9d).append(" ").append("seconds. Throughput is ").append(apply.apply("Throughput")).append(" records/second. Loss is ").append(apply.apply("Loss")).append(".").toString());
                logger().debug(new StringBuilder(1).append("\n").append(metrics.summary(metrics.summary$default$1(), metrics.summary$default$2())).toString());
                logger().debug(new StringBuilder(17).append("Dropped modules: ").append(i7 - unboxToInt4).toString());
                create6.elem = new double[i3];
                create5.elem++;
                if (unboxToDouble > 0.0d && create5.elem > unboxToInt2 && create5.elem % unboxToInt3 == 0) {
                    long[] jArr = (long[]) rdd.mapPartitions(iterator3 -> {
                        return new ArrayOps.ofLong(Predef$.MODULE$.longArrayOps(((DistriOptimizer.Cache) iterator3.next()).moduleTimeList())).iterator();
                    }, rdd.mapPartitions$default$2(), ClassTag$.MODULE$.Long()).collect();
                    int i9 = (int) (unboxToDouble * unboxToInt3 * i7);
                    if (i9 > i8) {
                        create3.elem = Util$.MODULE$.kthLargest(jArr, 0, jArr.length - 1, i9 - i8);
                    } else {
                        create3.elem = (long) (create3.elem * 1.01d);
                    }
                    logger().info(new StringBuilder(11).append("threshold: ").append(create3.elem).toString());
                    rdd.mapPartitions(iterator4 -> {
                        long[] moduleTimeList = ((DistriOptimizer.Cache) iterator4.next()).moduleTimeList();
                        int i10 = 0;
                        while (true) {
                            int i11 = i10;
                            if (i11 >= moduleTimeList.length) {
                                return package$.MODULE$.Iterator().empty();
                            }
                            moduleTimeList[i11] = 0;
                            i10 = i11 + 1;
                        }
                    }, rdd.mapPartitions$default$2(), classTag).count();
                    i8 = 0;
                }
                apply.update("neval", BoxesRunTime.boxToInteger(BoxesRunTime.unboxToInt(apply.apply("neval")) + i6));
                if (create2.elem >= unboxToInt) {
                    create.elem = (j + System.nanoTime()) - nanoTime3;
                    j = create.elem;
                    nanoTime3 = System.nanoTime();
                    logger().info(new StringBuilder(39).append(header).append(" Epoch finished. Wall clock time is ").append(create.elem / 1000000.0d).append(" ms").toString());
                    apply.update("epoch", BoxesRunTime.boxToInteger(BoxesRunTime.unboxToInt(apply.apply("epoch")) + 1));
                    distributedDataSet.shuffle();
                    data = distributedDataSet.data(true);
                    create2.elem = 0;
                }
                map.map(tuple22 -> {
                    if (tuple22 == null) {
                        throw new MatchError(tuple22);
                    }
                    OptimMethod optimMethod2 = (OptimMethod) tuple22._2();
                    optimMethod2.state().update("recordsProcessedThisEpoch", BoxesRunTime.boxToInteger(create2.elem));
                    optimMethod2.state().update("epoch", apply.apply("epoch"));
                    optimMethod2.state().update("neval", apply.apply("neval"));
                    optimMethod2.state().update("Loss", apply.apply("Loss"));
                    return option3.isDefined() ? optimMethod2.state().update("score", apply.apply("score")) : BoxedUnit.UNIT;
                }, Iterable$.MODULE$.canBuildFrom());
                if (trigger.apply(apply)) {
                    logger().info("training finished, updating all layers parameters");
                    rdd.mapPartitions(iterator5 -> {
                        Function0[] function0Arr = (Function0[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(((DistriOptimizer.Cache) iterator5.next()).localModels())).map(abstractModule2 -> {
                            return () -> {
                                MODULE$.updateLayerParameters(abstractModule2, classTag);
                            };
                        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Function0.class)));
                        ThreadPool m1282default = Engine$.MODULE$.m1282default();
                        m1282default.invokeAndWait2(Predef$.MODULE$.wrapRefArray(function0Arr), m1282default.invokeAndWait2$default$2(), m1282default.invokeAndWait2$default$3());
                        return package$.MODULE$.Iterator().empty();
                    }, rdd.mapPartitions$default$2(), classTag).collect();
                } else {
                    BoxedUnit boxedUnit = BoxedUnit.UNIT;
                }
                validate$default$9();
                validate(option, option2, option3, i, rdd, apply, option7, header, null);
                option6.foreach(trainSummary -> {
                    $anonfun$optimize$17(rdd, apply, abstractModule, classTag, tensorNumeric, trainSummary);
                    return BoxedUnit.UNIT;
                });
                checkpoint(option4, option5, z, create.elem, rdd, apply, null, map, abstractModule, classTag, tensorNumeric);
            } else {
                logger().info(new StringBuilder(214).append("Warning! Not enough training samples were successfully processed in this ").append("iteration due to some slow tasks. The gradients computed in this iteration will be ").append("discarded. Only ").append(unboxToInt4).append("/").append(i7).append(" threads successfully ").append("completed training.").toString());
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public <T> void updateLayerParameters(AbstractModule<Activity, Activity, T> abstractModule, ClassTag<T> classTag) {
        abstractModule.updateParameter();
        if (abstractModule instanceof Container) {
            ((Container) abstractModule).modules().foreach(abstractModule2 -> {
                $anonfun$updateLayerParameters$1(classTag, abstractModule2);
                return BoxedUnit.UNIT;
            });
        }
    }

    public <T> RDD<DistriOptimizer.CacheV1<T>> com$intel$analytics$bigdl$optim$ParallelOptimizer$$initThreadModels(AbstractModule<Activity, Activity, T> abstractModule, DistributedDataSet<MiniBatch<T>> distributedDataSet, AbstractCriterion<Activity, Activity, T> abstractCriterion, Table table, int i, int i2, boolean z, Option<ValidationMethod<T>[]> option, Map<String, OptimMethod<T>> map, scala.collection.mutable.Map<String, Object> map2, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        int i3;
        SparkContext sparkContext = distributedDataSet.originRDD().sparkContext();
        Broadcast broadcast = sparkContext.broadcast(new Tuple4(abstractCriterion, table, option, map), ClassTag$.MODULE$.apply(Tuple4.class));
        ModelBroadcast<T> broadcast2 = ModelBroadcast$.MODULE$.apply(classTag, tensorNumeric).broadcast(sparkContext, abstractModule);
        abstractModule.getParameters();
        EngineType engineType = Engine$.MODULE$.getEngineType();
        if (MklBlas$.MODULE$.equals(engineType)) {
            i3 = i2;
        } else {
            if (!MklDnn$.MODULE$.equals(engineType)) {
                throw new MatchError(engineType);
            }
            i3 = 1;
        }
        int i4 = i3;
        Predef$.MODULE$.require(distributedDataSet.originRDD().partitions().length == i, () -> {
            return new StringBuilder(71).append("Passed in rdd partition number ").append(distributedDataSet.originRDD().partitions().length).append(" is not equal to configured node number ").append(i).toString();
        });
        int unboxToInt = BoxesRunTime.unboxToInt(table.get("computeThresholdbatchSize").get());
        int nodeNumber = Engine$.MODULE$.nodeNumber();
        int i5 = new StringOps(Predef$.MODULE$.augmentString(System.getProperty("bigdl.parallelOptimizer.parameterBlocks", "10"))).toInt();
        RDD<?> originRDD = distributedDataSet.originRDD();
        RDD<DistriOptimizer.CacheV1<T>> persist = originRDD.mapPartitions(iterator -> {
            int partitionId = TaskContext$.MODULE$.getPartitionId();
            Tuple4 tuple4 = (Tuple4) broadcast.value();
            if (tuple4 == null) {
                throw new MatchError(tuple4);
            }
            Tuple4 tuple42 = new Tuple4((AbstractCriterion) tuple4._1(), (Table) tuple4._2(), (Option) tuple4._3(), (Map) tuple4._4());
            AbstractCriterion abstractCriterion2 = (AbstractCriterion) tuple42._1();
            Table table2 = (Table) tuple42._2();
            Option option2 = (Option) tuple42._3();
            Map map3 = (Map) tuple42._4();
            if (!Engine$.MODULE$.checkSingleton()) {
                if (z) {
                    Predef$.MODULE$.require(Engine$.MODULE$.checkSingleton(), () -> {
                        return "Partitions of the training data are not evenlydistributed across the executors in the Spark cluster; are there sufficient trainingdata to be distributed? Set property \"bigdl.check.singleton\" to false to skip this check";
                    });
                } else {
                    MODULE$.logger().warn("Partitions of the training data are not evenlydistributed across the executors in the Spark cluster; are there sufficient trainingdata to be distributed?");
                }
            }
            Engine$.MODULE$.setNodeAndCore(nodeNumber, i2);
            BlockManagerParameterSynchronizer blockManagerParameterSynchronizer = new BlockManagerParameterSynchronizer(partitionId, nodeNumber, classTag, tensorNumeric);
            Tuple6[] tuple6Arr = (Tuple6[]) ((TraversableOnce) RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), i4).map(obj -> {
                return $anonfun$initThreadModels$4(broadcast2, partitionId, classTag, blockManagerParameterSynchronizer, i5, abstractCriterion2, table2, option2, tensorNumeric, BoxesRunTime.unboxToInt(obj));
            }, IndexedSeq$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.apply(Tuple6.class));
            MODULE$.logger().info(new StringBuilder(26).append("model thread pool size is ").append(Engine$.MODULE$.model().getPoolSize()).toString());
            return package$.MODULE$.Iterator().single(new DistriOptimizer.CacheV1((AbstractModule[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(tuple6Arr)).map(tuple6 -> {
                return (AbstractModule) tuple6._1();
            }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(AbstractModule.class))), (Tensor[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(tuple6Arr)).map(tuple62 -> {
                return (Tensor) tuple62._2();
            }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tensor.class))), (Tensor[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(tuple6Arr)).map(tuple63 -> {
                return (Tensor) tuple63._3();
            }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tensor.class))), (AbstractCriterion[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(tuple6Arr)).map(tuple64 -> {
                return (AbstractCriterion) tuple64._4();
            }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(AbstractCriterion.class))), (Table[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(tuple6Arr)).map(tuple65 -> {
                return (Table) tuple65._5();
            }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Table.class))), new long[i4 * unboxToInt], (Option[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(tuple6Arr)).map(tuple66 -> {
                return (Option) tuple66._6();
            }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Option.class))), (Map) map3.map(tuple2 -> {
                return new Tuple2(tuple2._1(), ((OptimMethod) tuple2._2()).m984clone());
            }, Map$.MODULE$.canBuildFrom()), blockManagerParameterSynchronizer));
        }, originRDD.mapPartitions$default$2(), ClassTag$.MODULE$.apply(DistriOptimizer.CacheV1.class)).persist();
        persist.setName("Thread Model RDD");
        logger().info("Cache thread models...");
        persist.count();
        logger().info("Cache thread models... done");
        return persist;
    }

    public <T> ArrayBuffer<AbstractModule<Activity, Activity, T>> com$intel$analytics$bigdl$optim$ParallelOptimizer$$getExecutionOrder(AbstractModule<Activity, Activity, T> abstractModule, ClassTag<T> classTag) {
        ArrayBuffer<AbstractModule<Activity, Activity, T>> arrayBuffer = new ArrayBuffer<>();
        if (abstractModule instanceof Container) {
            ((Container) abstractModule).modules().foreach(abstractModule2 -> {
                return arrayBuffer.$plus$plus$eq(MODULE$.com$intel$analytics$bigdl$optim$ParallelOptimizer$$getExecutionOrder(abstractModule2, classTag));
            });
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        } else if (abstractModule.parameters() != null) {
            arrayBuffer.$plus$eq(abstractModule);
        } else {
            BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
        }
        return arrayBuffer;
    }

    private <T> void setDistriPartitionsynchronizer(AbstractModule<Activity, Activity, T> abstractModule, DistriParameterSynchronizer<T> distriParameterSynchronizer, scala.collection.mutable.Map<Object, Object> map, int i, ClassTag<T> classTag) {
        Tensor tensor = (Tensor) abstractModule.getParameters()._1();
        Tensor tensor2 = (Tensor) abstractModule.getParameters()._2();
        int nElement = tensor2.nElement();
        ArrayBuffer<AbstractModule<Activity, Activity, T>> com$intel$analytics$bigdl$optim$ParallelOptimizer$$getExecutionOrder = com$intel$analytics$bigdl$optim$ParallelOptimizer$$getExecutionOrder(abstractModule, classTag);
        int i2 = (nElement / i) - 1;
        int i3 = nElement - (i2 * (i - 1));
        int i4 = nElement;
        for (int length = com$intel$analytics$bigdl$optim$ParallelOptimizer$$getExecutionOrder.length() - 1; length >= 0; length--) {
            AbstractModule abstractModule2 = (AbstractModule) com$intel$analytics$bigdl$optim$ParallelOptimizer$$getExecutionOrder.apply(length);
            if (abstractModule2.parameters() != null) {
                int storageOffset = ((Tensor) abstractModule2.getParameters()._1()).storageOffset() - 1;
                int i5 = storageOffset == 0 ? 0 : ((storageOffset - 1) / i2) + 1;
                int i6 = i4 - storageOffset;
                if (i5 < i && !map.contains(BoxesRunTime.boxToInteger(i5))) {
                    map.put(BoxesRunTime.boxToInteger(i5), BoxesRunTime.boxToInteger(storageOffset));
                    distriParameterSynchronizer.init(abstractModule2.getName(), i6, com$intel$analytics$bigdl$optim$ParallelOptimizer$$getExecutionOrder.length() - length, tensor.narrow(1, storageOffset + 1, i6), tensor2.narrow(1, storageOffset + 1, i6));
                    abstractModule2.setParameterSynchronizer(distriParameterSynchronizer);
                    i4 = storageOffset;
                }
            }
        }
    }

    private <T> void setModelId(AbstractModule<Activity, Activity, T> abstractModule, int i, ClassTag<T> classTag) {
        abstractModule.setId(i);
        if (abstractModule instanceof Container) {
            ((Container) abstractModule).modules().foreach(abstractModule2 -> {
                $anonfun$setModelId$1(i, classTag, abstractModule2);
                return BoxedUnit.UNIT;
            });
        }
    }

    @Override // com.intel.analytics.bigdl.optim.AbstractOptimizer
    public <T> AbstractModule<Activity, Activity, T> getModel(RDD<DistriOptimizer.Cache<T>> rdd, AllReduceParameter<T> allReduceParameter, AbstractModule<Activity, Activity, T> abstractModule, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        int length = rdd.partitions().length;
        abstractModule.setExtraParameter((Tensor[]) rdd.map(cache -> {
            return ((AbstractModule) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(cache.localModels())).head()).getExtraParameter();
        }, ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Tensor.class))).first());
        Tuple2<Tensor<T>[], Tensor<T>[]> parameters = abstractModule.parameters();
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), ((Tensor[]) parameters._2()).length).foreach(obj -> {
            return $anonfun$getModel$2(parameters, BoxesRunTime.unboxToInt(obj));
        });
        Tensor tensor = (Tensor) abstractModule.getParameters()._1();
        scala.reflect.package$.MODULE$.classTag(classTag);
        int array_length = ScalaRunTime$.MODULE$.array_length(tensor.storage().array());
        int i = array_length / length;
        int i2 = array_length % length;
        Map map = (Map) rdd.mapPartitions(iterator -> {
            DistriOptimizer.Cache cache2 = (DistriOptimizer.Cache) iterator.next();
            Tensor tensor2 = (Tensor) ((AbstractModule) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(cache2.localModels())).head()).getParameters()._1();
            int partitionID = ((BlockManagerParameterSynchronizer) cache2.parameterSynchronizer()).partitionID();
            int min = (partitionID * i) + scala.math.package$.MODULE$.min(partitionID, i2);
            int i3 = i + (partitionID < i2 ? 1 : 0);
            Tensor apply = Tensor$.MODULE$.apply(i3, classTag, tensorNumeric);
            apply.copy(tensor2.narrow(1, min + 1, i3));
            return package$.MODULE$.Iterator().single(Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(BoxesRunTime.boxToInteger(partitionID)), apply)})));
        }, rdd.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Map.class)).reduce((map2, map3) -> {
            return map2.$plus$plus(map3);
        });
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), length).map(obj2 -> {
            return $anonfun$getModel$5(tensor, i, i2, map, BoxesRunTime.unboxToInt(obj2));
        }, IndexedSeq$.MODULE$.canBuildFrom());
        return abstractModule;
    }

    public static final /* synthetic */ boolean $anonfun$optimize$7(Future future) {
        return !future.isCancelled();
    }

    public static final /* synthetic */ int $anonfun$optimize$8(Future future) {
        return BoxesRunTime.unboxToInt(future.get());
    }

    public static final /* synthetic */ void $anonfun$optimize$10(Tuple2 tuple2) {
        ((OptimMethod) tuple2._2()).updateHyperParameter();
    }

    public static final /* synthetic */ void $anonfun$optimize$17(RDD rdd, Table table, AbstractModule abstractModule, ClassTag classTag, TensorNumericMath.TensorNumeric tensorNumeric, TrainSummary trainSummary) {
        MODULE$.saveSummary(trainSummary, rdd, table, null, abstractModule, classTag, tensorNumeric);
    }

    public static final /* synthetic */ void $anonfun$updateLayerParameters$1(ClassTag classTag, AbstractModule abstractModule) {
        MODULE$.updateLayerParameters(abstractModule, classTag);
    }

    public static final /* synthetic */ Tuple6 $anonfun$initThreadModels$4(ModelBroadcast modelBroadcast, int i, ClassTag classTag, BlockManagerParameterSynchronizer blockManagerParameterSynchronizer, int i2, AbstractCriterion abstractCriterion, Table table, Option option, TensorNumericMath.TensorNumeric tensorNumeric, int i3) {
        Serializable value = modelBroadcast.value(true, false);
        if (value instanceof MklDnnContainer) {
            ((MklDnnContainer) value).compile(Phase$TrainingPhase$.MODULE$);
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        } else {
            BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
        }
        MODULE$.setModelId(value, i, classTag);
        MODULE$.setDistriPartitionsynchronizer(value, blockManagerParameterSynchronizer, new HashMap(), i2, classTag);
        return new Tuple6(value, Tensor$.MODULE$.apply(0, classTag, tensorNumeric), Tensor$.MODULE$.apply(0, classTag, tensorNumeric), abstractCriterion.cloneCriterion(), table.m1303clone(), option.isDefined() ? new Some(new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) option.get())).map(validationMethod -> {
            return validationMethod.m1003clone();
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ValidationMethod.class)))) : None$.MODULE$);
    }

    public static final /* synthetic */ void $anonfun$setModelId$1(int i, ClassTag classTag, AbstractModule abstractModule) {
        MODULE$.setModelId(abstractModule, i, classTag);
    }

    public static final /* synthetic */ Tensor $anonfun$getModel$2(Tuple2 tuple2, int i) {
        return ((Tensor[]) tuple2._2())[i].resizeAs(((Tensor[]) tuple2._1())[i]);
    }

    public static final /* synthetic */ Tensor $anonfun$getModel$5(Tensor tensor, int i, int i2, Map map, int i3) {
        return tensor.narrow(1, tensor.storageOffset() + (i3 * i) + scala.math.package$.MODULE$.min(i3, i2), i + (i3 < i2 ? 1 : 0)).copy((Tensor) map.apply(BoxesRunTime.boxToInteger(i3)));
    }

    private ParallelOptimizer$() {
        MODULE$ = this;
        this.logger = Logger.getLogger(getClass());
    }
}
