package com.intel.analytics.bigdl.dlframes;

import com.intel.analytics.bigdl.dataset.Sample;
import com.intel.analytics.bigdl.dataset.Sample$;
import com.intel.analytics.bigdl.nn.abstractnn.AbstractCriterion;
import com.intel.analytics.bigdl.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.optim.OptimMethod;
import com.intel.analytics.bigdl.optim.Optimizer;
import com.intel.analytics.bigdl.optim.Optimizer$;
import com.intel.analytics.bigdl.optim.Trigger;
import com.intel.analytics.bigdl.optim.Trigger$;
import com.intel.analytics.bigdl.tensor.ConvertableFrom$ConvertableFromDouble$;
import com.intel.analytics.bigdl.tensor.ConvertableFrom$ConvertableFromFloat$;
import com.intel.analytics.bigdl.tensor.Tensor;
import com.intel.analytics.bigdl.tensor.Tensor$;
import com.intel.analytics.bigdl.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.utils.T$;
import com.intel.analytics.bigdl.utils.Table;
import com.intel.analytics.bigdl.visualization.TrainSummary;
import com.intel.analytics.bigdl.visualization.ValidationSummary;
import org.apache.spark.ml.DLTransformerBase;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.DataType;
import scala.Function2;
import scala.MatchError;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: DLEstimator.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/dlframes/DLEstimator$mcF$sp.class */
public class DLEstimator$mcF$sp extends DLEstimator<Object> implements DLParams$mcF$sp {
    public final TensorNumericMath.TensorNumeric<Object> ev$mcF$sp;
    private final ClassTag<Object> evidence$1;

    @Override // com.intel.analytics.bigdl.dlframes.DLEstimator
    /* renamed from: setOptimMethod */
    public DLEstimator<Object> setOptimMethod2(OptimMethod<Object> optimMethod) {
        return setOptimMethod$mcF$sp(optimMethod);
    }

    @Override // com.intel.analytics.bigdl.dlframes.DLEstimator
    public DLEstimator<Object> setOptimMethod$mcF$sp(OptimMethod<Object> optimMethod) {
        return (DLEstimator$mcF$sp) set(optimMethod(), optimMethod);
    }

    @Override // com.intel.analytics.bigdl.dlframes.DLEstimator
    public DLModel<Object> internalFit(Dataset<Row> dataset) {
        return internalFit$mcF$sp(dataset);
    }

    @Override // com.intel.analytics.bigdl.dlframes.DLEstimator
    public DLModel<Object> internalFit$mcF$sp(Dataset<Row> dataset) {
        String str = (String) $(featuresCol());
        String str2 = (String) $(labelCol());
        RDD samples$3 = getSamples$3(dataset, str, str2);
        Table apply = T$.MODULE$.apply(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("learningRate"), $(learningRate())), (Seq<Tuple2<Object, Object>>) Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("learningRateDecay"), $(learningRateDecay()))}));
        Trigger maxEpoch = isSet(endWhen()) ? (Trigger) $(endWhen()) : Trigger$.MODULE$.maxEpoch(BoxesRunTime.unboxToInt($(maxEpoch())));
        Optimizer$ optimizer$ = Optimizer$.MODULE$;
        AbstractModule<Activity, Activity, Object> model = model();
        AbstractCriterion<Activity, Activity, Object> criterion = criterion();
        int unboxToInt = BoxesRunTime.unboxToInt($(batchSize()));
        Optimizer$.MODULE$.apply$default$5();
        Optimizer$.MODULE$.apply$default$6();
        Optimizer endWhen = optimizer$.apply(model, samples$3, criterion, unboxToInt, null, null, this.com$intel$analytics$bigdl$dlframes$DLEstimator$$evidence$1, this.ev$mcF$sp).setState(apply).setOptimMethod((OptimMethod) $(optimMethod())).setEndWhen(maxEpoch);
        if (com$intel$analytics$bigdl$dlframes$DLEstimator$$validationTrigger().isDefined()) {
            endWhen.setValidation((Trigger) com$intel$analytics$bigdl$dlframes$DLEstimator$$validationTrigger().get(), getSamples$3(com$intel$analytics$bigdl$dlframes$DLEstimator$$validationDF(), str, str2), com$intel$analytics$bigdl$dlframes$DLEstimator$$validationMethods(), com$intel$analytics$bigdl$dlframes$DLEstimator$$validationBatchSize());
            if (com$intel$analytics$bigdl$dlframes$DLEstimator$$validationSummary().isDefined()) {
                endWhen.setValidationSummary((ValidationSummary) com$intel$analytics$bigdl$dlframes$DLEstimator$$validationSummary().get());
            } else {
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            }
        } else {
            BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
        }
        if (com$intel$analytics$bigdl$dlframes$DLEstimator$$trainSummary().isDefined()) {
            endWhen.setTrainSummary((TrainSummary) com$intel$analytics$bigdl$dlframes$DLEstimator$$trainSummary().get());
        } else {
            BoxedUnit boxedUnit3 = BoxedUnit.UNIT;
        }
        return wrapBigDLModel$mcF$sp(endWhen.optimize(), featureSize());
    }

    @Override // com.intel.analytics.bigdl.dlframes.DLEstimator
    public DLModel<Object> wrapBigDLModel(AbstractModule<Activity, Activity, Object> abstractModule, int[] iArr) {
        return wrapBigDLModel$mcF$sp(abstractModule, iArr);
    }

    @Override // com.intel.analytics.bigdl.dlframes.DLEstimator
    public DLModel<Object> wrapBigDLModel$mcF$sp(AbstractModule<Activity, Activity, Object> abstractModule, int[] iArr) {
        return (DLModel) copyValues(new DLModel$mcF$sp(abstractModule, iArr, DLModel$.MODULE$.$lessinit$greater$default$3(), this.com$intel$analytics$bigdl$dlframes$DLEstimator$$evidence$1, this.ev$mcF$sp).setParent(this), copyValues$default$2());
    }

    @Override // com.intel.analytics.bigdl.dlframes.DLEstimator
    /* renamed from: copy */
    public DLEstimator<Object> mo239copy(ParamMap paramMap) {
        return copy$mcF$sp(paramMap);
    }

    @Override // com.intel.analytics.bigdl.dlframes.DLEstimator
    public DLEstimator<Object> copy$mcF$sp(ParamMap paramMap) {
        return (DLEstimator) copyValues(new DLEstimator$mcF$sp(model(), criterion(), featureSize(), labelSize(), DLEstimator$.MODULE$.$lessinit$greater$default$5(), this.com$intel$analytics$bigdl$dlframes$DLEstimator$$evidence$1, this.ev$mcF$sp), paramMap);
    }

    @Override // com.intel.analytics.bigdl.dlframes.DLEstimator
    /* renamed from: internalFit */
    public /* bridge */ /* synthetic */ DLTransformerBase mo242internalFit(Dataset dataset) {
        return internalFit((Dataset<Row>) dataset);
    }

    @Override // com.intel.analytics.bigdl.dlframes.DLEstimator
    /* renamed from: setOptimMethod$mcF$sp, reason: avoid collision after fix types in other method */
    public /* bridge */ /* synthetic */ DLEstimator<Object> setOptimMethod$mcF$sp2(OptimMethod optimMethod) {
        return setOptimMethod$mcF$sp((OptimMethod<Object>) optimMethod);
    }

    private final RDD getSamples$3(Dataset dataset, String str, String str2) {
        DataType dataType = dataset.schema().apply(str).dataType();
        int fieldIndex = dataset.schema().fieldIndex(str);
        DataType dataType2 = dataset.schema().apply(str2).dataType();
        int fieldIndex2 = dataset.schema().fieldIndex(str2);
        Function2<Row, Object, Seq<Object>> convertFunc = getConvertFunc(dataType);
        Function2<Row, Object, Seq<Object>> convertFunc2 = getConvertFunc(dataType2);
        return dataset.rdd().map(row -> {
            return new Tuple2((Seq) convertFunc.apply(row, BoxesRunTime.boxToInteger(fieldIndex)), (Seq) convertFunc2.apply(row, BoxesRunTime.boxToInteger(fieldIndex2)));
        }, ClassTag$.MODULE$.apply(Tuple2.class)).map(tuple2 -> {
            Seq seq;
            Seq seq2;
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            Seq seq3 = (Seq) tuple2._1();
            Seq seq4 = (Seq) tuple2._2();
            Object head = seq3.head();
            if (head instanceof Double) {
                seq = (Seq) seq3.map(d -> {
                    return this.ev$mcF$sp.fromType$mcF$sp(BoxesRunTime.boxToDouble(d), ConvertableFrom$ConvertableFromDouble$.MODULE$);
                }, Seq$.MODULE$.canBuildFrom());
            } else {
                if (!(head instanceof Float)) {
                    throw new MatchError(head);
                }
                seq = (Seq) seq3.map(f -> {
                    return this.ev$mcF$sp.fromType$mcF$sp(BoxesRunTime.boxToFloat(f), ConvertableFrom$ConvertableFromFloat$.MODULE$);
                }, Seq$.MODULE$.canBuildFrom());
            }
            Seq seq5 = seq;
            Object head2 = seq4.head();
            if (head2 instanceof Double) {
                seq2 = (Seq) seq4.map(d2 -> {
                    return this.ev$mcF$sp.fromType$mcF$sp(BoxesRunTime.boxToDouble(d2), ConvertableFrom$ConvertableFromDouble$.MODULE$);
                }, Seq$.MODULE$.canBuildFrom());
            } else {
                if (!(head2 instanceof Float)) {
                    throw new MatchError(head2);
                }
                seq2 = (Seq) seq4.map(f2 -> {
                    return this.ev$mcF$sp.fromType$mcF$sp(BoxesRunTime.boxToFloat(f2), ConvertableFrom$ConvertableFromFloat$.MODULE$);
                }, Seq$.MODULE$.canBuildFrom());
            }
            return new Tuple2(seq5, seq2);
        }, ClassTag$.MODULE$.apply(Tuple2.class)).map(tuple22 -> {
            if (tuple22 == null) {
                throw new MatchError(tuple22);
            }
            return Sample$.MODULE$.apply((Tensor) Tensor$.MODULE$.apply$mFc$sp((float[]) ((Seq) tuple22._1()).toArray(this.com$intel$analytics$bigdl$dlframes$DLEstimator$$evidence$1), this.featureSize(), this.com$intel$analytics$bigdl$dlframes$DLEstimator$$evidence$1, this.ev$mcF$sp), (Tensor) Tensor$.MODULE$.apply$mFc$sp((float[]) ((Seq) tuple22._2()).toArray(this.com$intel$analytics$bigdl$dlframes$DLEstimator$$evidence$1), this.labelSize(), this.com$intel$analytics$bigdl$dlframes$DLEstimator$$evidence$1, this.ev$mcF$sp), (ClassTag) this.com$intel$analytics$bigdl$dlframes$DLEstimator$$evidence$1, (TensorNumericMath.TensorNumeric) this.ev$mcF$sp);
        }, ClassTag$.MODULE$.apply(Sample.class));
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public DLEstimator$mcF$sp(AbstractModule<Activity, Activity, Object> abstractModule, AbstractCriterion<Activity, Activity, Object> abstractCriterion, int[] iArr, int[] iArr2, String str, ClassTag<Object> classTag, TensorNumericMath.TensorNumeric<Object> tensorNumeric) {
        super(abstractModule, abstractCriterion, iArr, iArr2, str, classTag, tensorNumeric);
        this.ev$mcF$sp = tensorNumeric;
        this.evidence$1 = classTag;
    }
}
