package com.intel.analytics.bigdl.nn.mkldnn;

import com.intel.analytics.bigdl.nn.MklInt8Convertible;
import com.intel.analytics.bigdl.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.tensor.Tensor;
import com.intel.analytics.bigdl.tensor.TensorNumericMath$TensorNumeric$NumericFloat$;
import com.intel.analytics.bigdl.utils.T$;
import com.intel.analytics.bigdl.utils.Table;
import scala.Array$;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.mutable.ArrayOps;
import scala.math.Ordering$Float$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.ScalaRunTime$;

/* compiled from: Utils.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/nn/mkldnn/Utils$.class */
public final class Utils$ {
    public static Utils$ MODULE$;

    static {
        new Utils$();
    }

    public void copyMaskAndScales(MemoryData memoryData, MemoryData memoryData2) {
        if (memoryData == null || memoryData2 == null || !new ArrayOps.ofFloat(Predef$.MODULE$.floatArrayOps(memoryData2.scales())).isEmpty()) {
            return;
        }
        memoryData2.setScales((float[]) memoryData.scales().clone());
        memoryData2.setMask(memoryData.mask());
    }

    public void copyMaskAndScales(MemoryData[] memoryDataArr, MemoryData[] memoryDataArr2) {
        if (memoryDataArr == null || memoryDataArr2 == null) {
            return;
        }
        boolean z = memoryDataArr.length == 1 || memoryDataArr2.length == 1 || memoryDataArr.length == memoryDataArr2.length;
        boolean z2 = memoryDataArr != memoryDataArr2 && new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(memoryDataArr)).forall(memoryData -> {
            return BoxesRunTime.boxToBoolean($anonfun$copyMaskAndScales$1(memoryData));
        }) && new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(memoryDataArr2)).forall(memoryData2 -> {
            return BoxesRunTime.boxToBoolean($anonfun$copyMaskAndScales$2(memoryData2));
        });
        if (z && z2) {
            if (memoryDataArr.length == memoryDataArr2.length) {
                new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(memoryDataArr2)).zip(Predef$.MODULE$.wrapRefArray(memoryDataArr), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class))))).foreach(tuple2 -> {
                    $anonfun$copyMaskAndScales$3(tuple2);
                    return BoxedUnit.UNIT;
                });
                return;
            }
            if (memoryDataArr2.length == 1) {
                ((MemoryData) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(memoryDataArr2)).head()).setScales((float[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(memoryDataArr)).map(memoryData3 -> {
                    return memoryData3.scales();
                }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Float.TYPE)))))).transpose(Predef$.MODULE$.$conforms()))).map(fArr -> {
                    return BoxesRunTime.boxToFloat($anonfun$copyMaskAndScales$5(fArr));
                }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Float())));
                Predef$.MODULE$.require(((int[]) new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps((int[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(memoryDataArr)).map(memoryData4 -> {
                    return BoxesRunTime.boxToInteger(memoryData4.mask());
                }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Int())))).distinct()).length == 1, () -> {
                    return "only support the same mask";
                });
                ((MemoryData) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(memoryDataArr2)).head()).setMask(BoxesRunTime.unboxToInt(new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps((int[]) new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps((int[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(memoryDataArr)).map(memoryData5 -> {
                    return BoxesRunTime.boxToInteger(memoryData5.mask());
                }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Int())))).distinct())).head()));
            } else if (memoryDataArr2.length > 1) {
                new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(memoryDataArr2)).foreach(memoryData6 -> {
                    $anonfun$copyMaskAndScales$9(memoryDataArr, memoryData6);
                    return BoxedUnit.UNIT;
                });
                new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(memoryDataArr2)).foreach(memoryData7 -> {
                    $anonfun$copyMaskAndScales$10(memoryDataArr, memoryData7);
                    return BoxedUnit.UNIT;
                });
            }
        }
    }

    public int getDefaultFormat(MemoryData memoryData, boolean z) {
        switch (memoryData.shape().length) {
            case 2:
                return z ? 4 : 12;
            case 4:
                return z ? 7 : 16;
            default:
                throw new UnsupportedOperationException("Linear only supports 2-D or 4-D");
        }
    }

    public boolean getDefaultFormat$default$2() {
        return true;
    }

    private Tensor<Object> denseTensor(MemoryData memoryData, Tensor<Object> tensor, boolean z, MklDnnRuntime mklDnnRuntime) {
        HeapData heapData = new HeapData(memoryData.shape(), getDefaultFormat(memoryData, z), HeapData$.MODULE$.apply$default$3());
        MemoryData apply$default$2 = ReorderMemory$.MODULE$.apply$default$2();
        ReorderMemory apply = ReorderMemory$.MODULE$.apply(heapData, apply$default$2, ReorderMemory$.MODULE$.apply$default$3(heapData, apply$default$2));
        apply.setRuntime(mklDnnRuntime);
        apply.initFwdPrimitives(new MemoryData[]{memoryData}, Phase$InferencePhase$.MODULE$);
        return apply.forward(tensor).toTensor(TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
    }

    private boolean denseTensor$default$3() {
        return true;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Activity denseActivity(MemoryData[] memoryDataArr, Activity activity, boolean z, MklDnnRuntime mklDnnRuntime) {
        Tensor<Object> denseTensor;
        if (memoryDataArr.length > 1) {
            Predef$.MODULE$.require(memoryDataArr.length == activity.toTable().length(), () -> {
                return "formats should be the same as activity";
            });
            Table apply = T$.MODULE$.apply();
            int i = 1;
            while (true) {
                int i2 = i;
                if (i2 > memoryDataArr.length) {
                    break;
                }
                apply.update(BoxesRunTime.boxToInteger(i2), denseTensor(memoryDataArr[i2 - 1], (Tensor) activity.toTable().get(BoxesRunTime.boxToInteger(i2)).get(), z, mklDnnRuntime));
                i = i2 + 1;
            }
            denseTensor = apply;
        } else {
            denseTensor = denseTensor(memoryDataArr[0], activity.toTensor(TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$), z, mklDnnRuntime);
        }
        return denseTensor;
    }

    private boolean denseActivity$default$3() {
        return true;
    }

    public Activity getDenseIn(MklInt8Convertible mklInt8Convertible, Activity activity) {
        if (!(mklInt8Convertible instanceof MklDnnModule)) {
            return activity;
        }
        MklDnnModule mklDnnModule = (MklDnnModule) mklInt8Convertible;
        return denseActivity(mklDnnModule.inputFormats(), activity, true, mklDnnModule.getRuntime());
    }

    public Activity getDenseOut(MklInt8Convertible mklInt8Convertible, Activity activity) {
        if (!(mklInt8Convertible instanceof MklDnnModule)) {
            return activity;
        }
        MklDnnModule mklDnnModule = (MklDnnModule) mklInt8Convertible;
        return denseActivity(mklDnnModule.outputFormats(), activity, true, mklDnnModule.getRuntime());
    }

    private void setConvNegativeInput(MklInt8Convertible mklInt8Convertible, Activity activity) {
        if (mklInt8Convertible instanceof SpatialConvolution) {
            SpatialConvolution spatialConvolution = (SpatialConvolution) mklInt8Convertible;
            if (BoxesRunTime.unboxToFloat(getDenseIn(mklInt8Convertible, activity).toTensor(TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$).mo1126min()) >= 0.0f) {
                spatialConvolution.negativeInput_$eq(false);
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public void calcScales(AbstractModule<?, ?, ?> abstractModule, Activity activity) {
        if (!(abstractModule instanceof MklInt8Convertible)) {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
            return;
        }
        ((MklInt8Convertible) abstractModule).calcScales(activity);
        setConvNegativeInput((MklInt8Convertible) abstractModule, activity);
        BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v5, types: [com.intel.analytics.bigdl.nn.abstractnn.Activity] */
    public Activity getOutput(AbstractModule<?, ?, ?> abstractModule, Activity activity) {
        return abstractModule instanceof MklDnnModule ? abstractModule.output() : abstractModule.output();
    }

    public static final /* synthetic */ boolean $anonfun$copyMaskAndScales$1(MemoryData memoryData) {
        return new ArrayOps.ofFloat(Predef$.MODULE$.floatArrayOps(memoryData.scales())).nonEmpty();
    }

    public static final /* synthetic */ boolean $anonfun$copyMaskAndScales$2(MemoryData memoryData) {
        return new ArrayOps.ofFloat(Predef$.MODULE$.floatArrayOps(memoryData.scales())).isEmpty();
    }

    public static final /* synthetic */ void $anonfun$copyMaskAndScales$3(Tuple2 tuple2) {
        if (new ArrayOps.ofFloat(Predef$.MODULE$.floatArrayOps(((MemoryData) tuple2._1()).scales())).isEmpty()) {
            ((MemoryData) tuple2._1()).setScales(((MemoryData) tuple2._2()).scales());
            ((MemoryData) tuple2._1()).setMask(((MemoryData) tuple2._2()).mask());
        }
    }

    public static final /* synthetic */ float $anonfun$copyMaskAndScales$5(float[] fArr) {
        return BoxesRunTime.unboxToFloat(new ArrayOps.ofFloat(Predef$.MODULE$.floatArrayOps(fArr)).max(Ordering$Float$.MODULE$));
    }

    public static final /* synthetic */ void $anonfun$copyMaskAndScales$9(MemoryData[] memoryDataArr, MemoryData memoryData) {
        memoryData.setScales(((MemoryData) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(memoryDataArr)).head()).scales());
    }

    public static final /* synthetic */ void $anonfun$copyMaskAndScales$10(MemoryData[] memoryDataArr, MemoryData memoryData) {
        memoryData.setMask(((MemoryData) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(memoryDataArr)).head()).mask());
    }

    private Utils$() {
        MODULE$ = this;
    }
}
