package com.intel.analytics.bigdl.utils.tf;

import com.intel.analytics.bigdl.nn.SpatialConvolution;
import com.intel.analytics.bigdl.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.nn.abstractnn.DataFormat;
import com.intel.analytics.bigdl.nn.abstractnn.DataFormat$NCHW$;
import com.intel.analytics.bigdl.tensor.Tensor;
import com.intel.analytics.bigdl.tensor.Tensor$;
import com.intel.analytics.bigdl.tensor.TensorNumericMath$TensorNumeric$NumericInt$;
import java.nio.ByteOrder;
import org.tensorflow.framework.NodeDef;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableLike;
import scala.collection.immutable.IndexedSeq;
import scala.collection.immutable.IndexedSeq$;
import scala.collection.mutable.ArrayBuffer;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichInt$;

/* compiled from: BigDLToTensorflow.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/utils/tf/SpatialConvolutionToTF$.class */
public final class SpatialConvolutionToTF$ implements BigDLToTensorflow {
    public static SpatialConvolutionToTF$ MODULE$;

    static {
        new SpatialConvolutionToTF$();
    }

    @Override // com.intel.analytics.bigdl.utils.tf.BigDLToTensorflow
    public Seq<NodeDef> toTFDef(AbstractModule<?, ?, ?> abstractModule, Seq<NodeDef> seq, ByteOrder byteOrder) {
        Predef$.MODULE$.require(seq.length() == 1, () -> {
            return "SpatialConvolution only accept one input";
        });
        SpatialConvolution spatialConvolution = (SpatialConvolution) abstractModule;
        if (spatialConvolution.nGroup() != 1) {
            Predef$ predef$ = Predef$.MODULE$;
            DataFormat format = spatialConvolution.format();
            DataFormat$NCHW$ dataFormat$NCHW$ = DataFormat$NCHW$.MODULE$;
            predef$.require(format != null ? format.equals(dataFormat$NCHW$) : dataFormat$NCHW$ == null, () -> {
                return "Only NCHW support conv group";
            });
            ArrayBuffer arrayBuffer = new ArrayBuffer();
            NodeDef m3261const = Tensorflow$.MODULE$.m3261const(Tensor$.MODULE$.scalar(BoxesRunTime.boxToInteger(1), ClassTag$.MODULE$.Int(), TensorNumericMath$TensorNumeric$NumericInt$.MODULE$), new StringBuilder(10).append(spatialConvolution.getName()).append("/split_dim").toString(), ByteOrder.LITTLE_ENDIAN);
            Seq<NodeDef> split = Tensorflow$.MODULE$.split(m3261const, (NodeDef) seq.apply(0), spatialConvolution.nGroup(), new StringBuilder(6).append(spatialConvolution.getName()).append("/split").toString());
            arrayBuffer.append(Predef$.MODULE$.wrapRefArray(new NodeDef[]{m3261const}));
            arrayBuffer.appendAll(split);
            NodeDef m3261const2 = Tensorflow$.MODULE$.m3261const(Tensor$.MODULE$.scalar(BoxesRunTime.boxToInteger(1), ClassTag$.MODULE$.Int(), TensorNumericMath$TensorNumeric$NumericInt$.MODULE$), new StringBuilder(12).append(spatialConvolution.getName()).append("/concat/axis").toString(), ByteOrder.LITTLE_ENDIAN);
            arrayBuffer.append(Predef$.MODULE$.wrapRefArray(new NodeDef[]{m3261const2}));
            return (Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new NodeDef[]{Tensorflow$.MODULE$.concat((IndexedSeq) ((TraversableLike) RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), spatialConvolution.nGroup()).map(obj -> {
                return $anonfun$toTFDef$6(spatialConvolution, byteOrder, split, arrayBuffer, BoxesRunTime.unboxToInt(obj));
            }, IndexedSeq$.MODULE$.canBuildFrom())).$plus$plus(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new NodeDef[]{m3261const2})), IndexedSeq$.MODULE$.canBuildFrom()), new StringBuilder(14).append(spatialConvolution.getName()).append("/concat/output").toString())})).$plus$plus(arrayBuffer, Seq$.MODULE$.canBuildFrom());
        }
        DataFormat format2 = spatialConvolution.format();
        DataFormat$NCHW$ dataFormat$NCHW$2 = DataFormat$NCHW$.MODULE$;
        Tuple2 tuple2 = (format2 != null ? !format2.equals(dataFormat$NCHW$2) : dataFormat$NCHW$2 != null) ? new Tuple2(TensorflowDataFormat$NHWC$.MODULE$, spatialConvolution.weight().select(1, 1)) : new Tuple2(TensorflowDataFormat$NCHW$.MODULE$, spatialConvolution.weight().select(1, 1).transpose(2, 3).transpose(3, 4).transpose(1, 2).transpose(2, 3).transpose(3, 4).contiguous());
        if (!(tuple2 instanceof Tuple2)) {
            throw new MatchError(tuple2);
        }
        Tuple2 tuple22 = new Tuple2((TensorflowDataFormat) tuple2._1(), (Tensor) tuple2._2());
        TensorflowDataFormat tensorflowDataFormat = (TensorflowDataFormat) tuple22._1();
        NodeDef m3261const3 = Tensorflow$.MODULE$.m3261const((Tensor) tuple22._2(), new StringBuilder(7).append(spatialConvolution.getName()).append("/filter").toString(), byteOrder);
        NodeDef identity = Tensorflow$.MODULE$.identity(m3261const3, new StringBuilder(13).append(spatialConvolution.getName()).append("/filterReader").toString());
        NodeDef conv2D = Tensorflow$.MODULE$.conv2D((NodeDef) seq.apply(0), identity, spatialConvolution.strideW(), spatialConvolution.strideH(), spatialConvolution.kernelW(), spatialConvolution.kernelH(), spatialConvolution.padW(), spatialConvolution.padH(), tensorflowDataFormat, new StringBuilder(7).append(spatialConvolution.getName()).append("/conv2D").toString());
        if (spatialConvolution.bias() == null) {
            return Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new NodeDef[]{conv2D, identity, m3261const3}));
        }
        NodeDef m3261const4 = Tensorflow$.MODULE$.m3261const(spatialConvolution.bias(), new StringBuilder(5).append(spatialConvolution.getName()).append("/bias").toString(), byteOrder);
        NodeDef identity2 = Tensorflow$.MODULE$.identity(m3261const4, new StringBuilder(11).append(spatialConvolution.getName()).append("/biasReader").toString());
        return Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new NodeDef[]{Tensorflow$.MODULE$.biasAdd(conv2D, identity2, tensorflowDataFormat, new StringBuilder(8).append(spatialConvolution.getName()).append("/biasAdd").toString()), identity2, m3261const4, conv2D, identity, m3261const3}));
    }

    public static final /* synthetic */ NodeDef $anonfun$toTFDef$6(SpatialConvolution spatialConvolution, ByteOrder byteOrder, Seq seq, ArrayBuffer arrayBuffer, int i) {
        NodeDef m3261const = Tensorflow$.MODULE$.m3261const(spatialConvolution.weight().select(1, i + 1).transpose(2, 3).transpose(3, 4).transpose(1, 2).transpose(2, 3).transpose(3, 4).contiguous(), new StringBuilder(13).append(spatialConvolution.getName()).append("/group").append(i).append("/filter").toString(), byteOrder);
        NodeDef identity = Tensorflow$.MODULE$.identity(m3261const, new StringBuilder(19).append(spatialConvolution.getName()).append("/group").append(i).append("/filterReader").toString());
        NodeDef conv2D = Tensorflow$.MODULE$.conv2D((NodeDef) seq.apply(i), identity, spatialConvolution.strideW(), spatialConvolution.strideH(), spatialConvolution.kernelW(), spatialConvolution.kernelH(), spatialConvolution.padW(), spatialConvolution.padH(), TensorflowDataFormat$NCHW$.MODULE$, new StringBuilder(13).append(spatialConvolution.getName()).append("/group").append(i).append("/conv2D").toString());
        if (spatialConvolution.bias() == null) {
            arrayBuffer.append(Predef$.MODULE$.wrapRefArray(new NodeDef[]{conv2D, identity, m3261const}));
            return conv2D;
        }
        NodeDef m3261const2 = Tensorflow$.MODULE$.m3261const(spatialConvolution.bias().narrow(1, ((i * spatialConvolution.nOutputPlane()) / spatialConvolution.nGroup()) + 1, spatialConvolution.nOutputPlane() / spatialConvolution.nGroup()), new StringBuilder(11).append(spatialConvolution.getName()).append("/group").append(i).append("/bias").toString(), byteOrder);
        NodeDef identity2 = Tensorflow$.MODULE$.identity(m3261const2, new StringBuilder(17).append(spatialConvolution.getName()).append("/group").append(i).append("/biasReader").toString());
        NodeDef biasAdd = Tensorflow$.MODULE$.biasAdd(conv2D, identity2, TensorflowDataFormat$NCHW$.MODULE$, new StringBuilder(14).append(spatialConvolution.getName()).append("/group").append(i).append("/biasAdd").toString());
        arrayBuffer.append(Predef$.MODULE$.wrapRefArray(new NodeDef[]{biasAdd, identity2, m3261const2, conv2D, identity, m3261const}));
        return biasAdd;
    }

    private SpatialConvolutionToTF$() {
        MODULE$ = this;
    }
}
