package com.intel.analytics.bigdl.optim;

import com.intel.analytics.bigdl.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.optim.DistriOptimizer;
import com.intel.analytics.bigdl.parameters.AllReduceParameter;
import com.intel.analytics.bigdl.parameters.ParameterProcessor;
import com.intel.analytics.bigdl.parameters.Util$;
import com.intel.analytics.bigdl.tensor.Tensor;
import com.intel.analytics.bigdl.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.utils.Table;
import org.apache.spark.rdd.RDD;
import org.apache.spark.rdd.RDD$;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.IterableLike;
import scala.collection.Iterator;
import scala.collection.immutable.Map;
import scala.collection.immutable.Map$;
import scala.collection.mutable.ArrayOps;
import scala.math.Ordering$String$;
import scala.package$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: LarsSGD.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005mc!B\u0004\t\u0001)\u0011\u0002\u0002C\u0010\u0001\u0005\u0003\u0005\u000b\u0011B\u0011\t\u0011U\u0002!\u0011!Q\u0001\nYBQ!\u000f\u0001\u0005\u0002iBQa\u0010\u0001\u0005B\u0001Cq!!\b\u0001\t\u0003\ny\u0002C\u0004\u0002\u001e\u0001!\t%!\u000f\u0003\u001b1\u000b'o\u001d)s_\u000e,7o]8s\u0015\tI!\"A\u0003paRLWN\u0003\u0002\f\u0019\u0005)!-[4eY*\u0011QBD\u0001\nC:\fG.\u001f;jGNT!a\u0004\t\u0002\u000b%tG/\u001a7\u000b\u0003E\t1aY8n'\r\u00011#\u0007\t\u0003)]i\u0011!\u0006\u0006\u0002-\u0005)1oY1mC&\u0011\u0001$\u0006\u0002\u0007\u0003:L(+\u001a4\u0011\u0005iiR\"A\u000e\u000b\u0005qQ\u0011A\u00039be\u0006lW\r^3sg&\u0011ad\u0007\u0002\u0013!\u0006\u0014\u0018-\\3uKJ\u0004&o\\2fgN|'/A\bqCJ\fW.\u0019;feN\u0003H.\u001b;t\u0007\u0001\u0001BAI\u0015-_9\u00111e\n\t\u0003IUi\u0011!\n\u0006\u0003M\u0001\na\u0001\u0010:p_Rt\u0014B\u0001\u0015\u0016\u0003\u0019\u0001&/\u001a3fM&\u0011!f\u000b\u0002\u0004\u001b\u0006\u0004(B\u0001\u0015\u0016!\t\u0011S&\u0003\u0002/W\t11\u000b\u001e:j]\u001e\u0004B\u0001\u0006\u00193e%\u0011\u0011'\u0006\u0002\u0007)V\u0004H.\u001a\u001a\u0011\u0005Q\u0019\u0014B\u0001\u001b\u0016\u0005\rIe\u000e^\u0001\fo\u0016Lw\r\u001b;EK\u000e\f\u0017\u0010\u0005\u0002\u0015o%\u0011\u0001(\u0006\u0002\u0007\t>,(\r\\3\u0002\rqJg.\u001b;?)\rYTH\u0010\t\u0003y\u0001i\u0011\u0001\u0003\u0005\u0006?\r\u0001\r!\t\u0005\u0006k\r\u0001\rAN\u0001\u0012G>dG.Z2u\u000f2|'-\u00197ECR\fWCA!^)\u001d\u0011e-`A\u0002\u0003\u001b!\"a\u0011$\u0011\u0005Q!\u0015BA#\u0016\u0005\u0011)f.\u001b;\t\u000b\u001d#\u00019\u0001%\u0002\u0005\u00154\bcA%Y7:\u0011!*\u0016\b\u0003\u0017Ns!\u0001\u0014*\u000f\u00055\u000bfB\u0001(Q\u001d\t!s*C\u0001\u0012\u0013\ty\u0001#\u0003\u0002\u000e\u001d%\u00111\u0002D\u0005\u0003)*\ta\u0001^3og>\u0014\u0018B\u0001,X\u0003E!VM\\:pe:+X.\u001a:jG6\u000bG\u000f\u001b\u0006\u0003)*I!!\u0017.\u0003\u001bQ+gn]8s\u001dVlWM]5d\u0015\t1v\u000b\u0005\u0002];2\u0001A!\u00020\u0005\u0005\u0004y&!\u0001+\u0012\u0005\u0001\u001c\u0007C\u0001\u000bb\u0013\t\u0011WCA\u0004O_RD\u0017N\\4\u0011\u0005Q!\u0017BA3\u0016\u0005\r\te.\u001f\u0005\u0006O\u0012\u0001\r\u0001[\u0001\u0007[>$W\r\\:\u0011\u0007%\u0014H/D\u0001k\u0015\tYG.A\u0002sI\u0012T!!\u001c8\u0002\u000bM\u0004\u0018M]6\u000b\u0005=\u0004\u0018AB1qC\u000eDWMC\u0001r\u0003\ry'oZ\u0005\u0003g*\u00141A\u0015#E!\r)(p\u0017\b\u0003mbt!aS<\n\u0005%Q\u0011BA=\t\u0003=!\u0015n\u001d;sS>\u0003H/[7ju\u0016\u0014\u0018BA>}\u0005\u0015\u0019\u0015m\u00195f\u0015\tI\b\u0002C\u0003\u001d\t\u0001\u0007a\u0010E\u0002\u001b\u007fnK1!!\u0001\u001c\u0005I\tE\u000e\u001c*fIV\u001cW\rU1sC6,G/\u001a:\t\u000f\u0005\u0015A\u00011\u0001\u0002\b\u00059Q.\u001a;sS\u000e\u001c\bc\u0001\u001f\u0002\n%\u0019\u00111\u0002\u0005\u0003\u000f5+GO]5dg\"9\u0011q\u0002\u0003A\u0002\u0005E\u0011!B:uCR,\u0007\u0003BA\n\u00033i!!!\u0006\u000b\u0007\u0005]!\"A\u0003vi&d7/\u0003\u0003\u0002\u001c\u0005U!!\u0002+bE2,\u0017!\u00059s_\u000e,7o\u001d)be\u0006lW\r^3sgV!\u0011\u0011EA\u0016)!\t\u0019#!\f\u00022\u0005]BcA\"\u0002&!1q)\u0002a\u0002\u0003O\u0001B!\u0013-\u0002*A\u0019A,a\u000b\u0005\u000by+!\u0019A0\t\rq)\u0001\u0019AA\u0018!\u0011Qr0!\u000b\t\u000f\u0005MR\u00011\u0001\u00026\u0005QQn\u001c3fY\u000e\u000b7\r[3\u0011\tUT\u0018\u0011\u0006\u0005\b\u0003\u001f)\u0001\u0019AA\t+\u0011\tY$!\u0012\u0015\r\u0005u\u0012qIA-)\r\u0019\u0015q\b\u0005\u0007\u000f\u001a\u0001\u001d!!\u0011\u0011\t%C\u00161\t\t\u00049\u0006\u0015C!\u00020\u0007\u0005\u0004y\u0006bBA%\r\u0001\u0007\u00111J\u0001\u0006[>$W\r\u001c\t\u0007\u0003\u001b\n\u0019&a\u0011\u000f\u0007-\u000by%C\u0002\u0002R)\tq\u0001]1dW\u0006<W-\u0003\u0003\u0002V\u0005]#AB'pIVdWMC\u0002\u0002R)Aq!a\u0004\u0007\u0001\u0004\t\t\u0002")
/* loaded from: input_file:com/intel/analytics/bigdl/optim/LarsProcessor.class */
public class LarsProcessor implements ParameterProcessor {
    private final Map<String, Tuple2<Object, Object>> paramaterSplits;
    private final double weightDecay;

    @Override // com.intel.analytics.bigdl.parameters.ParameterProcessor
    public <T> void collectGlobalData(RDD<DistriOptimizer.Cache<T>> rdd, AllReduceParameter<T> allReduceParameter, Metrics metrics, Table table, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        int unboxToInt = BoxesRunTime.unboxToInt(table.get("numFinishedModel").get());
        int unboxToInt2 = BoxesRunTime.unboxToInt(table.get("parallelism").get());
        boolean unboxToBoolean = BoxesRunTime.unboxToBoolean(table.get("isGradientUpdated").get());
        Map map = new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) RDD$.MODULE$.rddToPairRDDFunctions(rdd.mapPartitions(iterator -> {
            if (unboxToBoolean) {
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else {
                long nanoTime = System.nanoTime();
                allReduceParameter.aggregateGradientPartition(unboxToInt);
                metrics.add("aggregrateGradientParition average executor", System.nanoTime() - nanoTime);
            }
            Tuple2<Object, Object> localPartitionRange = allReduceParameter.localPartitionRange();
            if (localPartitionRange == null) {
                throw new MatchError(localPartitionRange);
            }
            Tuple2.mcII.sp spVar = new Tuple2.mcII.sp(localPartitionRange._1$mcI$sp(), localPartitionRange._2$mcI$sp());
            int _1$mcI$sp = spVar._1$mcI$sp();
            int _2$mcI$sp = spVar._2$mcI$sp();
            return ((IterableLike) this.paramaterSplits.flatMap(tuple2 -> {
                Iterator empty;
                if (tuple2 == null) {
                    throw new MatchError(tuple2);
                }
                String str = (String) tuple2._1();
                Tuple2 tuple2 = (Tuple2) tuple2._2();
                int max = Math.max(_1$mcI$sp, tuple2._1$mcI$sp());
                int min = Math.min(_1$mcI$sp + _2$mcI$sp, tuple2._1$mcI$sp() + tuple2._2$mcI$sp());
                if (min > max) {
                    Tensor narrow = allReduceParameter.gradientPartition().narrow(1, (max - _1$mcI$sp) + 1, min - max);
                    Tensor narrow2 = allReduceParameter.weightPartition().narrow(1, (max - _1$mcI$sp) + 1, min - max);
                    double sumsquareInParallel = Util$.MODULE$.getSumsquareInParallel(narrow, unboxToInt2, tensorNumeric);
                    empty = package$.MODULE$.Iterator().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{new Tuple2(str, new Tuple2.mcDD.sp(Util$.MODULE$.getSumsquareInParallel(narrow2, unboxToInt2, tensorNumeric), sumsquareInParallel))}));
                } else {
                    empty = package$.MODULE$.Iterator().empty();
                }
                return empty;
            }, Map$.MODULE$.canBuildFrom())).toIterator();
        }, rdd.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Tuple2.class)), ClassTag$.MODULE$.apply(String.class), ClassTag$.MODULE$.apply(Tuple2.class), Ordering$String$.MODULE$).reduceByKey((tuple2, tuple22) -> {
            return new Tuple2.mcDD.sp(tuple2._1$mcD$sp() + tuple22._1$mcD$sp(), tuple2._2$mcD$sp() + tuple22._2$mcD$sp());
        }).map(tuple23 -> {
            if (tuple23 == null) {
                throw new MatchError(tuple23);
            }
            String str = (String) tuple23._1();
            Tuple2 tuple23 = (Tuple2) tuple23._2();
            double sqrt = Math.sqrt(tuple23._2$mcD$sp());
            double sqrt2 = Math.sqrt(tuple23._1$mcD$sp());
            return new Tuple2(str, BoxesRunTime.boxToDouble((sqrt + (this.weightDecay * sqrt2)) / sqrt2));
        }, ClassTag$.MODULE$.apply(Tuple2.class)).collect())).toMap(Predef$.MODULE$.$conforms());
        table.update("isGradientUpdated", BoxesRunTime.boxToBoolean(true));
        table.update("larsScale", map);
    }

    @Override // com.intel.analytics.bigdl.parameters.ParameterProcessor
    public <T> void processParameters(AllReduceParameter<T> allReduceParameter, DistriOptimizer.Cache<T> cache, Table table, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Map map = (Map) table.get("larsScale").get();
        Tuple2<Object, Object> localPartitionRange = allReduceParameter.localPartitionRange();
        if (localPartitionRange == null) {
            throw new MatchError(localPartitionRange);
        }
        Tuple2.mcII.sp spVar = new Tuple2.mcII.sp(localPartitionRange._1$mcI$sp(), localPartitionRange._2$mcI$sp());
        int _1$mcI$sp = spVar._1$mcI$sp();
        int _2$mcI$sp = spVar._2$mcI$sp();
        this.paramaterSplits.foreach(tuple2 -> {
            $anonfun$processParameters$1(_1$mcI$sp, _2$mcI$sp, cache, map, tuple2);
            return BoxedUnit.UNIT;
        });
    }

    @Override // com.intel.analytics.bigdl.parameters.ParameterProcessor
    public <T> void processParameters(AbstractModule<Activity, Activity, T> abstractModule, Table table, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
    }

    public static final /* synthetic */ void $anonfun$processParameters$1(int i, int i2, DistriOptimizer.Cache cache, Map map, Tuple2 tuple2) {
        BoxedUnit boxedUnit;
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        String str = (String) tuple2._1();
        Tuple2 tuple22 = (Tuple2) tuple2._2();
        if (Math.min(i + i2, tuple22._1$mcI$sp() + tuple22._2$mcI$sp()) > Math.max(i, tuple22._1$mcI$sp())) {
            OptimMethod optimMethod = (OptimMethod) cache.optimMethods().apply(str);
            if (!(optimMethod instanceof LarsSGD)) {
                throw new MatchError(optimMethod);
            }
            ((LarsSGD) optimMethod).setGradientScale(BoxesRunTime.unboxToDouble(map.apply(str)));
            BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
            boxedUnit = BoxedUnit.UNIT;
        } else {
            boxedUnit = BoxedUnit.UNIT;
        }
    }

    public LarsProcessor(Map<String, Tuple2<Object, Object>> map, double d) {
        this.paramaterSplits = map;
        this.weightDecay = d;
        ParameterProcessor.$init$(this);
    }
}
