package net.sansa_stack.ml.spark.kernel;

import org.apache.jena.graph.Triple;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS;
import org.apache.spark.mllib.evaluation.MulticlassMetrics;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions$;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.immutable.$colon;
import scala.collection.immutable.Nil$;
import scala.reflect.ClassTag$;
import scala.reflect.api.Mirror;
import scala.reflect.api.TypeCreator;
import scala.reflect.api.Types;
import scala.reflect.api.Universe;
import scala.reflect.runtime.package$;
import scala.runtime.BoxesRunTime;
import scala.runtime.DoubleRef;
import scala.runtime.RichInt$;

/* compiled from: RDFFastTreeGraphKernelUtil.scala */
/* loaded from: input_file:net/sansa_stack/ml/spark/kernel/RDFFastTreeGraphKernelUtil$.class */
public final class RDFFastTreeGraphKernelUtil$ {
    public static RDFFastTreeGraphKernelUtil$ MODULE$;

    static {
        new RDFFastTreeGraphKernelUtil$();
    }

    public Dataset<Row> triplesToDF(SparkSession sparkSession, RDD<Triple> rdd, String str, String str2, String str3) {
        return sparkSession.implicits().rddToDatasetHolder(rdd.map(triple -> {
            return new Tuple3(triple.getSubject().toString(), triple.getPredicate().toString(), triple.getObject().toString());
        }, ClassTag$.MODULE$.apply(Tuple3.class)), sparkSession.implicits().newProductEncoder(package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(getClass().getClassLoader()), new TypeCreator() { // from class: net.sansa_stack.ml.spark.kernel.RDFFastTreeGraphKernelUtil$$typecreator5$1
            public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                Universe universe = mirror.universe();
                return universe.internal().reificationSupport().TypeRef(universe.internal().reificationSupport().ThisType(mirror.staticPackage("scala").asModule().moduleClass()), mirror.staticClass("scala.Tuple3"), new $colon.colon(mirror.staticClass("java.lang.String").asType().toTypeConstructor(), new $colon.colon(mirror.staticClass("java.lang.String").asType().toTypeConstructor(), new $colon.colon(mirror.staticClass("java.lang.String").asType().toTypeConstructor(), Nil$.MODULE$))));
            }
        }))).toDF(Predef$.MODULE$.wrapRefArray(new String[]{str, str2, str3}));
    }

    public String triplesToDF$default$3() {
        return "subject";
    }

    public String triplesToDF$default$4() {
        return "predicate";
    }

    public String triplesToDF$default$5() {
        return "object";
    }

    public Dataset<Row> getInstanceAndLabelDF(Dataset<Row> dataset, String str, String str2) {
        Dataset distinct = dataset.select(str, Predef$.MODULE$.wrapRefArray(new String[]{str2})).distinct();
        return new StringIndexer().setInputCol(str2).setOutputCol("label").fit(distinct).transform(distinct).drop(str2).groupBy(str, Predef$.MODULE$.wrapRefArray(new String[0])).agg(functions$.MODULE$.max("label").as("label"), Predef$.MODULE$.wrapRefArray(new Column[0])).toDF(Predef$.MODULE$.wrapRefArray(new String[]{"instance", "label"}));
    }

    public String getInstanceAndLabelDF$default$2() {
        return "subject";
    }

    public String getInstanceAndLabelDF$default$3() {
        return "object";
    }

    public void predictLogisticRegressionMLLIB(RDD<LabeledPoint> rdd, int i, int i2) {
        long nanoTime = System.nanoTime();
        rdd.cache();
        Predef$.MODULE$.println(new Tuple2("data count", BoxesRunTime.boxToLong(rdd.count())));
        long nanoTime2 = System.nanoTime();
        DoubleRef create = DoubleRef.create(0.0d);
        RichInt$.MODULE$.to$extension0(Predef$.MODULE$.intWrapper(1), i2).foreach$mVc$sp(i3 -> {
            Tuple2 trainAndValidate$1 = trainAndValidate$1(rdd, i3, i);
            if (trainAndValidate$1 == null) {
                throw new MatchError(trainAndValidate$1);
            }
            Tuple2 tuple2 = new Tuple2((LogisticRegressionModel) trainAndValidate$1._1(), BoxesRunTime.boxToDouble(trainAndValidate$1._2$mcD$sp()));
            create.elem += tuple2._2$mcD$sp();
        });
        long nanoTime3 = System.nanoTime();
        Predef$.MODULE$.println(new StringBuilder(18).append("Average Accuracy: ").append(create.elem / i2).toString());
        printTime("Feature Computation/Read", nanoTime, nanoTime2);
        printTime("Model learning/testing", nanoTime2, nanoTime3);
    }

    public int predictLogisticRegressionMLLIB$default$2() {
        return 2;
    }

    public int predictLogisticRegressionMLLIB$default$3() {
        return 5;
    }

    public void printTime(String str, long j, long j2) {
        Predef$.MODULE$.println(new StringBuilder(4).append(str).append(": ").append((j2 - j) / 1.0E9d).append(" s").toString());
    }

    private static final Tuple2 trainAndValidate$1(RDD rdd, long j, int i) {
        RDD[] randomSplit = rdd.randomSplit(new double[]{0.9d, 0.1d}, j);
        RDD cache = randomSplit[0].cache();
        RDD rdd2 = randomSplit[1];
        LogisticRegressionModel run = new LogisticRegressionWithLBFGS().setNumClasses(i).run(cache);
        return new Tuple2(run, BoxesRunTime.boxToDouble(new MulticlassMetrics(rdd2.map(labeledPoint -> {
            return new Tuple2.mcDD.sp(labeledPoint.label(), run.predict(labeledPoint.features()));
        }, ClassTag$.MODULE$.apply(Tuple2.class))).accuracy()));
    }

    private RDFFastTreeGraphKernelUtil$() {
        MODULE$ = this;
    }
}
