package com.intel.analytics.bigdl.optim;

import com.intel.analytics.bigdl.dataset.MiniBatch;
import com.intel.analytics.bigdl.dataset.PaddingParam;
import com.intel.analytics.bigdl.dataset.Sample;
import com.intel.analytics.bigdl.dataset.SampleToMiniBatch;
import com.intel.analytics.bigdl.dataset.SampleToMiniBatch$;
import com.intel.analytics.bigdl.dataset.Transformer;
import com.intel.analytics.bigdl.models.utils.ModelBroadcast;
import com.intel.analytics.bigdl.models.utils.ModelBroadcast$;
import com.intel.analytics.bigdl.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.tensor.ConvertableTo$ConvertableToInt$;
import com.intel.analytics.bigdl.tensor.Tensor;
import com.intel.analytics.bigdl.tensor.Tensor$;
import com.intel.analytics.bigdl.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.transform.vision.image.DistributedImageFrame;
import com.intel.analytics.bigdl.transform.vision.image.ImageFeature;
import com.intel.analytics.bigdl.transform.vision.image.ImageFeature$;
import com.intel.analytics.bigdl.transform.vision.image.ImageFrame$;
import com.intel.analytics.bigdl.utils.Engine$;
import com.intel.analytics.bigdl.utils.EngineType;
import com.intel.analytics.bigdl.utils.MklDnn$;
import com.intel.analytics.bigdl.utils.T$;
import com.intel.analytics.bigdl.utils.Table;
import com.intel.analytics.bigdl.utils.intermediate.ConversionUtils$;
import org.apache.spark.SparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.rdd.RDD;
import scala.Array$;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Serializable;
import scala.Some;
import scala.Tuple2;
import scala.collection.Iterator;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichInt$;

/* compiled from: Predictor.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/optim/Predictor$.class */
public final class Predictor$ implements Serializable {
    public static Predictor$ MODULE$;

    static {
        new Predictor$();
    }

    public <T> None$ $lessinit$greater$default$2() {
        return None$.MODULE$;
    }

    public <T> int $lessinit$greater$default$3() {
        return 4;
    }

    public <T> Predictor<T> apply(AbstractModule<Activity, Activity, T> abstractModule, Option<PaddingParam<T>> option, int i, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        return new Predictor<>(abstractModule, option, i, classTag, tensorNumeric);
    }

    public <T> None$ apply$default$2() {
        return None$.MODULE$;
    }

    public <T> int apply$default$3() {
        return 4;
    }

    public <T> Seq<ImageFeature> predictImageBatch(AbstractModule<Activity, Activity, T> abstractModule, Seq<ImageFeature> seq, String str, String str2, Transformer<Sample<T>, MiniBatch<T>> transformer, boolean z, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Seq seq2 = (Seq) seq.filter(imageFeature -> {
            return BoxesRunTime.boxToBoolean(imageFeature.isValid());
        });
        seq2.toIterator().zip(predictSamples(abstractModule, (Seq) seq2.map(imageFeature2 -> {
            return (Sample) imageFeature2.apply(ImageFeature$.MODULE$.sample());
        }, Seq$.MODULE$.canBuildFrom()), transformer, z, str, classTag, tensorNumeric)).foreach(tuple2 -> {
            $anonfun$predictImageBatch$3(str2, tuple2);
            return BoxedUnit.UNIT;
        });
        return seq;
    }

    public <T> Iterator<Activity> predictSamples(AbstractModule<Activity, Activity, T> abstractModule, Seq<Sample<T>> seq, Transformer<Sample<T>, MiniBatch<T>> transformer, boolean z, String str, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        AbstractModule<Activity, Activity, T> abstractModule2;
        if (str == null) {
            abstractModule2 = abstractModule;
        } else {
            Option<AbstractModule<Activity, Activity, T>> apply = abstractModule.apply(str);
            Predef$.MODULE$.require(apply.isDefined(), () -> {
                return new StringBuilder(32).append("cannot find layer that map name ").append(str).toString();
            });
            abstractModule2 = (AbstractModule) apply.get();
        }
        AbstractModule<Activity, Activity, T> abstractModule3 = abstractModule2;
        return transformer.apply(seq.toIterator()).flatMap(miniBatch -> {
            return new ArrayOps.ofRef($anonfun$predictSamples$2(abstractModule, abstractModule3, z, classTag, tensorNumeric, miniBatch));
        });
    }

    public <T> Activity[] splitTensor(Tensor<T> tensor, boolean z, int i, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Tensor[] split;
        Tensor<T> m1158clone = z ? tensor : tensor.m1158clone();
        if (i == 1) {
            split = new Tensor[]{m1158clone.squeeze()};
        } else {
            int size = m1158clone.size(1);
            Predef$.MODULE$.require(i == size, () -> {
                return new StringBuilder(50).append("The batchSize is required to be ").append(size).append(", while actual is ").append(i).toString();
            });
            split = m1158clone.split(1);
        }
        return split;
    }

    public <T> Activity[] splitBatch(Activity activity, boolean z, int i, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Table[] tableArr;
        if (activity.isTensor()) {
            tableArr = splitTensor(activity.toTensor(tensorNumeric), z, i, classTag, tensorNumeric);
        } else {
            Table table = activity.toTable();
            Table[] tableArr2 = new Table[i];
            RichInt$.MODULE$.to$extension0(Predef$.MODULE$.intWrapper(1), table.length()).foreach$mVc$sp(i2 -> {
                Activity[] splitBatch = MODULE$.splitBatch((Activity) table.apply(BoxesRunTime.boxToInteger(i2)), z, i, classTag, tensorNumeric);
                int length = splitBatch.length;
                Predef$.MODULE$.require(i == length, () -> {
                    return new StringBuilder(50).append("The batchSize is required to be ").append(length).append(", while actual is ").append(i).toString();
                });
                int i2 = 0;
                while (true) {
                    int i3 = i2;
                    if (i3 >= i) {
                        return;
                    }
                    if (tableArr2[i3] == null) {
                        tableArr2[i3] = T$.MODULE$.apply();
                    }
                    tableArr2[i3].insert(splitBatch[i3]);
                    i2 = i3 + 1;
                }
            });
            tableArr = tableArr2;
        }
        return tableArr;
    }

    public <T> DistributedImageFrame predictImage(DistributedImageFrame distributedImageFrame, String str, boolean z, String str2, int i, AbstractModule<Activity, Activity, T> abstractModule, Option<PaddingParam<T>> option, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Activity dummyData = getDummyData(distributedImageFrame.rdd(), i, classTag, tensorNumeric);
        int length = distributedImageFrame.rdd().partitions().length * i;
        RDD<T> coalesce = ConversionUtils$.MODULE$.coalesce(distributedImageFrame.rdd(), ClassTag$.MODULE$.apply(ImageFeature.class));
        ModelBroadcast<T> broadcast = ModelBroadcast$.MODULE$.apply(classTag, tensorNumeric).broadcast(coalesce.sparkContext(), ConversionUtils$.MODULE$.convert(abstractModule.evaluate2(), classTag), dummyData);
        int length2 = coalesce.partitions().length;
        SparkContext sparkContext = coalesce.sparkContext();
        Option<Object> some = new Some<>(BoxesRunTime.boxToInteger(length2));
        Broadcast broadcast2 = sparkContext.broadcast(new Tuple2(SampleToMiniBatch$.MODULE$.apply(length, option, SampleToMiniBatch$.MODULE$.apply$default$3(), some, classTag, tensorNumeric), BoxesRunTime.boxToBoolean(z)), ClassTag$.MODULE$.apply(Tuple2.class));
        int i2 = length / length2;
        return ImageFrame$.MODULE$.rdd(coalesce.mapPartitions(iterator -> {
            AbstractModule value = broadcast.value(false, true, dummyData);
            Transformer cloneTransformer = ((Transformer) ((Tuple2) broadcast2.value())._1()).cloneTransformer();
            return iterator.grouped(i2).flatMap(seq -> {
                MODULE$.predictImageBatch(value, seq, str, str2, cloneTransformer, z, classTag, tensorNumeric);
                return seq;
            });
        }, coalesce.mapPartitions$default$2(), ClassTag$.MODULE$.apply(ImageFeature.class)));
    }

    public <T> RDD<Activity> predict(RDD<Sample<T>> rdd, int i, boolean z, AbstractModule<Activity, Activity, T> abstractModule, int i2, Option<PaddingParam<T>> option, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        int i3;
        int length = rdd.partitions().length;
        if (i > 0) {
            Predef$.MODULE$.require(i % length == 0, () -> {
                return new StringBuilder(71).append("Predictor.predict: total batch size ").append(i).append(" ").append("should be divided by partitionNum ").append(length).toString();
            });
            i3 = i;
        } else {
            i3 = i2 * length;
        }
        int i4 = i3;
        Activity dummyData = getDummyData(rdd, i4 / length, classTag, tensorNumeric);
        ModelBroadcast<T> broadcast = ModelBroadcast$.MODULE$.apply(classTag, tensorNumeric).broadcast(rdd.sparkContext(), ConversionUtils$.MODULE$.convert(abstractModule.evaluate2(), classTag), dummyData);
        RDD<T> coalesce = ConversionUtils$.MODULE$.coalesce(rdd, ClassTag$.MODULE$.apply(Sample.class));
        Broadcast broadcast2 = coalesce.sparkContext().broadcast(SampleToMiniBatch$.MODULE$.apply(i4, option, SampleToMiniBatch$.MODULE$.apply$default$3(), new Some<>(BoxesRunTime.boxToInteger(coalesce.partitions().length)), classTag, tensorNumeric), ClassTag$.MODULE$.apply(SampleToMiniBatch.class));
        return coalesce.mapPartitions(iterator -> {
            AbstractModule value = broadcast.value(false, true, dummyData);
            return ((Transformer) broadcast2.value()).cloneTransformer().apply(iterator).flatMap(miniBatch -> {
                return new ArrayOps.ofRef($anonfun$predict$3(value, z, classTag, tensorNumeric, miniBatch));
            });
        }, coalesce.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Activity.class));
    }

    public <T> String predictSamples$default$5() {
        return null;
    }

    public <T> String predictImage$default$2() {
        return null;
    }

    public <T> boolean predictImage$default$3() {
        return false;
    }

    public <T> String predictImage$default$4() {
        return ImageFeature$.MODULE$.predict();
    }

    public <T> int predict$default$2() {
        return -1;
    }

    public <T> boolean predict$default$3() {
        return false;
    }

    public <T> RDD<Object> predictClass(RDD<Sample<T>> rdd, int i, AbstractModule<Activity, Activity, T> abstractModule, int i2, Option<PaddingParam<T>> option, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        RDD<Activity> predict = predict(rdd, i, true, abstractModule, i2, option, classTag, tensorNumeric);
        return predict.mapPartitions(iterator -> {
            return iterator.map(activity -> {
                return BoxesRunTime.boxToInteger($anonfun$predictClass$2(tensorNumeric, activity));
            });
        }, predict.mapPartitions$default$2(), ClassTag$.MODULE$.Int());
    }

    public <T> int predictClass$default$2() {
        return -1;
    }

    public <T, R> Activity getDummyData(RDD<R> rdd, int i, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        EngineType engineType = Engine$.MODULE$.getEngineType();
        MklDnn$ mklDnn$ = MklDnn$.MODULE$;
        if (engineType != null ? engineType.equals(mklDnn$) : mklDnn$ == null) {
            if (Engine$.MODULE$.isMultiModels()) {
                Sample[] sampleArr = (Sample[]) Predef$.MODULE$.genericArrayOps(rdd.takeSample(false, i, rdd.takeSample$default$3())).map(obj -> {
                    return obj instanceof ImageFeature ? (Sample) ((ImageFeature) obj).apply(ImageFeature$.MODULE$.sample()) : (Sample) obj;
                }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Sample.class)));
                Option<Object> some = new Some<>(BoxesRunTime.boxToInteger(1));
                return ((MiniBatch) SampleToMiniBatch$.MODULE$.apply(i, SampleToMiniBatch$.MODULE$.apply$default$2(), SampleToMiniBatch$.MODULE$.apply$default$3(), some, classTag, tensorNumeric).apply(new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(sampleArr)).toIterator()).toSeq().head()).getInput();
            }
        }
        return Tensor$.MODULE$.apply(classTag, tensorNumeric);
    }

    private Object readResolve() {
        return MODULE$;
    }

    public static final /* synthetic */ void $anonfun$predictImageBatch$3(String str, Tuple2 tuple2) {
        ((ImageFeature) tuple2._1()).update(str, tuple2._2());
    }

    public static final /* synthetic */ Object[] $anonfun$predictSamples$2(AbstractModule abstractModule, AbstractModule abstractModule2, boolean z, ClassTag classTag, TensorNumericMath.TensorNumeric tensorNumeric, MiniBatch miniBatch) {
        abstractModule.forward(miniBatch.getInput());
        return Predef$.MODULE$.refArrayOps(MODULE$.splitBatch(abstractModule2.output(), z, miniBatch.size(), classTag, tensorNumeric));
    }

    public static final /* synthetic */ Object[] $anonfun$predict$3(AbstractModule abstractModule, boolean z, ClassTag classTag, TensorNumericMath.TensorNumeric tensorNumeric, MiniBatch miniBatch) {
        return Predef$.MODULE$.refArrayOps(MODULE$.splitBatch(abstractModule.forward(miniBatch.getInput()), z, miniBatch.size(), classTag, tensorNumeric));
    }

    public static final /* synthetic */ int $anonfun$predictClass$2(TensorNumericMath.TensorNumeric tensorNumeric, Activity activity) {
        Tensor tensor = activity.toTensor(tensorNumeric);
        Predef$.MODULE$.require(tensor.dim() == 1, () -> {
            return new StringBuilder(76).append("Predictor.predictClass:").append("Only support one sample has one label, but got ").append(tensor.dim()).append(" label").toString();
        });
        return BoxesRunTime.unboxToInt(tensorNumeric.toType(((Tensor) tensor.max(1)._2()).mo1135valueAt(1), ConvertableTo$ConvertableToInt$.MODULE$));
    }

    private Predictor$() {
        MODULE$ = this;
    }
}
