package com.intel.analytics.bigdl.optim;

import com.intel.analytics.bigdl.dataset.AbstractDataSet;
import com.intel.analytics.bigdl.dataset.DataSet$;
import com.intel.analytics.bigdl.dataset.DistributedDataSet;
import com.intel.analytics.bigdl.dataset.LocalDataSet;
import com.intel.analytics.bigdl.dataset.MiniBatch;
import com.intel.analytics.bigdl.dataset.PaddingParam;
import com.intel.analytics.bigdl.dataset.Sample;
import com.intel.analytics.bigdl.dataset.SampleToMiniBatch$;
import com.intel.analytics.bigdl.nn.abstractnn.AbstractCriterion;
import com.intel.analytics.bigdl.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.parameters.ParameterProcessor;
import com.intel.analytics.bigdl.tensor.Storage;
import com.intel.analytics.bigdl.tensor.Tensor;
import com.intel.analytics.bigdl.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.utils.Engine$;
import com.intel.analytics.bigdl.utils.OptimizerV1$;
import com.intel.analytics.bigdl.utils.OptimizerV2$;
import com.intel.analytics.bigdl.utils.OptimizerVersion;
import com.intel.analytics.bigdl.utils.Table;
import org.apache.log4j.Logger;
import org.apache.spark.rdd.RDD;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Some;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableOnce;
import scala.collection.immutable.Iterable$;
import scala.collection.immutable.Map;
import scala.collection.mutable.ArrayBuffer;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.reflect.package$;
import scala.runtime.BoxesRunTime;
import scala.runtime.Null$;

/* compiled from: Optimizer.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/optim/Optimizer$.class */
public final class Optimizer$ {
    public static Optimizer$ MODULE$;
    private final Logger com$intel$analytics$bigdl$optim$Optimizer$$logger;

    static {
        new Optimizer$();
    }

    public Logger com$intel$analytics$bigdl$optim$Optimizer$$logger() {
        return this.com$intel$analytics$bigdl$optim$Optimizer$$logger;
    }

    public String header(int i, int i2, long j, int i3, long j2) {
        return new StringBuilder(36).append("[Epoch ").append(i).append(" ").append(i2).append("/").append(j).append("][Iteration ").append(i3).append("][Wall Clock ").append(j2 / 1.0E9d).append("s]").toString();
    }

    public <T> void checkSubModules(AbstractModule<Activity, Activity, T> abstractModule, Seq<String> seq, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Tuple2<Tensor<T>, Tensor<T>> parameters = abstractModule.getParameters();
        Tuple2[] tuple2Arr = (Tuple2[]) ((TraversableOnce) seq.map(str -> {
            Option apply = abstractModule.apply(str);
            Predef$.MODULE$.require(apply.isDefined(), () -> {
                return new StringBuilder(29).append("Optimizer: couldn't find ").append(str).append(" in ").append(abstractModule).toString();
            });
            Tensor tensor = (Tensor) ((AbstractModule) apply.get()).getParameters()._1();
            Predef$.MODULE$.require(tensor.nElement() > 0, () -> {
                return new StringBuilder(92).append("Optimizer: ").append(str).append(" doesn't have").append(" any trainable parameters, please check your model and optimMethods.").toString();
            });
            Predef$ predef$ = Predef$.MODULE$;
            Storage storage = ((Tensor) parameters._1()).storage();
            Storage storage2 = tensor.storage();
            predef$.require(storage != null ? storage.equals(storage2) : storage2 == null, () -> {
                return new StringBuilder(42).append("Optimizer:").append(" ").append(str).append("'s parameter is not contiguous.").toString();
            });
            return new Tuple2(str, tensor);
        }, Seq$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.apply(Tuple2.class));
        if (tuple2Arr.length == 1) {
            return;
        }
        Tuple2[] tuple2Arr2 = (Tuple2[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(tuple2Arr)).sortWith((tuple2, tuple22) -> {
            return BoxesRunTime.boxToBoolean($anonfun$checkSubModules$5(tuple2, tuple22));
        });
        int i = 0;
        while (true) {
            int i2 = i;
            if (i2 >= tuple2Arr2.length - 1) {
                return;
            }
            Tuple2 tuple23 = tuple2Arr2[i2];
            Tuple2 tuple24 = tuple2Arr2[i2 + 1];
            Predef$.MODULE$.require(((Tensor) tuple23._2()).storageOffset() + ((Tensor) tuple23._2()).nElement() <= ((Tensor) tuple24._2()).storageOffset(), () -> {
                return new StringBuilder(87).append("Optimizer: ").append(tuple23._1()).append(" and ").append(tuple24._1()).append("'s parameters are duplicated.").append(" Please check your model and optimMethods.").toString();
            });
            i = i2 + 1;
        }
    }

    public String getHyperParameterLog(Map<String, OptimMethod<?>> map) {
        return (String) ((TraversableOnce) map.map(tuple2 -> {
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            String str = (String) tuple2._1();
            String hyperParameter = ((OptimMethod) tuple2._2()).getHyperParameter();
            return hyperParameter.isEmpty() ? hyperParameter : new StringBuilder(22).append(str).append("'s hyper parameters: ").append(hyperParameter).append(" ").toString();
        }, Iterable$.MODULE$.canBuildFrom())).reduce((str, str2) -> {
            return new StringBuilder(0).append(str).append(str2).toString();
        });
    }

    public <T> void saveModel(AbstractModule<Activity, Activity, T> abstractModule, Option<String> option, boolean z, String str) {
        if (option.isDefined()) {
            abstractModule.save(new StringBuilder(6).append(option.get()).append("/model").append(str).toString(), z);
        }
    }

    public <T> String saveModel$default$4() {
        return "";
    }

    public void saveState(Table table, Option<String> option, boolean z, String str) {
        if (option.isDefined()) {
            table.save(new StringBuilder(6).append(option.get()).append("/state").append(str).toString(), z);
        }
    }

    public String saveState$default$4() {
        return "";
    }

    public <T> void saveOptimMethod(OptimMethod<T> optimMethod, Option<String> option, boolean z, String str, ClassTag<T> classTag) {
        if (option.isDefined()) {
            optimMethod.save(new StringBuilder(12).append(option.get()).append("/optimMethod").append(str).toString(), z);
        }
    }

    public <T> String saveOptimMethod$default$4() {
        return "";
    }

    public <T> Optimizer<T, MiniBatch<T>> apply(AbstractModule<Activity, Activity, T> abstractModule, RDD<Sample<T>> rdd, AbstractCriterion<Activity, Activity, T> abstractCriterion, int i, PaddingParam<T> paddingParam, PaddingParam<T> paddingParam2, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Optimizer distriOptimizerV2;
        Some some = paddingParam != null ? new Some(paddingParam) : None$.MODULE$;
        Some some2 = paddingParam2 != null ? new Some(paddingParam2) : None$.MODULE$;
        OptimizerVersion optimizerVersion = Engine$.MODULE$.getOptimizerVersion();
        if (OptimizerV1$.MODULE$.equals(optimizerVersion)) {
            distriOptimizerV2 = new DistriOptimizer(abstractModule, DataSet$.MODULE$.rdd(rdd, DataSet$.MODULE$.rdd$default$2(), ClassTag$.MODULE$.apply(Sample.class)).$minus$greater(SampleToMiniBatch$.MODULE$.apply(i, some, some2, SampleToMiniBatch$.MODULE$.apply$default$4(), classTag, tensorNumeric), ClassTag$.MODULE$.apply(MiniBatch.class)).toDistributed(), abstractCriterion, classTag, tensorNumeric);
        } else {
            if (!OptimizerV2$.MODULE$.equals(optimizerVersion)) {
                throw new MatchError(optimizerVersion);
            }
            distriOptimizerV2 = new DistriOptimizerV2(abstractModule, DataSet$.MODULE$.rdd(rdd, DataSet$.MODULE$.rdd$default$2(), ClassTag$.MODULE$.apply(Sample.class)).$minus$greater(SampleToMiniBatch$.MODULE$.apply(i, some, some2, SampleToMiniBatch$.MODULE$.apply$default$4(), classTag, tensorNumeric), ClassTag$.MODULE$.apply(MiniBatch.class)).toDistributed(), abstractCriterion, classTag, tensorNumeric);
        }
        return distriOptimizerV2;
    }

    public <T> Optimizer<T, MiniBatch<T>> apply(AbstractModule<Activity, Activity, T> abstractModule, RDD<Sample<T>> rdd, AbstractCriterion<Activity, Activity, T> abstractCriterion, int i, MiniBatch<T> miniBatch, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Optimizer distriOptimizerV2;
        OptimizerVersion optimizerVersion = Engine$.MODULE$.getOptimizerVersion();
        if (OptimizerV1$.MODULE$.equals(optimizerVersion)) {
            distriOptimizerV2 = new DistriOptimizer(abstractModule, DataSet$.MODULE$.rdd(rdd, DataSet$.MODULE$.rdd$default$2(), ClassTag$.MODULE$.apply(Sample.class)).$minus$greater(SampleToMiniBatch$.MODULE$.apply(miniBatch, i, None$.MODULE$, classTag, tensorNumeric), ClassTag$.MODULE$.apply(MiniBatch.class)).toDistributed(), abstractCriterion, classTag, tensorNumeric);
        } else {
            if (!OptimizerV2$.MODULE$.equals(optimizerVersion)) {
                throw new MatchError(optimizerVersion);
            }
            distriOptimizerV2 = new DistriOptimizerV2(abstractModule, DataSet$.MODULE$.rdd(rdd, DataSet$.MODULE$.rdd$default$2(), ClassTag$.MODULE$.apply(Sample.class)).$minus$greater(SampleToMiniBatch$.MODULE$.apply(miniBatch, i, None$.MODULE$, classTag, tensorNumeric), ClassTag$.MODULE$.apply(MiniBatch.class)).toDistributed(), abstractCriterion, classTag, tensorNumeric);
        }
        return distriOptimizerV2;
    }

    public <T, D> Optimizer<T, D> apply(AbstractModule<Activity, Activity, T> abstractModule, AbstractDataSet<D, ?> abstractDataSet, AbstractCriterion<Activity, Activity, T> abstractCriterion, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Optimizer localOptimizer;
        Optimizer distriOptimizerV2;
        if (abstractDataSet instanceof DistributedDataSet) {
            DistributedDataSet distributedDataSet = (DistributedDataSet) abstractDataSet;
            OptimizerVersion optimizerVersion = Engine$.MODULE$.getOptimizerVersion();
            if (OptimizerV1$.MODULE$.equals(optimizerVersion)) {
                distriOptimizerV2 = new DistriOptimizer(abstractModule, distributedDataSet.toDistributed(), abstractCriterion, classTag, tensorNumeric);
            } else {
                if (!OptimizerV2$.MODULE$.equals(optimizerVersion)) {
                    throw new MatchError(optimizerVersion);
                }
                distriOptimizerV2 = new DistriOptimizerV2(abstractModule, distributedDataSet.toDistributed(), abstractCriterion, classTag, tensorNumeric);
            }
            localOptimizer = distriOptimizerV2;
        } else {
            if (!(abstractDataSet instanceof LocalDataSet)) {
                throw new UnsupportedOperationException();
            }
            localOptimizer = new LocalOptimizer(abstractModule, ((LocalDataSet) abstractDataSet).toLocal(), abstractCriterion, classTag, tensorNumeric);
        }
        return localOptimizer;
    }

    public <T> Null$ apply$default$5() {
        return null;
    }

    public <T> Null$ apply$default$6() {
        return null;
    }

    public <T extends ParameterProcessor> int findIndex(ArrayBuffer<ParameterProcessor> arrayBuffer, ClassTag<T> classTag) {
        int i = 0;
        while (true) {
            int i2 = i;
            if (i2 >= arrayBuffer.size()) {
                return -1;
            }
            if (package$.MODULE$.classTag(classTag).runtimeClass().isInstance(arrayBuffer.apply(i2))) {
                return i2;
            }
            i = i2 + 1;
        }
    }

    public static final /* synthetic */ boolean $anonfun$checkSubModules$5(Tuple2 tuple2, Tuple2 tuple22) {
        return ((Tensor) tuple2._2()).storageOffset() < ((Tensor) tuple22._2()).storageOffset();
    }

    private Optimizer$() {
        MODULE$ = this;
        this.com$intel$analytics$bigdl$optim$Optimizer$$logger = Logger.getLogger(getClass());
    }
}
