package com.intel.analytics.bigdl.utils.tf;

import com.intel.analytics.bigdl.nn.SpatialBatchNormalization;
import com.intel.analytics.bigdl.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.tensor.Tensor;
import com.intel.analytics.bigdl.tensor.Tensor$;
import com.intel.analytics.bigdl.tensor.TensorNumericMath$TensorNumeric$NumericInt$;
import java.nio.ByteOrder;
import org.tensorflow.framework.NodeDef;
import scala.Predef$;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichInt$;

/* compiled from: BigDLToTensorflow.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/utils/tf/BatchNorm2DToTF$.class */
public final class BatchNorm2DToTF$ implements BigDLToTensorflow {
    public static BatchNorm2DToTF$ MODULE$;

    static {
        new BatchNorm2DToTF$();
    }

    @Override // com.intel.analytics.bigdl.utils.tf.BigDLToTensorflow
    public Seq<NodeDef> toTFDef(AbstractModule<?, ?, ?> abstractModule, Seq<NodeDef> seq, ByteOrder byteOrder) {
        Predef$.MODULE$.require(seq.length() == 1, () -> {
            return "BatchNorm only accept one input";
        });
        SpatialBatchNormalization spatialBatchNormalization = (SpatialBatchNormalization) abstractModule;
        Predef$.MODULE$.require(!spatialBatchNormalization.isTraining(), () -> {
            return "Only support evaluate mode batch norm";
        });
        Tensor<?> apply = Tensor$.MODULE$.apply(spatialBatchNormalization.nDim(), ClassTag$.MODULE$.Int(), TensorNumericMath$TensorNumeric$NumericInt$.MODULE$);
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), spatialBatchNormalization.nDim()).foreach(obj -> {
            return $anonfun$toTFDef$29(apply, BoxesRunTime.unboxToInt(obj));
        });
        apply.update(2, (int) BoxesRunTime.boxToInteger(spatialBatchNormalization.runningVar().size(1)));
        if (spatialBatchNormalization.weight() == null) {
            NodeDef m1452const = Tensorflow$.MODULE$.m1452const(apply, new StringBuilder(16).append(spatialBatchNormalization.getName()).append("/reshape_1/shape").toString(), byteOrder);
            NodeDef m1452const2 = Tensorflow$.MODULE$.m1452const(apply, new StringBuilder(16).append(spatialBatchNormalization.getName()).append("/reshape_2/shape").toString(), byteOrder);
            NodeDef m1452const3 = Tensorflow$.MODULE$.m1452const(spatialBatchNormalization.runningVar(), new StringBuilder(4).append(spatialBatchNormalization.getName()).append("/var").toString(), byteOrder);
            NodeDef m1452const4 = Tensorflow$.MODULE$.m1452const(spatialBatchNormalization.runningMean(), new StringBuilder(5).append(spatialBatchNormalization.getName()).append("/mean").toString(), byteOrder);
            NodeDef reshape = Tensorflow$.MODULE$.reshape(m1452const3, m1452const, new StringBuilder(10).append(spatialBatchNormalization.getName()).append("/reshape_1").toString());
            NodeDef reshape2 = Tensorflow$.MODULE$.reshape(m1452const4, m1452const2, new StringBuilder(10).append(spatialBatchNormalization.getName()).append("/reshape_2").toString());
            NodeDef rsqrt = Tensorflow$.MODULE$.rsqrt(reshape, new StringBuilder(8).append(spatialBatchNormalization.getName()).append("/sqrtvar").toString());
            NodeDef multiply = Tensorflow$.MODULE$.multiply((NodeDef) seq.apply(0), rsqrt, new StringBuilder(5).append(spatialBatchNormalization.getName()).append("/mul1").toString());
            NodeDef multiply2 = Tensorflow$.MODULE$.multiply(reshape2, rsqrt, new StringBuilder(5).append(spatialBatchNormalization.getName()).append("/mul2").toString());
            return Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new NodeDef[]{Tensorflow$.MODULE$.subtract(multiply, multiply2, new StringBuilder(7).append(spatialBatchNormalization.getName()).append("/output").toString()), multiply2, multiply, reshape2, m1452const2, m1452const4, rsqrt, reshape, m1452const, m1452const3}));
        }
        NodeDef m1452const5 = Tensorflow$.MODULE$.m1452const(apply, new StringBuilder(16).append(spatialBatchNormalization.getName()).append("/reshape_1/shape").toString(), byteOrder);
        NodeDef m1452const6 = Tensorflow$.MODULE$.m1452const(apply, new StringBuilder(16).append(spatialBatchNormalization.getName()).append("/reshape_2/shape").toString(), byteOrder);
        NodeDef m1452const7 = Tensorflow$.MODULE$.m1452const(apply, new StringBuilder(16).append(spatialBatchNormalization.getName()).append("/reshape_3/shape").toString(), byteOrder);
        NodeDef m1452const8 = Tensorflow$.MODULE$.m1452const(apply, new StringBuilder(16).append(spatialBatchNormalization.getName()).append("/reshape_4/shape").toString(), byteOrder);
        NodeDef m1452const9 = Tensorflow$.MODULE$.m1452const(spatialBatchNormalization.runningVar(), new StringBuilder(4).append(spatialBatchNormalization.getName()).append("/var").toString(), byteOrder);
        NodeDef m1452const10 = Tensorflow$.MODULE$.m1452const(spatialBatchNormalization.runningMean(), new StringBuilder(5).append(spatialBatchNormalization.getName()).append("/mean").toString(), byteOrder);
        NodeDef m1452const11 = Tensorflow$.MODULE$.m1452const(spatialBatchNormalization.weight(), new StringBuilder(6).append(spatialBatchNormalization.getName()).append("/scale").toString(), byteOrder);
        NodeDef m1452const12 = Tensorflow$.MODULE$.m1452const(spatialBatchNormalization.bias(), new StringBuilder(7).append(spatialBatchNormalization.getName()).append("/offset").toString(), byteOrder);
        NodeDef reshape3 = Tensorflow$.MODULE$.reshape(m1452const9, m1452const5, new StringBuilder(10).append(spatialBatchNormalization.getName()).append("/reshape_1").toString());
        NodeDef reshape4 = Tensorflow$.MODULE$.reshape(m1452const10, m1452const6, new StringBuilder(10).append(spatialBatchNormalization.getName()).append("/reshape_2").toString());
        NodeDef reshape5 = Tensorflow$.MODULE$.reshape(m1452const11, m1452const7, new StringBuilder(10).append(spatialBatchNormalization.getName()).append("/reshape_3").toString());
        NodeDef reshape6 = Tensorflow$.MODULE$.reshape(m1452const12, m1452const8, new StringBuilder(10).append(spatialBatchNormalization.getName()).append("/reshape_4").toString());
        NodeDef rsqrt2 = Tensorflow$.MODULE$.rsqrt(reshape3, new StringBuilder(8).append(spatialBatchNormalization.getName()).append("/sqrtvar").toString());
        NodeDef multiply3 = Tensorflow$.MODULE$.multiply(reshape5, rsqrt2, new StringBuilder(5).append(spatialBatchNormalization.getName()).append("/mul0").toString());
        NodeDef multiply4 = Tensorflow$.MODULE$.multiply((NodeDef) seq.apply(0), multiply3, new StringBuilder(5).append(spatialBatchNormalization.getName()).append("/mul1").toString());
        NodeDef multiply5 = Tensorflow$.MODULE$.multiply(reshape4, multiply3, new StringBuilder(5).append(spatialBatchNormalization.getName()).append("/mul2").toString());
        NodeDef subtract = Tensorflow$.MODULE$.subtract(reshape6, multiply5, new StringBuilder(4).append(spatialBatchNormalization.getName()).append("/sub").toString());
        return Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new NodeDef[]{Tensorflow$.MODULE$.add(multiply4, subtract, new StringBuilder(7).append(spatialBatchNormalization.getName()).append("/output").toString()), subtract, multiply5, multiply4, multiply3, reshape6, reshape4, reshape5, m1452const8, m1452const6, m1452const7, m1452const12, m1452const11, m1452const10, rsqrt2, reshape3, m1452const5, m1452const9}));
    }

    public static final /* synthetic */ Tensor $anonfun$toTFDef$29(Tensor tensor, int i) {
        return tensor.setValue(i + 1, BoxesRunTime.boxToInteger(1));
    }

    private BatchNorm2DToTF$() {
        MODULE$ = this;
    }
}
