package net.sansa_stack.ml.spark.kernel;

import org.apache.spark.ml.feature.CountVectorizer;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.mllib.linalg.SparseVector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.util.MLUtils$;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions$;
import scala.Predef$;
import scala.Serializable;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.ObjectRef;
import scala.runtime.RichInt$;

/* compiled from: RDFFastTreeGraphKernel_v2.scala */
@ScalaSignature(bytes = "\u0006\u0001\u00055a\u0001\u0002\n\u0014\u0001yA\u0001\u0002\u000b\u0001\u0003\u0006\u0004%\t!\u000b\u0005\tk\u0001\u0011\t\u0011)A\u0005U!A!\b\u0001BC\u0002\u0013\u00051\b\u0003\u0005O\u0001\t\u0005\t\u0015!\u0003=\u0011!y\u0005A!b\u0001\n\u0003Y\u0004\u0002\u0003)\u0001\u0005\u0003\u0005\u000b\u0011\u0002\u001f\t\u0011E\u0003!Q1A\u0005\u0002IC\u0001B\u0016\u0001\u0003\u0002\u0003\u0006Ia\u0015\u0005\u0006/\u0002!\t\u0001\u0017\u0005\u0006?\u0002!\t\u0001\u0019\u0005\u0006C\u0002!\ta\u000f\u0005\u0006E\u0002!\taY\u0004\u0006eNA\ta\u001d\u0004\u0006%MA\t\u0001\u001e\u0005\u0006/:!\t!\u001e\u0005\u0006m:!\ta\u001e\u0005\by:\t\t\u0011\"\u0003~\u0005e\u0011FI\u0012$bgR$&/Z3He\u0006\u0004\bnS3s]\u0016dwL\u001e\u001a\u000b\u0005Q)\u0012AB6fe:,GN\u0003\u0002\u0017/\u0005)1\u000f]1sW*\u0011\u0001$G\u0001\u0003[2T!AG\u000e\u0002\u0017M\fgn]1`gR\f7m\u001b\u0006\u00029\u0005\u0019a.\u001a;\u0004\u0001M\u0019\u0001aH\u0013\u0011\u0005\u0001\u001aS\"A\u0011\u000b\u0003\t\nQa]2bY\u0006L!\u0001J\u0011\u0003\r\u0005s\u0017PU3g!\t\u0001c%\u0003\u0002(C\ta1+\u001a:jC2L'0\u00192mK\u0006a1\u000f]1sWN+7o]5p]V\t!\u0006\u0005\u0002,g5\tAF\u0003\u0002.]\u0005\u00191/\u001d7\u000b\u0005Yy#B\u0001\u00192\u0003\u0019\t\u0007/Y2iK*\t!'A\u0002pe\u001eL!\u0001\u000e\u0017\u0003\u0019M\u0003\u0018M]6TKN\u001c\u0018n\u001c8\u0002\u001bM\u0004\u0018M]6TKN\u001c\u0018n\u001c8!Q\t\u0011q\u0007\u0005\u0002!q%\u0011\u0011(\t\u0002\niJ\fgn]5f]R\f\u0001\u0002\u001e:ja2,GIR\u000b\u0002yA\u0011Qh\u0013\b\u0003}%s!a\u0010%\u000f\u0005\u0001;eBA!G\u001d\t\u0011U)D\u0001D\u0015\t!U$\u0001\u0004=e>|GOP\u0005\u0002e%\u0011\u0001'M\u0005\u0003-=J!!\f\u0018\n\u0005)c\u0013a\u00029bG.\fw-Z\u0005\u0003\u00196\u0013\u0011\u0002R1uC\u001a\u0013\u0018-\\3\u000b\u0005)c\u0013!\u0003;sSBdW\r\u0012$!\u0003)Ign\u001d;b]\u000e,GIR\u0001\fS:\u001cH/\u00198dK\u00123\u0005%\u0001\u0005nCb$U\r\u001d;i+\u0005\u0019\u0006C\u0001\u0011U\u0013\t)\u0016EA\u0002J]R\f\u0011\"\\1y\t\u0016\u0004H\u000f\u001b\u0011\u0002\rqJg.\u001b;?)\u0015I6\fX/_!\tQ\u0006!D\u0001\u0014\u0011\u0015A\u0013\u00021\u0001+\u0011\u0015Q\u0014\u00021\u0001=\u0011\u0015y\u0015\u00021\u0001=\u0011\u0015\t\u0016\u00021\u0001T\u0003=\u0019w.\u001c9vi\u00164U-\u0019;ve\u0016\u001cH#\u0001\u001f\u0002'\u001d,G/\u0014'GK\u0006$XO]3WK\u000e$xN]:\u0002+\u001d,G/\u0014'MS\nd\u0015MY3mK\u0012\u0004v.\u001b8ugV\tA\rE\u0002fQ*l\u0011A\u001a\u0006\u0003O:\n1A\u001d3e\u0013\tIgMA\u0002S\t\u0012\u0003\"a\u001b9\u000e\u00031T!!\u001c8\u0002\u0015I,wM]3tg&|gN\u0003\u0002p]\u0005)Q\u000e\u001c7jE&\u0011\u0011\u000f\u001c\u0002\r\u0019\u0006\u0014W\r\\3e!>Lg\u000e^\u0001\u001a%\u00123e)Y:u)J,Wm\u0012:ba\"\\UM\u001d8fY~3(\u0007\u0005\u0002[\u001dM\u0019abH\u0013\u0015\u0003M\fQ!\u00199qYf$R!\u0017=zunDQ\u0001\u000b\tA\u0002)BQA\u000f\tA\u0002qBQa\u0014\tA\u0002qBQ!\u0015\tA\u0002M\u000b1B]3bIJ+7o\u001c7wKR\ta\u0010E\u0002��\u0003\u0013i!!!\u0001\u000b\t\u0005\r\u0011QA\u0001\u0005Y\u0006twM\u0003\u0002\u0002\b\u0005!!.\u0019<b\u0013\u0011\tY!!\u0001\u0003\r=\u0013'.Z2u\u0001")
/* loaded from: input_file:net/sansa_stack/ml/spark/kernel/RDFFastTreeGraphKernel_v2.class */
public class RDFFastTreeGraphKernel_v2 implements Serializable {
    private final transient SparkSession sparkSession;
    private final Dataset<Row> tripleDF;
    private final Dataset<Row> instanceDF;
    private final int maxDepth;

    public static RDFFastTreeGraphKernel_v2 apply(SparkSession sparkSession, Dataset<Row> dataset, Dataset<Row> dataset2, int i) {
        return RDFFastTreeGraphKernel_v2$.MODULE$.apply(sparkSession, dataset, dataset2, i);
    }

    public SparkSession sparkSession() {
        return this.sparkSession;
    }

    public Dataset<Row> tripleDF() {
        return this.tripleDF;
    }

    public Dataset<Row> instanceDF() {
        return this.instanceDF;
    }

    public int maxDepth() {
        return this.maxDepth;
    }

    public Dataset<Row> computeFeatures() {
        SQLContext sqlContext = sparkSession().sqlContext();
        tripleDF().cache();
        instanceDF().createOrReplaceTempView("instances");
        tripleDF().createOrReplaceTempView("triples");
        ObjectRef create = ObjectRef.create(sqlContext.sql("SELECT instance, label, '' as path, instance as object FROM instances"));
        ((Dataset) create.elem).createOrReplaceTempView("df");
        RichInt$.MODULE$.to$extension0(Predef$.MODULE$.intWrapper(1), maxDepth()).foreach$mVc$sp(i -> {
            Dataset sql = sqlContext.sql("SELECT instance, label, CONCAT(df.path, ',', t.predicate, ',', t.object) AS path, t.object FROM df LEFT JOIN triples t WHERE df.object = t.subject");
            create.elem = ((Dataset) create.elem).union(sql);
            sql.createOrReplaceTempView("df");
        });
        Dataset df = new StringIndexer().setInputCol("path").setOutputCol("pathIndex").fit((Dataset) create.elem).transform((Dataset) create.elem).drop("path").drop("object").selectExpr(Predef$.MODULE$.wrapRefArray(new String[]{"instance", "label", "cast(pathIndex as string) pathIndex"})).orderBy("instance", Predef$.MODULE$.wrapRefArray(new String[0])).groupBy("instance", Predef$.MODULE$.wrapRefArray(new String[]{"label"})).agg(functions$.MODULE$.collect_list("pathIndex").as("paths"), Predef$.MODULE$.wrapRefArray(new Column[0])).toDF(Predef$.MODULE$.wrapRefArray(new String[]{"instance", "label", "paths"}));
        return new CountVectorizer().setInputCol("paths").setOutputCol("features").fit(df).transform(df);
    }

    public Dataset<Row> getMLFeatureVectors() {
        return computeFeatures().drop("instance").drop("paths");
    }

    public RDD<LabeledPoint> getMLLibLabeledPoints() {
        return MLUtils$.MODULE$.convertVectorColumnsFromML(computeFeatures().drop("instance").drop("paths"), Predef$.MODULE$.wrapRefArray(new String[]{"features"})).rdd().map(row -> {
            return new LabeledPoint(row.getDouble(0), (SparseVector) row.getAs(1));
        }, ClassTag$.MODULE$.apply(LabeledPoint.class));
    }

    public RDFFastTreeGraphKernel_v2(SparkSession sparkSession, Dataset<Row> dataset, Dataset<Row> dataset2, int i) {
        this.sparkSession = sparkSession;
        this.tripleDF = dataset;
        this.instanceDF = dataset2;
        this.maxDepth = i;
    }
}
