package com.intel.analytics.bigdl.optim;

import com.intel.analytics.bigdl.dataset.DistributedDataSet;
import com.intel.analytics.bigdl.dataset.MiniBatch;
import com.intel.analytics.bigdl.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.utils.Engine$;
import com.intel.analytics.bigdl.utils.MklBlas$;
import com.intel.analytics.bigdl.utils.ThreadPool;
import org.apache.log4j.Logger;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.rdd.RDD;
import scala.Array$;
import scala.Function0;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.TraversableOnce;
import scala.collection.immutable.IndexedSeq;
import scala.collection.immutable.IndexedSeq$;
import scala.collection.mutable.ArrayOps;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichInt$;
import scala.runtime.ScalaRunTime$;

/* compiled from: DistriValidator.scala */
@ScalaSignature(bytes = "\u0006\u0001E<QAC\u0006\t\u0002Y1Q\u0001G\u0006\t\u0002eAQ\u0001I\u0001\u0005\u0002\u0005BqAI\u0001C\u0002\u0013\u00051\u0005\u0003\u0004/\u0003\u0001\u0006I\u0001\n\u0004\u00051-\u0001q\u0006\u0003\u0005F\u000b\t\u0005\t\u0015!\u0003G\u0011!AVA!A!\u0002\u0013I\u0006B\u0002\u0011\u0006\t\u0003YA\fC\u0003a\u000b\u0011\u0005\u0013-A\bESN$(/\u001b,bY&$\u0017\r^8s\u0015\taQ\"A\u0003paRLWN\u0003\u0002\u000f\u001f\u0005)!-[4eY*\u0011\u0001#E\u0001\nC:\fG.\u001f;jGNT!AE\n\u0002\u000b%tG/\u001a7\u000b\u0003Q\t1aY8n\u0007\u0001\u0001\"aF\u0001\u000e\u0003-\u0011q\u0002R5tiJLg+\u00197jI\u0006$xN]\n\u0003\u0003i\u0001\"a\u0007\u0010\u000e\u0003qQ\u0011!H\u0001\u0006g\u000e\fG.Y\u0005\u0003?q\u0011a!\u00118z%\u00164\u0017A\u0002\u001fj]&$h\bF\u0001\u0017\u0003\u0019awnZ4feV\tA\u0005\u0005\u0002&Y5\taE\u0003\u0002(Q\u0005)An\\45U*\u0011\u0011FK\u0001\u0007CB\f7\r[3\u000b\u0003-\n1a\u001c:h\u0013\ticE\u0001\u0004M_\u001e<WM]\u0001\bY><w-\u001a:!+\t\u0001dg\u0005\u0002\u0006cA!qC\r\u001b@\u0013\t\u00194BA\u0005WC2LG-\u0019;peB\u0011QG\u000e\u0007\u0001\t\u00159TA1\u00019\u0005\u0005!\u0016CA\u001d=!\tY\"(\u0003\u0002<9\t9aj\u001c;iS:<\u0007CA\u000e>\u0013\tqDDA\u0002B]f\u00042\u0001Q\"5\u001b\u0005\t%B\u0001\"\u000e\u0003\u001d!\u0017\r^1tKRL!\u0001R!\u0003\u00135Kg.\u001b\"bi\u000eD\u0017!B7pI\u0016d\u0007cA$Vi9\u0011\u0001j\u0015\b\u0003\u0013Js!AS)\u000f\u0005-\u0003fB\u0001'P\u001b\u0005i%B\u0001(\u0016\u0003\u0019a$o\\8u}%\tA#\u0003\u0002\u0013'%\u0011\u0001#E\u0005\u0003\u001d=I!\u0001V\u0007\u0002\u000fA\f7m[1hK&\u0011ak\u0016\u0002\u0007\u001b>$W\u000f\\3\u000b\u0005Qk\u0011a\u00023bi\u0006\u001cV\r\u001e\t\u0004\u0001j{\u0014BA.B\u0005I!\u0015n\u001d;sS\n,H/\u001a3ECR\f7+\u001a;\u0015\u0007usv\fE\u0002\u0018\u000bQBQ!\u0012\u0005A\u0002\u0019CQ\u0001\u0017\u0005A\u0002e\u000bA\u0001^3tiR\u0011!M\u001c\t\u00047\r,\u0017B\u00013\u001d\u0005\u0015\t%O]1z!\u0011Yb\r[6\n\u0005\u001dd\"A\u0002+va2,'\u0007\u0005\u0002\u0018S&\u0011!n\u0003\u0002\u0011-\u0006d\u0017\u000eZ1uS>t'+Z:vYR\u00042a\u000675\u0013\ti7B\u0001\tWC2LG-\u0019;j_:lU\r\u001e5pI\")q.\u0003a\u0001a\u0006Aa/T3uQ>$7\u000fE\u0002\u001cG.\u0004")
/* loaded from: input_file:com/intel/analytics/bigdl/optim/DistriValidator.class */
public class DistriValidator<T> extends Validator<T, MiniBatch<T>> {
    private final AbstractModule<Activity, Activity, T> model;
    private final DistributedDataSet<MiniBatch<T>> dataSet;

    public static Logger logger() {
        return DistriValidator$.MODULE$.logger();
    }

    @Override // com.intel.analytics.bigdl.optim.Validator
    public Tuple2<ValidationResult, ValidationMethod<T>>[] test(ValidationMethod<T>[] validationMethodArr) {
        RDD<T> data = this.dataSet.data(false);
        Broadcast broadcast = data.sparkContext().broadcast(new Tuple2(this.model.evaluate2(), validationMethodArr), ClassTag$.MODULE$.apply(Tuple2.class));
        if (!MklBlas$.MODULE$.equals(Engine$.MODULE$.getEngineType())) {
            throw new IllegalArgumentException();
        }
        int coreNumber = Engine$.MODULE$.coreNumber();
        int nodeNumber = Engine$.MODULE$.nodeNumber();
        int coreNumber2 = Engine$.MODULE$.coreNumber();
        return (Tuple2[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) data.mapPartitions(iterator -> {
            Engine$.MODULE$.setNodeAndCore(nodeNumber, coreNumber2);
            AbstractModule abstractModule = (AbstractModule) ((Tuple2) broadcast.value())._1();
            ValidationMethod[] validationMethodArr2 = (ValidationMethod[]) ((Tuple2) broadcast.value())._2();
            DistriValidator$.MODULE$.logger().info(new StringBuilder(26).append("model thread pool size is ").append(Engine$.MODULE$.model().getPoolSize()).toString());
            AbstractModule[] abstractModuleArr = (AbstractModule[]) ((TraversableOnce) RichInt$.MODULE$.to$extension0(Predef$.MODULE$.intWrapper(1), coreNumber).map(obj -> {
                return $anonfun$test$2(abstractModule, BoxesRunTime.unboxToInt(obj));
            }, IndexedSeq$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.apply(AbstractModule.class));
            ValidationMethod[][] validationMethodArr3 = (ValidationMethod[][]) ((TraversableOnce) RichInt$.MODULE$.to$extension0(Predef$.MODULE$.intWrapper(1), coreNumber).map(obj2 -> {
                return $anonfun$test$3(validationMethodArr2, BoxesRunTime.unboxToInt(obj2));
            }, IndexedSeq$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(ValidationMethod.class)));
            return iterator.map(miniBatch -> {
                int size = miniBatch.size() / coreNumber;
                int size2 = miniBatch.size() % coreNumber;
                int i = size == 0 ? size2 : coreNumber;
                ThreadPool m1282default = Engine$.MODULE$.m1282default();
                return (ValidationResult[]) m1282default.invokeAndWait((IndexedSeq) RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), i).map(obj3 -> {
                    return $anonfun$test$6(size, size2, miniBatch, abstractModuleArr, validationMethodArr3, BoxesRunTime.unboxToInt(obj3));
                }, IndexedSeq$.MODULE$.canBuildFrom()), m1282default.invokeAndWait$default$2()).reduce((validationResultArr, validationResultArr2) -> {
                    return (ValidationResult[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(validationResultArr)).zip(Predef$.MODULE$.wrapRefArray(validationResultArr2), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class))))).map(tuple2 -> {
                        if (tuple2 != null) {
                            return ((ValidationResult) tuple2._1()).$plus((ValidationResult) tuple2._2());
                        }
                        throw new MatchError(tuple2);
                    }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ValidationResult.class)));
                });
            });
        }, data.mapPartitions$default$2(), ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(ValidationResult.class))).reduce((validationResultArr, validationResultArr2) -> {
            return (ValidationResult[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(validationResultArr)).zip(Predef$.MODULE$.wrapRefArray(validationResultArr2), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class))))).map(tuple2 -> {
                if (tuple2 != null) {
                    return ((ValidationResult) tuple2._1()).$plus((ValidationResult) tuple2._2());
                }
                throw new MatchError(tuple2);
            }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ValidationResult.class)));
        }))).zip(Predef$.MODULE$.wrapRefArray(validationMethodArr), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)));
    }

    public static final /* synthetic */ AbstractModule $anonfun$test$2(AbstractModule abstractModule, int i) {
        return abstractModule.cloneModule().evaluate2();
    }

    public static final /* synthetic */ ValidationMethod[] $anonfun$test$3(ValidationMethod[] validationMethodArr, int i) {
        return (ValidationMethod[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(validationMethodArr)).map(validationMethod -> {
            return validationMethod.m1003clone();
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ValidationMethod.class)));
    }

    public static final /* synthetic */ Function0 $anonfun$test$6(int i, int i2, MiniBatch miniBatch, AbstractModule[] abstractModuleArr, ValidationMethod[][] validationMethodArr, int i3) {
        return () -> {
            MiniBatch<T> slice = miniBatch.slice((i3 * i) + package$.MODULE$.min(i3, i2) + 1, i + (i3 < i2 ? 1 : 0));
            Activity input = slice.getInput();
            Activity target = slice.getTarget();
            Activity forward = abstractModuleArr[i3].forward(input);
            return (ValidationResult[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(validationMethodArr[i3])).map(validationMethod -> {
                return validationMethod.apply(forward, target);
            }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ValidationResult.class)));
        };
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public DistriValidator(AbstractModule<Activity, Activity, T> abstractModule, DistributedDataSet<MiniBatch<T>> distributedDataSet) {
        super(abstractModule, distributedDataSet);
        this.model = abstractModule;
        this.dataSet = distributedDataSet;
    }
}
