package com.intel.analytics.bigdl.optim;

import com.intel.analytics.bigdl.dataset.MiniBatch;
import com.intel.analytics.bigdl.dataset.Sample;
import com.intel.analytics.bigdl.dataset.SampleToMiniBatch$;
import com.intel.analytics.bigdl.dataset.Transformer;
import com.intel.analytics.bigdl.models.utils.ModelBroadcast;
import com.intel.analytics.bigdl.models.utils.ModelBroadcast$;
import com.intel.analytics.bigdl.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.utils.intermediate.ConversionUtils$;
import org.apache.spark.SparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.rdd.RDD;
import scala.Array$;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Serializable;
import scala.Some;
import scala.Tuple2;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;
import scala.runtime.ScalaRunTime$;

/* compiled from: Evaluator.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005ut!B\b\u0011\u0011\u0003Yb!B\u000f\u0011\u0011\u0003q\u0002\"\u0002\u0015\u0002\t\u0003I\u0003\"\u0002\u0016\u0002\t\u0003Y\u0003\"CA5\u0003\u0005\u0005I\u0011BA6\r\u0011i\u0002\u0003A\u0018\t\u0011E*!\u0011!Q\u0001\nIB\u0001bT\u0003\u0003\u0004\u0003\u0006Y\u0001\u0015\u0005\t-\u0016\u0011\t\u0011)A\u0006/\"1\u0001&\u0002C\u0001!\tDq\u0001[\u0003C\u0002\u0013%\u0011\u000e\u0003\u0004n\u000b\u0001\u0006IA\u001b\u0005\u0006]\u0016!\ta\u001c\u0005\n\u0003_)\u0011\u0013!C\u0001\u0003cA\u0001\"a\u0012\u0006\t\u0003\u0011\u0012\u0011J\u0001\n\u000bZ\fG.^1u_JT!!\u0005\n\u0002\u000b=\u0004H/[7\u000b\u0005M!\u0012!\u00022jO\u0012d'BA\u000b\u0017\u0003%\tg.\u00197zi&\u001c7O\u0003\u0002\u00181\u0005)\u0011N\u001c;fY*\t\u0011$A\u0002d_6\u001c\u0001\u0001\u0005\u0002\u001d\u00035\t\u0001CA\u0005Fm\u0006dW/\u0019;peN\u0019\u0011aH\u0013\u0011\u0005\u0001\u001aS\"A\u0011\u000b\u0003\t\nQa]2bY\u0006L!\u0001J\u0011\u0003\r\u0005s\u0017PU3g!\t\u0001c%\u0003\u0002(C\ta1+\u001a:jC2L'0\u00192mK\u00061A(\u001b8jiz\"\u0012aG\u0001\u0006CB\u0004H._\u000b\u0004Y\u0005eCcA\u0017\u0002fQ)a&a\u0017\u0002bA!A$BA,+\t\u0001diE\u0002\u0006?\u0015\nQ!\\8eK2\u00042aM!E\u001d\t!tH\u0004\u00026}9\u0011a'\u0010\b\u0003oqr!\u0001O\u001e\u000e\u0003eR!A\u000f\u000e\u0002\rq\u0012xn\u001c;?\u0013\u0005I\u0012BA\f\u0019\u0013\t)b#\u0003\u0002\u0014)%\u0011\u0001IE\u0001\ba\u0006\u001c7.Y4f\u0013\t\u00115I\u0001\u0004N_\u0012,H.\u001a\u0006\u0003\u0001J\u0001\"!\u0012$\r\u0001\u0011)q)\u0002b\u0001\u0011\n\tA+\u0005\u0002J\u0019B\u0011\u0001ES\u0005\u0003\u0017\u0006\u0012qAT8uQ&tw\r\u0005\u0002!\u001b&\u0011a*\t\u0002\u0004\u0003:L\u0018AC3wS\u0012,gnY3%eA\u0019\u0011\u000b\u0016#\u000e\u0003IS!aU\u0011\u0002\u000fI,g\r\\3di&\u0011QK\u0015\u0002\t\u00072\f7o\u001d+bO\u0006\u0011QM\u001e\t\u00041~#eBA-]\u001d\t!$,\u0003\u0002\\%\u00051A/\u001a8t_JL!!\u00180\u0002#Q+gn]8s\u001dVlWM]5d\u001b\u0006$\bN\u0003\u0002\\%%\u0011\u0001-\u0019\u0002\u000e)\u0016t7o\u001c:Ok6,'/[2\u000b\u0005usFCA2h)\r!WM\u001a\t\u00049\u0015!\u0005\"B(\n\u0001\b\u0001\u0006\"\u0002,\n\u0001\b9\u0006\"B\u0019\n\u0001\u0004\u0011\u0014!\u00052bi\u000eD\u0007+\u001a:QCJ$\u0018\u000e^5p]V\t!\u000e\u0005\u0002!W&\u0011A.\t\u0002\u0004\u0013:$\u0018A\u00052bi\u000eD\u0007+\u001a:QCJ$\u0018\u000e^5p]\u0002\nA\u0001^3tiR1\u0001\u000f`A\u0010\u0003K\u00012\u0001I9t\u0013\t\u0011\u0018EA\u0003BeJ\f\u0017\u0010\u0005\u0003!iZL\u0018BA;\"\u0005\u0019!V\u000f\u001d7feA\u0011Ad^\u0005\u0003qB\u0011\u0001CV1mS\u0012\fG/[8o%\u0016\u001cX\u000f\u001c;\u0011\u0007qQH)\u0003\u0002|!\t\u0001b+\u00197jI\u0006$\u0018n\u001c8NKRDw\u000e\u001a\u0005\u0006{2\u0001\rA`\u0001\bI\u0006$\u0018m]3u!\u0015y\u0018\u0011CA\u000b\u001b\t\t\tA\u0003\u0003\u0002\u0004\u0005\u0015\u0011a\u0001:eI*!\u0011qAA\u0005\u0003\u0015\u0019\b/\u0019:l\u0015\u0011\tY!!\u0004\u0002\r\u0005\u0004\u0018m\u00195f\u0015\t\ty!A\u0002pe\u001eLA!a\u0005\u0002\u0002\t\u0019!\u000b\u0012#\u0011\u000b\u0005]\u00111\u0004#\u000e\u0005\u0005e!BA?\u0013\u0013\u0011\ti\"!\u0007\u0003\rM\u000bW\u000e\u001d7f\u0011\u001d\t\t\u0003\u0004a\u0001\u0003G\t\u0001B^'fi\"|Gm\u001d\t\u0004AEL\b\"CA\u0014\u0019A\u0005\t\u0019AA\u0015\u0003%\u0011\u0017\r^2i'&TX\r\u0005\u0003!\u0003WQ\u0017bAA\u0017C\t1q\n\u001d;j_:\fa\u0002^3ti\u0012\"WMZ1vYR$3'\u0006\u0002\u00024)\"\u0011\u0011FA\u001bW\t\t9\u0004\u0005\u0003\u0002:\u0005\rSBAA\u001e\u0015\u0011\ti$a\u0010\u0002\u0013Ut7\r[3dW\u0016$'bAA!C\u0005Q\u0011M\u001c8pi\u0006$\u0018n\u001c8\n\t\u0005\u0015\u00131\b\u0002\u0012k:\u001c\u0007.Z2lK\u00124\u0016M]5b]\u000e,\u0017!\u0004;fgRl\u0015N\\5CCR\u001c\u0007\u000eF\u0003q\u0003\u0017\n)\u0006\u0003\u0004~\u001d\u0001\u0007\u0011Q\n\t\u0006\u007f\u0006E\u0011q\n\t\u0006\u0003/\t\t\u0006R\u0005\u0005\u0003'\nIBA\u0005NS:L')\u0019;dQ\"9\u0011\u0011\u0005\bA\u0002\u0005\r\u0002cA#\u0002Z\u0011)qi\u0001b\u0001\u0011\"I\u0011QL\u0002\u0002\u0002\u0003\u000f\u0011qL\u0001\u000bKZLG-\u001a8dK\u0012\n\u0004\u0003B)U\u0003/BaAV\u0002A\u0004\u0005\r\u0004\u0003\u0002-`\u0003/Ba!M\u0002A\u0002\u0005\u001d\u0004\u0003B\u001aB\u0003/\n1B]3bIJ+7o\u001c7wKR\u0011\u0011Q\u000e\t\u0005\u0003_\nI(\u0004\u0002\u0002r)!\u00111OA;\u0003\u0011a\u0017M\\4\u000b\u0005\u0005]\u0014\u0001\u00026bm\u0006LA!a\u001f\u0002r\t1qJ\u00196fGR\u0004")
/* loaded from: input_file:com/intel/analytics/bigdl/optim/Evaluator.class */
public class Evaluator<T> implements Serializable {
    private final AbstractModule<Activity, Activity, T> model;
    private final ClassTag<T> evidence$2;
    private final TensorNumericMath.TensorNumeric<T> ev;
    private final int batchPerPartition = 4;

    public static <T> Evaluator<T> apply(AbstractModule<Activity, Activity, T> abstractModule, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        return Evaluator$.MODULE$.apply(abstractModule, classTag, tensorNumeric);
    }

    private int batchPerPartition() {
        return this.batchPerPartition;
    }

    public Tuple2<ValidationResult, ValidationMethod<T>>[] test(RDD<Sample<T>> rdd, ValidationMethod<T>[] validationMethodArr, Option<Object> option) {
        int length = rdd.partitions().length;
        int unboxToInt = BoxesRunTime.unboxToInt(option.getOrElse(() -> {
            return this.batchPerPartition() * length;
        }));
        Activity dummyData = Predictor$.MODULE$.getDummyData(rdd, unboxToInt / length, this.evidence$2, this.ev);
        ModelBroadcast<T> broadcast = ModelBroadcast$.MODULE$.apply(this.evidence$2, this.ev).broadcast(rdd.sparkContext(), ConversionUtils$.MODULE$.convert(this.model.evaluate2(), this.evidence$2), dummyData);
        RDD<T> coalesce = ConversionUtils$.MODULE$.coalesce(rdd, ClassTag$.MODULE$.apply(Sample.class));
        SparkContext sparkContext = coalesce.sparkContext();
        Option<Object> some = new Some<>(BoxesRunTime.boxToInteger(coalesce.partitions().length));
        Broadcast broadcast2 = sparkContext.broadcast(new Tuple2(validationMethodArr, SampleToMiniBatch$.MODULE$.apply(unboxToInt, SampleToMiniBatch$.MODULE$.apply$default$2(), SampleToMiniBatch$.MODULE$.apply$default$3(), some, this.evidence$2, this.ev)), ClassTag$.MODULE$.apply(Tuple2.class));
        return (Tuple2[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) coalesce.mapPartitions(iterator -> {
            AbstractModule<Activity, Activity, T> value = broadcast.value(false, true, dummyData);
            ValidationMethod[] validationMethodArr2 = (ValidationMethod[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) ((Tuple2) broadcast2.value())._1())).map(validationMethod -> {
                return validationMethod.m1003clone();
            }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ValidationMethod.class)));
            return ((Transformer) ((Tuple2) broadcast2.value())._2()).cloneTransformer().apply(iterator).map(miniBatch -> {
                Activity forward = value.forward(miniBatch.getInput());
                return (ValidationResult[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(validationMethodArr2)).map(validationMethod2 -> {
                    return validationMethod2.apply(forward, miniBatch.getTarget());
                }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ValidationResult.class)));
            });
        }, coalesce.mapPartitions$default$2(), ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(ValidationResult.class))).reduce((validationResultArr, validationResultArr2) -> {
            return (ValidationResult[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(validationResultArr)).zip(Predef$.MODULE$.wrapRefArray(validationResultArr2), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class))))).map(tuple2 -> {
                if (tuple2 != null) {
                    return ((ValidationResult) tuple2._1()).$plus((ValidationResult) tuple2._2());
                }
                throw new MatchError(tuple2);
            }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ValidationResult.class)));
        }))).zip(Predef$.MODULE$.wrapRefArray(validationMethodArr), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)));
    }

    public Option<Object> test$default$3() {
        return None$.MODULE$;
    }

    public Tuple2<ValidationResult, ValidationMethod<T>>[] testMiniBatch(RDD<MiniBatch<T>> rdd, ValidationMethod<T>[] validationMethodArr) {
        RDD<T> coalesce = ConversionUtils$.MODULE$.coalesce(rdd, ClassTag$.MODULE$.apply(MiniBatch.class));
        ModelBroadcast<T> broadcast = ModelBroadcast$.MODULE$.apply(this.evidence$2, this.ev).broadcast(coalesce.sparkContext(), ConversionUtils$.MODULE$.convert(this.model.evaluate2(), this.evidence$2));
        Broadcast broadcast2 = coalesce.sparkContext().broadcast(validationMethodArr, ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(ValidationMethod.class)));
        return (Tuple2[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) coalesce.mapPartitions(iterator -> {
            AbstractModule<Activity, Activity, T> value = broadcast.value(broadcast.value$default$1(), broadcast.value$default$2());
            ValidationMethod[] validationMethodArr2 = (ValidationMethod[]) broadcast2.value();
            return iterator.map(miniBatch -> {
                Activity forward = value.forward(miniBatch.getInput());
                return (ValidationResult[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(validationMethodArr2)).map(validationMethod -> {
                    return validationMethod.apply(forward, miniBatch.getTarget());
                }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ValidationResult.class)));
            });
        }, coalesce.mapPartitions$default$2(), ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(ValidationResult.class))).reduce((validationResultArr, validationResultArr2) -> {
            return (ValidationResult[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(validationResultArr)).zip(Predef$.MODULE$.wrapRefArray(validationResultArr2), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class))))).map(tuple2 -> {
                if (tuple2 != null) {
                    return ((ValidationResult) tuple2._1()).$plus((ValidationResult) tuple2._2());
                }
                throw new MatchError(tuple2);
            }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ValidationResult.class)));
        }))).zip(Predef$.MODULE$.wrapRefArray(validationMethodArr), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)));
    }

    public Evaluator(AbstractModule<Activity, Activity, T> abstractModule, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        this.model = abstractModule;
        this.evidence$2 = classTag;
        this.ev = tensorNumeric;
    }
}
