package com.intel.analytics.bigdl.utils.intermediate;

import com.intel.analytics.bigdl.nn.Graph;
import com.intel.analytics.bigdl.nn.Graph$;
import com.intel.analytics.bigdl.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.nn.mkldnn.BlasWrapper;
import com.intel.analytics.bigdl.nn.mkldnn.DnnGraph$;
import com.intel.analytics.bigdl.nn.mkldnn.InputWrapper;
import com.intel.analytics.bigdl.nn.mkldnn.Output$;
import com.intel.analytics.bigdl.tensor.FloatType$;
import com.intel.analytics.bigdl.tensor.TensorDataType;
import com.intel.analytics.bigdl.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.tensor.TensorNumericMath$TensorNumeric$NumericFloat$;
import com.intel.analytics.bigdl.utils.Engine$;
import com.intel.analytics.bigdl.utils.EngineType;
import com.intel.analytics.bigdl.utils.MklBlas$;
import com.intel.analytics.bigdl.utils.MklDnn$;
import com.intel.analytics.bigdl.utils.Node;
import scala.Array$;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.mutable.ArrayBuffer;
import scala.collection.mutable.ArrayOps;
import scala.collection.mutable.HashMap;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: IRConverter.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005\u001db!\u0002\t\u0012\u0001Ui\u0002\u0002C\u0013\u0001\u0005\u0003\u0005\u000b\u0011B\u0014\t\u0011Y\u0002!1!Q\u0001\f]B\u0001\"\u0010\u0001\u0003\u0002\u0003\u0006YA\u0010\u0005\u0006)\u0002!\t!\u0016\u0005\b7\u0002\u0011\r\u0011\"\u0003]\u0011\u0019a\u0007\u0001)A\u0005;\"9Q\u000e\u0001b\u0001\n\u0013q\u0007B\u0002:\u0001A\u0003%q\u000eC\u0004t\u0001\t\u0007I\u0011\u00028\t\rQ\u0004\u0001\u0015!\u0003p\u0011\u0015)\b\u0001\"\u0003w\u0011\u0015Q\b\u0001\"\u0003|\u0011\u001d\t\u0019\u0002\u0001C\u0001\u0003+Aq!a\t\u0001\t\u0013\t)\u0002C\u0004\u0002&\u0001!I!!\u0006\u0003\u0017%\u00136i\u001c8wKJ$XM\u001d\u0006\u0003%M\tA\"\u001b8uKJlW\rZ5bi\u0016T!\u0001F\u000b\u0002\u000bU$\u0018\u000e\\:\u000b\u0005Y9\u0012!\u00022jO\u0012d'B\u0001\r\u001a\u0003%\tg.\u00197zi&\u001c7O\u0003\u0002\u001b7\u0005)\u0011N\u001c;fY*\tA$A\u0002d_6,\"AH\u0017\u0014\u0005\u0001y\u0002C\u0001\u0011$\u001b\u0005\t#\"\u0001\u0012\u0002\u000bM\u001c\u0017\r\\1\n\u0005\u0011\n#AB!osJ+g-A\u0004J%\u001e\u0014\u0018\r\u001d5\u0004\u0001A\u0019\u0001&K\u0016\u000e\u0003EI!AK\t\u0003\u000f%\u0013vI]1qQB\u0011A&\f\u0007\u0001\t\u0015q\u0003A1\u00010\u0005\u0005!\u0016C\u0001\u00194!\t\u0001\u0013'\u0003\u00023C\t9aj\u001c;iS:<\u0007C\u0001\u00115\u0013\t)\u0014EA\u0002B]f\f!\"\u001a<jI\u0016t7-\u001a\u00132!\rA4hK\u0007\u0002s)\u0011!(I\u0001\be\u00164G.Z2u\u0013\ta\u0014H\u0001\u0005DY\u0006\u001c8\u000fV1h\u0003\t)g\u000fE\u0002@#.r!\u0001\u0011(\u000f\u0005\u0005ceB\u0001\"L\u001d\t\u0019%J\u0004\u0002E\u0013:\u0011Q\tS\u0007\u0002\r*\u0011qIJ\u0001\u0007yI|w\u000e\u001e \n\u0003qI!AG\u000e\n\u0005aI\u0012B\u0001\f\u0018\u0013\tiU#\u0001\u0004uK:\u001cxN]\u0005\u0003\u001fB\u000b\u0011\u0003V3og>\u0014h*^7fe&\u001cW*\u0019;i\u0015\tiU#\u0003\u0002S'\niA+\u001a8t_JtU/\\3sS\u000eT!a\u0014)\u0002\rqJg.\u001b;?)\t1&\fF\u0002X1f\u00032\u0001\u000b\u0001,\u0011\u00151D\u0001q\u00018\u0011\u0015iD\u0001q\u0001?\u0011\u0015)C\u00011\u0001(\u0003!\tG\u000e\u001c(pI\u0016\u001cX#A/\u0011\u0007y\u001bW-D\u0001`\u0015\t\u0001\u0017-A\u0004nkR\f'\r\\3\u000b\u0005\t\f\u0013AC2pY2,7\r^5p]&\u0011Am\u0018\u0002\f\u0003J\u0014\u0018-\u001f\"vM\u001a,'\u000fE\u0002gO&l\u0011aE\u0005\u0003QN\u0011AAT8eKB\u0019\u0001F[\u0016\n\u0005-\f\"!C%S\u000b2,W.\u001a8u\u0003%\tG\u000e\u001c(pI\u0016\u001c\b%\u0001\u0005je&s\u0007/\u001e;t+\u0005y\u0007c\u0001\u0011qK&\u0011\u0011/\t\u0002\u0006\u0003J\u0014\u0018-_\u0001\nSJLe\u000e];ug\u0002\n\u0011\"\u001b:PkR\u0004X\u000f^:\u0002\u0015%\u0014x*\u001e;qkR\u001c\b%\u0001\u0003j]&$H#A<\u0011\u0005\u0001B\u0018BA=\"\u0005\u0011)f.\u001b;\u0002\u0011\u001d,GOT8eKN$Ba\u001e?\u0002\u0010!)Q\u0010\u0004a\u0001}\u00061\u0011N\u001c9viN\u0004Ba`A\u0005K:!\u0011\u0011AA\u0003\u001d\r)\u00151A\u0005\u0002E%\u0019\u0011qA\u0011\u0002\u000fA\f7m[1hK&!\u00111BA\u0007\u0005\r\u0019V-\u001d\u0006\u0004\u0003\u000f\t\u0003BBA\t\u0019\u0001\u0007Q,A\u0006o_\u0012,7OQ;gM\u0016\u0014\u0018a\u0002;p\u000fJ\f\u0007\u000f\u001b\u000b\u0003\u0003/\u0001R!!\u0007\u0002 -j!!a\u0007\u000b\u0007\u0005uQ#\u0001\u0002o]&!\u0011\u0011EA\u000e\u0005\u00159%/\u00199i\u0003)!x\u000e\u00128o\u000fJ\f\u0007\u000f[\u0001\fi>\u0014E.Y:He\u0006\u0004\b\u000e")
/* loaded from: input_file:com/intel/analytics/bigdl/utils/intermediate/IRConverter.class */
public class IRConverter<T> {
    private final IRGraph<T> IRgraph;
    private final ClassTag<T> evidence$1;
    private final TensorNumericMath.TensorNumeric<T> ev;
    private final ArrayBuffer<Node<IRElement<T>>> allNodes = new ArrayBuffer<>();
    private final Node<IRElement<T>>[] irInputs;
    private final Node<IRElement<T>>[] irOutputs;

    private ArrayBuffer<Node<IRElement<T>>> allNodes() {
        return this.allNodes;
    }

    private Node<IRElement<T>>[] irInputs() {
        return this.irInputs;
    }

    private Node<IRElement<T>>[] irOutputs() {
        return this.irOutputs;
    }

    private void init() {
        getNodes(Predef$.MODULE$.wrapRefArray(irInputs()), allNodes());
        new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(irOutputs())).foreach(node -> {
            $anonfun$init$1(this, node);
            return BoxedUnit.UNIT;
        });
    }

    private void getNodes(Seq<Node<IRElement<T>>> seq, ArrayBuffer<Node<IRElement<T>>> arrayBuffer) {
        if (seq.length() == 0) {
            return;
        }
        seq.foreach(node -> {
            $anonfun$getNodes$1(this, arrayBuffer, node);
            return BoxedUnit.UNIT;
        });
    }

    public Graph<T> toGraph() {
        EngineType engineType = Engine$.MODULE$.getEngineType();
        MklBlas$ mklBlas$ = MklBlas$.MODULE$;
        if (engineType != null ? engineType.equals(mklBlas$) : mklBlas$ == null) {
            Predef$.MODULE$.require(IRToBlas$.MODULE$.apply(this.evidence$1, this.ev).convertingCheck((Node[]) allNodes().toArray(ClassTag$.MODULE$.apply(Node.class))), () -> {
                return "IR graph can not be converted to Blas layer";
            });
            return toBlasGraph();
        }
        EngineType engineType2 = Engine$.MODULE$.getEngineType();
        MklDnn$ mklDnn$ = MklDnn$.MODULE$;
        if (engineType2 != null ? !engineType2.equals(mklDnn$) : mklDnn$ != null) {
            throw new UnsupportedOperationException(new StringBuilder(48).append("Only support engineType mkldnn/mklblas, but get ").append(Engine$.MODULE$.getEngineType()).toString());
        }
        Predef$ predef$ = Predef$.MODULE$;
        TensorDataType type = this.ev.getType();
        FloatType$ floatType$ = FloatType$.MODULE$;
        predef$.require(type != null ? type.equals(floatType$) : floatType$ == null, () -> {
            return "Mkldnn engine only supports float data";
        });
        Predef$.MODULE$.require(IRToDnn$.MODULE$.apply(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$).convertingCheck((Node[]) allNodes().toArray(ClassTag$.MODULE$.apply(Node.class))), () -> {
            return "IR graph can not be converted to Dnn layer";
        });
        return toDnnGraph();
    }

    private Graph<T> toDnnGraph() {
        HashMap<Node<IRElement<Object>>, Node<AbstractModule<Activity, Activity, Object>>> convert = IRToDnn$.MODULE$.apply(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$).convert((Node[]) allNodes().toArray(ClassTag$.MODULE$.apply(Node.class)));
        Node[] nodeArr = (Node[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(irInputs())).map(node -> {
            return (Node) convert.get(node).get();
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Node.class)));
        Node[] nodeArr2 = (Node[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(irOutputs())).map(node2 -> {
            return (Node) convert.get(node2).get();
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Node.class)));
        return DnnGraph$.MODULE$.apply(Predef$.MODULE$.wrapRefArray((Node[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(nodeArr)).map(node3 -> {
            Node<T> node3 = new Node<>(new InputWrapper());
            node3.from(node3, node3.from$default$2());
            return node3;
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Node.class)))), Predef$.MODULE$.wrapRefArray((Node[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(nodeArr2)).zipWithIndex(Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class))))).map(tuple2 -> {
            if (tuple2 != null) {
                Node node4 = (Node) tuple2._1();
                int _2$mcI$sp = tuple2._2$mcI$sp();
                if (node4 != null) {
                    return node4.element() instanceof BlasWrapper ? node4 : node4.add(new Node<>(Output$.MODULE$.apply(BoxesRunTime.unboxToInt(this.IRgraph.outputFormats().apply(_2$mcI$sp)), Output$.MODULE$.apply$default$2())), node4.add$default$2());
                }
            }
            throw new MatchError(tuple2);
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Node.class)))), this.IRgraph.variables(), this.IRgraph.generateBackward());
    }

    private Graph<T> toBlasGraph() {
        HashMap<Node<T>, Node<AbstractModule<Activity, Activity, T>>> convert = IRToBlas$.MODULE$.apply(this.evidence$1, this.ev).convert((Node[]) allNodes().toArray(ClassTag$.MODULE$.apply(Node.class)));
        return Graph$.MODULE$.dynamic((Node[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(irInputs())).map(node -> {
            return (Node) convert.get(node).get();
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Node.class))), (Node[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(irOutputs())).map(node2 -> {
            return (Node) convert.get(node2).get();
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Node.class))), this.IRgraph.variables(), this.IRgraph.generateBackward(), this.evidence$1, this.ev);
    }

    public static final /* synthetic */ void $anonfun$init$1(IRConverter iRConverter, Node node) {
        if (iRConverter.allNodes().contains(node)) {
            return;
        }
        iRConverter.allNodes().append(Predef$.MODULE$.wrapRefArray(new Node[]{node}));
    }

    public static final /* synthetic */ void $anonfun$getNodes$1(IRConverter iRConverter, ArrayBuffer arrayBuffer, Node node) {
        if (arrayBuffer.contains(node)) {
            return;
        }
        arrayBuffer.append(Predef$.MODULE$.wrapRefArray(new Node[]{node}));
        iRConverter.getNodes(node.nextNodes(), arrayBuffer);
    }

    public IRConverter(IRGraph<T> iRGraph, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        this.IRgraph = iRGraph;
        this.evidence$1 = classTag;
        this.ev = tensorNumeric;
        this.irInputs = (Node[]) iRGraph.inputs().toArray(ClassTag$.MODULE$.apply(Node.class));
        this.irOutputs = (Node[]) iRGraph.outputs().toArray(ClassTag$.MODULE$.apply(Node.class));
        init();
    }
}
