package com.intel.analytics.bigdl.nn.mkldnn;

import com.intel.analytics.bigdl.nn.MklInt8Convertible;
import com.intel.analytics.bigdl.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.tensor.Tensor;
import com.intel.analytics.bigdl.tensor.Tensor$;
import com.intel.analytics.bigdl.tensor.TensorNumericMath$TensorNumeric$NumericFloat$;
import com.intel.analytics.bigdl.utils.Node;
import scala.Predef$;
import scala.collection.IterableLike;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.collection.generic.GenericTraversableTemplate;
import scala.collection.immutable.$colon;
import scala.collection.immutable.Nil$;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.ArrayOps;
import scala.math.Ordering$Float$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichInt$;

/* compiled from: Fusion.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/nn/mkldnn/Fusion$.class */
public final class Fusion$ {
    public static Fusion$ MODULE$;

    static {
        new Fusion$();
    }

    private boolean fuse() {
        return new StringOps(Predef$.MODULE$.augmentString(System.getProperty("bigdl.mkldnn.fusion", "true"))).toBoolean();
    }

    public void fuseModule(Node<AbstractModule<Activity, Activity, Object>> node) {
        if (fuse()) {
            AbstractModule<Activity, Activity, Object> element = node.element();
            if (element instanceof ReLU) {
                fusionRelu(node);
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else if (!(element instanceof SpatialBatchNormalization)) {
                BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
            } else {
                fusionBN(node);
                BoxedUnit boxedUnit3 = BoxedUnit.UNIT;
            }
        }
    }

    public void fuseCAdd(Node<AbstractModule<Activity, Activity, Object>> node) {
        if (fuse()) {
            if (!(node.element() instanceof CAddTable)) {
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else {
                fusionCAddTable(node);
                BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
            }
        }
    }

    private void fusionBN(Node<AbstractModule<Activity, Activity, Object>> node) {
        SpatialBatchNormalization spatialBatchNormalization = (SpatialBatchNormalization) node.element();
        node.prevNodes().foreach(node2 -> {
            BoxedUnit boxedUnit;
            BoxedUnit boxedUnit2;
            AbstractModule abstractModule = (AbstractModule) node2.element();
            if (abstractModule instanceof SpatialConvolution) {
                SpatialConvolution spatialConvolution = (SpatialConvolution) abstractModule;
                if (spatialConvolution.relu() || spatialConvolution.batchNorm()) {
                    boxedUnit2 = BoxedUnit.UNIT;
                } else {
                    if (spatialBatchNormalization.relu()) {
                        spatialConvolution.setReLU(true);
                    } else {
                        BoxedUnit boxedUnit3 = BoxedUnit.UNIT;
                    }
                    MODULE$.fusionConvBn(spatialConvolution, spatialBatchNormalization);
                    node.element_$eq(Identity$.MODULE$.apply$mFc$sp(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$));
                    boxedUnit2 = BoxedUnit.UNIT;
                }
                boxedUnit = boxedUnit2;
            } else {
                boxedUnit = null;
            }
            return boxedUnit;
        });
    }

    private void fusionRelu(Node<AbstractModule<Activity, Activity, Object>> node) {
        node.prevNodes().foreach(node2 -> {
            BoxedUnit boxedUnit;
            BoxedUnit boxedUnit2;
            BoxedUnit boxedUnit3;
            AbstractModule<Activity, Activity, Object> element = MODULE$.findPrevious(node2).element();
            if (element instanceof SpatialConvolution) {
                SpatialConvolution spatialConvolution = (SpatialConvolution) element;
                if (spatialConvolution.relu()) {
                    boxedUnit3 = BoxedUnit.UNIT;
                } else {
                    spatialConvolution.setReLU(true);
                    spatialConvolution.setOutputScales(((ReLU) node.element()).getOutputScales());
                    node.element_$eq(Identity$.MODULE$.apply$mFc$sp(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$));
                    boxedUnit3 = BoxedUnit.UNIT;
                }
                boxedUnit = boxedUnit3;
            } else if (element instanceof SpatialBatchNormalization) {
                SpatialBatchNormalization spatialBatchNormalization = (SpatialBatchNormalization) element;
                if (spatialBatchNormalization.relu()) {
                    boxedUnit2 = BoxedUnit.UNIT;
                } else {
                    spatialBatchNormalization.setReLU(true);
                    spatialBatchNormalization.setOutputScales(((ReLU) node.element()).getOutputScales());
                    node.element_$eq(Identity$.MODULE$.apply$mFc$sp(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$));
                    boxedUnit2 = BoxedUnit.UNIT;
                }
                boxedUnit = boxedUnit2;
            } else {
                boxedUnit = null;
            }
            return boxedUnit;
        });
    }

    private Node<AbstractModule<Activity, Activity, Object>> findPrevious(Node<AbstractModule<Activity, Activity, Object>> node) {
        while ((node.element() instanceof Identity) && node.prevNodes().length() == 1) {
            node = (Node) node.prevNodes().apply(0);
        }
        return node;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Seq<Node<AbstractModule<Activity, Activity, Object>>> findNext(Node<AbstractModule<Activity, Activity, Object>> node) {
        return node.element() instanceof Identity ? (Seq) node.nextNodes().flatMap(node2 -> {
            return MODULE$.findNext(node2);
        }, Seq$.MODULE$.canBuildFrom()) : new $colon.colon<>(node, Nil$.MODULE$);
    }

    private void fusionCAddTable(Node<AbstractModule<Activity, Activity, Object>> node) {
        if ((node.element() instanceof CAddTable) && node.prevNodes().length() == 2) {
            Node<AbstractModule<Activity, Activity, Object>>[] nodeArr = (Node[]) node.prevNodes().toArray(ClassTag$.MODULE$.apply(Node.class));
            Node<AbstractModule<Activity, Activity, Object>> findPrevious = findPrevious(nodeArr[0]);
            Node<AbstractModule<Activity, Activity, Object>> findPrevious2 = findPrevious(nodeArr[1]);
            Node<AbstractModule<Activity, Activity, Object>> node2 = null;
            int i = 0;
            if (findPrevious.element() instanceof SpatialConvolution) {
                if (requirements(findPrevious)) {
                    node2 = findPrevious;
                }
                i = 1;
            } else if (findPrevious2.element() instanceof SpatialConvolution) {
                if (requirements(findPrevious2)) {
                    node2 = findPrevious2;
                }
                i = 0;
            }
            if (node2 != null) {
                node.element_$eq(node2.element());
                SpatialConvolution spatialConvolution = (SpatialConvolution) node.element();
                spatialConvolution.setSumOp(nodeArr[i].element(), i + 1);
                node2.element_$eq(Identity$.MODULE$.apply$mFc$sp(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$));
                Node node3 = (Node) node.nextNodes().apply(0);
                if ((node3.element() instanceof ReLU) && !spatialConvolution.relu()) {
                    ((SpatialConvolution) node.element()).setReLU(true);
                    ((SpatialConvolution) node.element()).setOutputScales(((ReLU) node3.element()).getOutputScales());
                    node3.element_$eq(new Identity());
                }
                Node<AbstractModule<Activity, Activity, Object>> findPrevious3 = findPrevious(nodeArr[i]);
                AbstractModule<Activity, Activity, Object> element = findPrevious3.element();
                if (element instanceof SpatialConvolution) {
                    ((SpatialConvolution) element).setOutputScales(((SpatialConvolution) node.element()).getOutputScales());
                    BoxedUnit boxedUnit = BoxedUnit.UNIT;
                } else {
                    if (!(element instanceof ReLU)) {
                        BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
                        return;
                    }
                    ((ReLU) element).setOutputScales(((SpatialConvolution) node.element()).getOutputScales());
                    ((IterableLike) ((TraversableLike) findPrevious3.nextNodes().flatMap(node4 -> {
                        return MODULE$.findNext(node4);
                    }, Seq$.MODULE$.canBuildFrom())).filter(node5 -> {
                        return BoxesRunTime.boxToBoolean($anonfun$fusionCAddTable$2(node, node5));
                    })).foreach(node6 -> {
                        $anonfun$fusionCAddTable$3(node, node6);
                        return BoxedUnit.UNIT;
                    });
                    BoxedUnit boxedUnit3 = BoxedUnit.UNIT;
                }
            }
        }
    }

    private boolean requirements(Node<AbstractModule<Activity, Activity, Object>> node) {
        return !((SpatialConvolution) node.element()).sum();
    }

    private void fusionConvBn(SpatialConvolution spatialConvolution, SpatialBatchNormalization spatialBatchNormalization) {
        spatialConvolution.setBatchNorm(true);
        Tensor apply = Tensor$.MODULE$.apply(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
        Tensor copy = apply.resize(spatialBatchNormalization.runningVariance().size(), apply.resize$default$2()).copy(spatialBatchNormalization.runningVariance().dense());
        Tensor apply2 = Tensor$.MODULE$.apply(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
        Tensor copy2 = apply2.resize(spatialBatchNormalization.runningMean().size(), apply2.resize$default$2()).copy(spatialBatchNormalization.runningMean().dense());
        Tensor apply3 = Tensor$.MODULE$.apply(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
        Tensor<Object> copy3 = apply3.resize(spatialConvolution.weight().size(), apply3.resize$default$2()).copy(spatialConvolution.weight().dense());
        Tensor apply4 = Tensor$.MODULE$.apply(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
        Tensor<Object> copy4 = apply4.resize(spatialConvolution.bias().size(), apply4.resize$default$2()).copy(spatialConvolution.bias().dense());
        Tensor copy5 = Tensor$.MODULE$.apply(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$).resizeAs(spatialBatchNormalization.weightAndBias().dense()).copy(spatialBatchNormalization.weightAndBias().dense());
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), spatialBatchNormalization.nOutput()).foreach$mVc$sp(i -> {
            Tensor select;
            float sqrt = (float) Math.sqrt(((float[]) copy.storage().array())[(i + copy.storageOffset()) - 1] + spatialBatchNormalization.eps());
            Predef$.MODULE$.require(((double) sqrt) != 0.0d, () -> {
                return new StringBuilder(33).append("the eps of ").append(spatialBatchNormalization.getName()).append(" should be more than 0").toString();
            });
            float f = ((float[]) copy5.storage().array())[(copy5.storageOffset() - 1) + i];
            float f2 = ((float[]) copy5.storage().array())[(copy5.storageOffset() - 1) + spatialBatchNormalization.nOutput() + i];
            if (spatialConvolution.nGroup() == 1) {
                select = copy3.select(1, i + 1);
            } else {
                int nOutputPlane = spatialConvolution.nOutputPlane() / spatialConvolution.nGroup();
                select = copy3.select(1, (i / nOutputPlane) + 1).select(2, (i % nOutputPlane) + 1);
            }
            Tensor tensor = select;
            tensor.div((Tensor) BoxesRunTime.boxToFloat(sqrt));
            tensor.mul(BoxesRunTime.boxToFloat(f));
            ((float[]) copy4.storage().array())[i] = (((f / sqrt) * ((float[]) copy4.storage().array())[i]) + f2) - ((f * ((float[]) copy2.storage().array())[i]) / sqrt);
        });
        spatialConvolution.weight().dense().set(copy3);
        spatialConvolution.bias().dense().set(copy4);
        spatialConvolution.flushWeightScales(spatialConvolution.weight().dense());
        spatialConvolution.setOutputScales(spatialBatchNormalization.getOutputScales());
    }

    public void setNegativeInputOfConv(Node<AbstractModule<Activity, Activity, Object>> node) {
        if (fuse() && (node.element() instanceof SpatialConvolution) && ((IterableLike) ((TraversableLike) node.prevNodes().flatMap(node2 -> {
            return MODULE$.findAllNonIdentityPrevs(node2);
        }, Seq$.MODULE$.canBuildFrom())).map(node3 -> {
            return BoxesRunTime.boxToBoolean($anonfun$setNegativeInputOfConv$2(node3));
        }, Seq$.MODULE$.canBuildFrom())).forall(obj -> {
            return BoxesRunTime.boxToBoolean($anonfun$setNegativeInputOfConv$3(BoxesRunTime.unboxToBoolean(obj)));
        })) {
            ((SpatialConvolution) node.element()).negativeInput_$eq(false);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public void setScalesPrevousJoinTable(Node<AbstractModule<Activity, Activity, Object>> node) {
        if (fuse() && (node.element() instanceof JoinTable)) {
            Seq seq = (Seq) ((TraversableLike) ((TraversableLike) node.prevNodes().flatMap(node2 -> {
                return MODULE$.findAllNonIdentityPrevs(node2);
            }, Seq$.MODULE$.canBuildFrom())).filter(node3 -> {
                return BoxesRunTime.boxToBoolean($anonfun$setScalesPrevousJoinTable$2(node3));
            })).map(node4 -> {
                return (SpatialConvolution) node4.element();
            }, Seq$.MODULE$.canBuildFrom());
            if (seq.exists(spatialConvolution -> {
                return BoxesRunTime.boxToBoolean(spatialConvolution.needQuantize());
            })) {
                Predef$.MODULE$.require(((TraversableOnce) seq.map(spatialConvolution2 -> {
                    return BoxesRunTime.boxToInteger(spatialConvolution2.getOutputDimMask());
                }, Seq$.MODULE$.canBuildFrom())).toSet().size() == 1, () -> {
                    return "all preceding convolutions must have the same mask";
                });
                Seq seq2 = (Seq) ((TraversableLike) node.nextNodes().flatMap(node5 -> {
                    return MODULE$.findNext(node5);
                }, Seq$.MODULE$.canBuildFrom())).filter(node6 -> {
                    return BoxesRunTime.boxToBoolean($anonfun$setScalesPrevousJoinTable$8(node6));
                });
                float[][] inputScales = seq2.isEmpty() ? (float[][]) ((Object[]) new float[]{(float[]) ((TraversableOnce) ((GenericTraversableTemplate) seq.map(spatialConvolution3 -> {
                    return (float[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(spatialConvolution3.getOutputScales())).flatten(fArr -> {
                        return Predef$.MODULE$.wrapFloatArray(fArr);
                    }, ClassTag$.MODULE$.Float());
                }, Seq$.MODULE$.canBuildFrom())).transpose(fArr -> {
                    return new ArrayOps.ofFloat($anonfun$setScalesPrevousJoinTable$11(fArr));
                }).map(seq3 -> {
                    return BoxesRunTime.boxToFloat($anonfun$setScalesPrevousJoinTable$12(seq3));
                }, Seq$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.Float())}) : ((MklInt8Convertible) ((IterableLike) seq2.map(node7 -> {
                    return (SpatialConvolution) node7.element();
                }, Seq$.MODULE$.canBuildFrom())).head()).getInputScales();
                seq.foreach(spatialConvolution4 -> {
                    spatialConvolution4.setOutputScales(inputScales);
                    return BoxedUnit.UNIT;
                });
            }
        }
    }

    public void fuseScale(Node<AbstractModule<Activity, Activity, Object>> node) {
        AbstractModule<Activity, Activity, Object> element = node.element();
        if (!(element instanceof BlasWrapper) || !(((BlasWrapper) element).module() instanceof com.intel.analytics.bigdl.nn.Scale)) {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        } else if (node.prevNodes().forall(node2 -> {
            return BoxesRunTime.boxToBoolean($anonfun$fuseScale$1(node2));
        })) {
            node.prevNodes().foreach(node3 -> {
                SpatialBatchNormalization spatialBatchNormalization = (SpatialBatchNormalization) node3.element();
                Tensor<Object> dense = spatialBatchNormalization.weightAndBias().dense();
                Tensor<Object> narrow = dense.narrow(1, 1, spatialBatchNormalization.nOutput());
                Tensor<Object> narrow2 = dense.narrow(1, spatialBatchNormalization.nOutput() + 1, spatialBatchNormalization.nOutput());
                com.intel.analytics.bigdl.nn.Scale scale = (com.intel.analytics.bigdl.nn.Scale) ((BlasWrapper) node.element()).module();
                Tensor<Object> tensor = ((Tensor[]) scale.parameters()._1())[0];
                Tensor<Object> tensor2 = ((Tensor[]) scale.parameters()._1())[1];
                narrow.cmul(tensor);
                narrow2.cmul(tensor);
                narrow2.add(tensor2);
                return spatialBatchNormalization.weightAndBias().dense().set(dense);
            });
            node.element_$eq(Identity$.MODULE$.apply$mFc$sp(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$));
            BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Seq<Node<AbstractModule<Activity, Activity, Object>>> findAllNonIdentityPrevs(Node<AbstractModule<Activity, Activity, Object>> node) {
        return ((node.element() instanceof Identity) || (node.element() instanceof MaxPooling) || (node.element() instanceof AvgPooling) || (node.element() instanceof JoinTable)) ? (Seq) node.prevNodes().flatMap(node2 -> {
            return MODULE$.findAllNonIdentityPrevs(node2);
        }, Seq$.MODULE$.canBuildFrom()) : new $colon.colon<>(node, Nil$.MODULE$);
    }

    public static final /* synthetic */ boolean $anonfun$fusionCAddTable$2(Node node, Node node2) {
        if (node2 != null ? !node2.equals(node) : node != null) {
            if (node2.element() instanceof MklInt8Convertible) {
                return true;
            }
        }
        return false;
    }

    public static final /* synthetic */ void $anonfun$fusionCAddTable$3(Node node, Node node2) {
        ((MklInt8Convertible) node2.element()).setInputScales(((SpatialConvolution) node.element()).getOutputScales());
    }

    public static final /* synthetic */ boolean $anonfun$setNegativeInputOfConv$2(Node node) {
        AbstractModule abstractModule = (AbstractModule) node.element();
        return abstractModule instanceof SpatialConvolution ? ((SpatialConvolution) node.element()).relu() : abstractModule instanceof ReLU;
    }

    public static final /* synthetic */ boolean $anonfun$setNegativeInputOfConv$3(boolean z) {
        return z;
    }

    public static final /* synthetic */ boolean $anonfun$setScalesPrevousJoinTable$2(Node node) {
        return node.element() instanceof SpatialConvolution;
    }

    public static final /* synthetic */ boolean $anonfun$setScalesPrevousJoinTable$8(Node node) {
        return node.element() instanceof SpatialConvolution;
    }

    public static final /* synthetic */ float[] $anonfun$setScalesPrevousJoinTable$11(float[] fArr) {
        return Predef$.MODULE$.floatArrayOps(fArr);
    }

    public static final /* synthetic */ float $anonfun$setScalesPrevousJoinTable$12(Seq seq) {
        return BoxesRunTime.unboxToFloat(seq.max(Ordering$Float$.MODULE$));
    }

    public static final /* synthetic */ boolean $anonfun$fuseScale$1(Node node) {
        return node.element() instanceof SpatialBatchNormalization;
    }

    private Fusion$() {
        MODULE$ = this;
    }
}
