package com.intel.analytics.bigdl.utils.tf;

import com.intel.analytics.bigdl.nn.CAddTable$;
import com.intel.analytics.bigdl.nn.CMulTable$;
import com.intel.analytics.bigdl.nn.Dropout$;
import com.intel.analytics.bigdl.nn.Graph;
import com.intel.analytics.bigdl.nn.Input$;
import com.intel.analytics.bigdl.nn.JoinTable$;
import com.intel.analytics.bigdl.nn.Linear$;
import com.intel.analytics.bigdl.nn.LogSoftMax$;
import com.intel.analytics.bigdl.nn.Mean$;
import com.intel.analytics.bigdl.nn.Padding$;
import com.intel.analytics.bigdl.nn.ReLU$;
import com.intel.analytics.bigdl.nn.Reshape$;
import com.intel.analytics.bigdl.nn.Scale$;
import com.intel.analytics.bigdl.nn.Sigmoid$;
import com.intel.analytics.bigdl.nn.SoftMax$;
import com.intel.analytics.bigdl.nn.SpatialAveragePooling;
import com.intel.analytics.bigdl.nn.SpatialAveragePooling$;
import com.intel.analytics.bigdl.nn.SpatialBatchNormalization$;
import com.intel.analytics.bigdl.nn.SpatialConvolution$;
import com.intel.analytics.bigdl.nn.SpatialCrossMapLRN$;
import com.intel.analytics.bigdl.nn.SpatialMaxPooling;
import com.intel.analytics.bigdl.nn.SpatialMaxPooling$;
import com.intel.analytics.bigdl.nn.Squeeze$;
import com.intel.analytics.bigdl.nn.Tanh$;
import com.intel.analytics.bigdl.nn.TemporalConvolution$;
import com.intel.analytics.bigdl.nn.View$;
import com.intel.analytics.bigdl.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.tensor.Tensor;
import com.intel.analytics.bigdl.tensor.Tensor$;
import com.intel.analytics.bigdl.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.utils.FileWriter;
import com.intel.analytics.bigdl.utils.FileWriter$;
import com.intel.analytics.bigdl.utils.Node;
import com.intel.analytics.bigdl.utils.T$;
import com.intel.analytics.bigdl.utils.Table;
import com.intel.analytics.shaded.protobuf_v_3_5_1.CodedOutputStream;
import java.io.OutputStream;
import java.nio.ByteOrder;
import org.apache.log4j.Logger;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Tuple2;
import scala.collection.IterableLike;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableOnce;
import scala.collection.immutable.Nil$;
import scala.collection.immutable.Set;
import scala.collection.mutable.ArrayBuffer;
import scala.collection.mutable.ArrayBuffer$;
import scala.collection.mutable.ArrayOps;
import scala.collection.mutable.HashMap;
import scala.collection.mutable.Map;
import scala.collection.mutable.Map$;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;

/* compiled from: TensorflowSaver.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/utils/tf/TensorflowSaver$.class */
public final class TensorflowSaver$ {
    public static TensorflowSaver$ MODULE$;
    private final Logger logger;
    private final Map<String, BigDLToTensorflow> maps;

    static {
        new TensorflowSaver$();
    }

    public <T> void saveGraphWithNodeDef(Graph<T> graph, Seq<NodeDef> seq, String str, ByteOrder byteOrder, Set<NodeDef> set) {
        HashMap hashMap = new HashMap();
        ((IterableLike) graph.inputs().zip(seq, Seq$.MODULE$.canBuildFrom())).foreach(tuple2 -> {
            $anonfun$saveGraphWithNodeDef$1(hashMap, tuple2);
            return BoxedUnit.UNIT;
        });
        GraphDef.Builder newBuilder = GraphDef.newBuilder();
        seq.foreach(nodeDef -> {
            return newBuilder.addNode(nodeDef);
        });
        new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(graph.getSortedForwardExecutions())).foreach(node -> {
            $anonfun$saveGraphWithNodeDef$3(hashMap, byteOrder, newBuilder, node);
            return BoxedUnit.UNIT;
        });
        set.foreach(nodeDef2 -> {
            return newBuilder.addNode(nodeDef2);
        });
        FileWriter fileWriter = null;
        OutputStream outputStream = null;
        try {
            fileWriter = FileWriter$.MODULE$.apply(str);
            outputStream = fileWriter.create(true);
            CodedOutputStream newInstance = CodedOutputStream.newInstance(outputStream);
            GraphDef build = newBuilder.build();
            logger().debug("Graph definition is:");
            logger().debug(build.toString());
            build.writeTo(newInstance);
            newInstance.flush();
            outputStream.flush();
            logger().info(new StringBuilder(33).append("Save as tensorflow model file to ").append(str).toString());
            if (outputStream != null) {
                outputStream.close();
            }
            if (fileWriter != null) {
                fileWriter.close();
            }
        } catch (Throwable th) {
            if (outputStream != null) {
                outputStream.close();
            }
            if (fileWriter != null) {
                fileWriter.close();
            }
            throw th;
        }
    }

    public <T> void saveGraph(Graph<T> graph, Seq<Tuple2<String, Seq<Object>>> seq, String str, ByteOrder byteOrder, TensorflowDataFormat tensorflowDataFormat, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Activity activity;
        if (((ArrayBuffer) graph.modules().filter(abstractModule -> {
            return BoxesRunTime.boxToBoolean($anonfun$saveGraph$1(abstractModule));
        })).size() != 0) {
            Seq seq2 = (Seq) seq.map(tuple2 -> {
                Tensor apply = Tensor$.MODULE$.apply(classTag, tensorNumeric);
                return apply.resize((int[]) ((TraversableOnce) tuple2._2()).toArray(ClassTag$.MODULE$.Int()), apply.resize$default$2());
            }, Seq$.MODULE$.canBuildFrom());
            if (seq2.size() == 1) {
                activity = (Activity) seq2.head();
            } else {
                Table apply = T$.MODULE$.apply();
                IntRef create = IntRef.create(1);
                seq2.foreach(tensor -> {
                    $anonfun$saveGraph$3(apply, create, tensor);
                    return BoxedUnit.UNIT;
                });
                activity = apply;
            }
            graph.forward(activity);
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        saveGraphWithNodeDef(graph, (Seq) seq.map(tuple22 -> {
            return Tensorflow$.MODULE$.placeholder(graph.getNumericType(), (Seq) tuple22._2(), (String) tuple22._1());
        }, Seq$.MODULE$.canBuildFrom()), str, byteOrder, saveGraphWithNodeDef$default$5());
    }

    public <T> ByteOrder saveGraphWithNodeDef$default$4() {
        return ByteOrder.LITTLE_ENDIAN;
    }

    public <T> Set<NodeDef> saveGraphWithNodeDef$default$5() {
        return Predef$.MODULE$.Set().apply(Nil$.MODULE$);
    }

    public <T> ByteOrder saveGraph$default$4() {
        return ByteOrder.LITTLE_ENDIAN;
    }

    public <T> TensorflowDataFormat saveGraph$default$5() {
        return TensorflowDataFormat$NHWC$.MODULE$;
    }

    public void register(String str, BigDLToTensorflow bigDLToTensorflow) {
        maps().update(str, bigDLToTensorflow);
    }

    private Logger logger() {
        return this.logger;
    }

    private Map<String, BigDLToTensorflow> maps() {
        return this.maps;
    }

    private String getNameFromObj(String str) {
        return str.substring(0, str.length() - 1);
    }

    public static final /* synthetic */ void $anonfun$saveGraphWithNodeDef$1(HashMap hashMap, Tuple2 tuple2) {
        hashMap.update(((AbstractModule) ((Node) tuple2._1()).element()).getName(), ArrayBuffer$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new NodeDef[]{(NodeDef) tuple2._2()})));
    }

    public static final /* synthetic */ void $anonfun$saveGraphWithNodeDef$5(HashMap hashMap, Seq seq, Node node) {
        ArrayBuffer arrayBuffer = (ArrayBuffer) hashMap.getOrElse(((AbstractModule) node.element()).getName(), () -> {
            return ArrayBuffer$.MODULE$.apply(Nil$.MODULE$);
        });
        arrayBuffer.append(Predef$.MODULE$.wrapRefArray(new NodeDef[]{(NodeDef) seq.apply(0)}));
        hashMap.update(((AbstractModule) node.element()).getName(), arrayBuffer);
    }

    public static final /* synthetic */ void $anonfun$saveGraphWithNodeDef$3(HashMap hashMap, ByteOrder byteOrder, GraphDef.Builder builder, Node node) {
        Seq<NodeDef> tFDef = ((BigDLToTensorflow) MODULE$.maps().apply(node.element().getClass().getName())).toTFDef((AbstractModule) node.element(), (Seq) hashMap.apply(((AbstractModule) node.element()).getName()), byteOrder);
        tFDef.foreach(nodeDef -> {
            return builder.addNode(nodeDef);
        });
        node.nextNodes().foreach(node2 -> {
            $anonfun$saveGraphWithNodeDef$5(hashMap, tFDef, node2);
            return BoxedUnit.UNIT;
        });
    }

    public static final /* synthetic */ boolean $anonfun$saveGraph$1(AbstractModule abstractModule) {
        if (abstractModule instanceof SpatialMaxPooling) {
            SpatialMaxPooling spatialMaxPooling = (SpatialMaxPooling) abstractModule;
            return spatialMaxPooling.ceilMode() && spatialMaxPooling.padH() == 0 && spatialMaxPooling.padW() == 0;
        }
        if (!(abstractModule instanceof SpatialAveragePooling)) {
            return false;
        }
        SpatialAveragePooling spatialAveragePooling = (SpatialAveragePooling) abstractModule;
        return spatialAveragePooling.ceilMode() && spatialAveragePooling.padH() == 0 && spatialAveragePooling.padW() == 0;
    }

    public static final /* synthetic */ void $anonfun$saveGraph$3(Table table, IntRef intRef, Tensor tensor) {
        table.update(BoxesRunTime.boxToInteger(intRef.elem), tensor);
        intRef.elem++;
    }

    private TensorflowSaver$() {
        MODULE$ = this;
        this.logger = Logger.getLogger(getClass());
        this.maps = Map$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(getNameFromObj(TemporalConvolution$.MODULE$.getClass().getName())), TemporalConvolutionToTF$.MODULE$), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(getNameFromObj(ReLU$.MODULE$.getClass().getName())), ReLUToTF$.MODULE$), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(getNameFromObj(Linear$.MODULE$.getClass().getName())), LinearToTF$.MODULE$), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(getNameFromObj(SpatialConvolution$.MODULE$.getClass().getName())), SpatialConvolutionToTF$.MODULE$), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(getNameFromObj(Squeeze$.MODULE$.getClass().getName())), SqueezeToTF$.MODULE$), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(getNameFromObj(Tanh$.MODULE$.getClass().getName())), TanhToTF$.MODULE$), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(getNameFromObj(Reshape$.MODULE$.getClass().getName())), ReshapeToTF$.MODULE$), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(getNameFromObj(View$.MODULE$.getClass().getName())), ViewToTF$.MODULE$), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(getNameFromObj(SpatialMaxPooling$.MODULE$.getClass().getName())), MaxpoolToTF$.MODULE$), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(getNameFromObj(Padding$.MODULE$.getClass().getName())), PaddingToTF$.MODULE$), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(getNameFromObj(SpatialAveragePooling$.MODULE$.getClass().getName())), AvgpoolToTF$.MODULE$), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(getNameFromObj(Sigmoid$.MODULE$.getClass().getName())), SigmoidToTF$.MODULE$), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(getNameFromObj(Dropout$.MODULE$.getClass().getName())), DropoutToTF$.MODULE$), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(getNameFromObj(CAddTable$.MODULE$.getClass().getName())), CAddTableToTF$.MODULE$), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(getNameFromObj(CMulTable$.MODULE$.getClass().getName())), CMultTableToTF$.MODULE$), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(getNameFromObj(JoinTable$.MODULE$.getClass().getName())), JoinTableToTF$.MODULE$), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(getNameFromObj(Mean$.MODULE$.getClass().getName())), MeanToTF$.MODULE$), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(getNameFromObj(SoftMax$.MODULE$.getClass().getName())), SoftMaxToTF$.MODULE$), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(getNameFromObj(LogSoftMax$.MODULE$.getClass().getName())), LogSoftMaxToTF$.MODULE$), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(getNameFromObj(SpatialBatchNormalization$.MODULE$.getClass().getName())), BatchNorm2DToTF$.MODULE$), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(getNameFromObj(Input$.MODULE$.getClass().getName())), InputToTF$.MODULE$), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(getNameFromObj(Sigmoid$.MODULE$.getClass().getName())), SigmoidToTF$.MODULE$), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(getNameFromObj(Scale$.MODULE$.getClass().getName())), ScaleToTF$.MODULE$), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(getNameFromObj(SpatialCrossMapLRN$.MODULE$.getClass().getName())), LRNToTF$.MODULE$)}));
    }
}
