package com.intel.analytics.bigdl.utils.serializer;

import com.intel.analytics.bigdl.nn.Container;
import com.intel.analytics.bigdl.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.serialization.Bigdl;
import com.intel.analytics.bigdl.tensor.QuantizedTensor;
import com.intel.analytics.bigdl.tensor.Tensor;
import com.intel.analytics.bigdl.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.utils.File$;
import com.intel.analytics.bigdl.utils.FileReader;
import com.intel.analytics.bigdl.utils.FileReader$;
import com.intel.analytics.bigdl.utils.Table;
import com.intel.analytics.bigdl.utils.serializer.converters.DataReaderWriter$;
import com.intel.analytics.bigdl.utils.serializer.converters.TensorConverter$;
import com.intel.analytics.shaded.protobuf_v_3_5_1.CodedInputStream;
import java.io.ByteArrayInputStream;
import java.io.DataInputStream;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.security.DigestInputStream;
import java.security.MessageDigest;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.IterableLike;
import scala.collection.JavaConverters$;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.HashMap;
import scala.collection.mutable.HashSet;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichInt$;

/* compiled from: ModuleLoader.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/utils/serializer/ModuleLoader$.class */
public final class ModuleLoader$ {
    public static ModuleLoader$ MODULE$;

    static {
        new ModuleLoader$();
    }

    public <T> AbstractModule<Activity, Activity, T> loadFromFile(String str, String str2, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        DeserializeContext deserializeContext;
        Bigdl.BigDLModule.Builder newBuilder = Bigdl.BigDLModule.newBuilder();
        CodedInputStream newInstance = CodedInputStream.newInstance(new ByteArrayInputStream(File$.MODULE$.readBytes(str)));
        newInstance.setSizeLimit(Integer.MAX_VALUE);
        newBuilder.mergeFrom(newInstance);
        Bigdl.BigDLModule build = newBuilder.build();
        HashMap hashMap = new HashMap();
        if (str2 == null) {
            deserializeContext = new DeserializeContext(build, hashMap, ProtoStorageType$.MODULE$, DeserializeContext$.MODULE$.apply$default$4());
            initTensorStorage(deserializeContext, classTag, tensorNumeric);
        } else {
            deserializeContext = new DeserializeContext(build, hashMap, BigDLStorage$.MODULE$, DeserializeContext$.MODULE$.apply$default$4());
            initTensorStorage(deserializeContext, str2, classTag, tensorNumeric);
        }
        return ModuleSerializer$.MODULE$.load(deserializeContext, classTag, tensorNumeric).module();
    }

    public <T> String loadFromFile$default$2() {
        return null;
    }

    private <T> void initTensorStorage(DeserializeContext deserializeContext, String str, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        int MAGIC_NO = SerConst$.MODULE$.MAGIC_NO();
        FileReader fileReader = null;
        InputStream inputStream = null;
        ObjectInputStream objectInputStream = null;
        HashMap<Object, Object> storages = deserializeContext.storages();
        try {
            fileReader = FileReader$.MODULE$.apply(str);
            inputStream = fileReader.open();
            DigestInputStream digestInputStream = new DigestInputStream(inputStream, MessageDigest.getInstance(SerConst$.MODULE$.DIGEST_TYPE()));
            DataInputStream dataInputStream = new DataInputStream(digestInputStream);
            digestInputStream.on(true);
            int readInt = dataInputStream.readInt();
            Predef$.MODULE$.require(readInt == MAGIC_NO, () -> {
                return new StringBuilder(41).append("Magic number mismatch, expected ").append(MAGIC_NO).append(", actual ").append(readInt).toString();
            });
            RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), dataInputStream.readInt()).foreach$mVc$sp(i -> {
                int readInt2 = dataInputStream.readInt();
                storages.update(BoxesRunTime.boxToInteger(readInt2), DataReaderWriter$.MODULE$.apply(BigDLDataType$.MODULE$.apply(dataInputStream.readInt())).read(dataInputStream, dataInputStream.readInt()));
            });
            digestInputStream.on(false);
            int readInt2 = dataInputStream.readInt();
            byte[] bArr = new byte[readInt2];
            dataInputStream.read(bArr);
            byte[] digest = digestInputStream.getMessageDigest().digest();
            Predef$.MODULE$.require(digest.length == readInt2, () -> {
                return "checksum error, size mismatch";
            });
            RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), readInt2).foreach$mVc$sp(i2 -> {
                Predef$.MODULE$.require(digest[i2] == bArr[i2], () -> {
                    return "check sum error, please check weight file";
                });
            });
            if (inputStream != null) {
                inputStream.close();
            }
            if (fileReader != null) {
                fileReader.close();
            }
            if (0 != 0) {
                objectInputStream.close();
            }
        } catch (Throwable th) {
            if (inputStream != null) {
                inputStream.close();
            }
            if (fileReader != null) {
                fileReader.close();
            }
            if (0 != 0) {
                objectInputStream.close();
            }
            throw th;
        }
    }

    public <T> void initTensorStorage(DeserializeContext deserializeContext, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        ((IterableLike) JavaConverters$.MODULE$.mapAsScalaMapConverter(deserializeContext.bigdlModule().getAttrMap().get(SerConst$.MODULE$.GLOBAL_STORAGE()).getNameAttrListValue().getAttrMap()).asScala()).foreach(tuple2 -> {
            $anonfun$initTensorStorage$6(deserializeContext, classTag, tensorNumeric, tuple2);
            return BoxedUnit.UNIT;
        });
    }

    public <T> void loadFromDefinition(AbstractModule<Activity, Activity, T> abstractModule, String str, HashSet<String> hashSet, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        HashSet<String> hashSet2;
        AbstractModule<Activity, Activity, T> loadFromFile = loadFromFile(str, loadFromFile$default$2(), classTag, tensorNumeric);
        if (hashSet == null) {
            HashSet<String> hashSet3 = new HashSet<>();
            getAllLayers(abstractModule, hashSet3, classTag);
            hashSet2 = hashSet3;
        } else {
            hashSet2 = hashSet;
        }
        copyParams(abstractModule, loadFromFile, hashSet2, classTag);
    }

    public <T> HashSet<String> loadFromDefinition$default$3() {
        return null;
    }

    private <T> void getAllLayers(AbstractModule<Activity, Activity, T> abstractModule, HashSet<String> hashSet, ClassTag<T> classTag) {
        hashSet.add(abstractModule.getName());
        if (abstractModule instanceof Container) {
            ((Container) abstractModule).modules().foreach(abstractModule2 -> {
                $anonfun$getAllLayers$1(hashSet, abstractModule2);
                return BoxedUnit.UNIT;
            });
        }
    }

    private <T> void copyParams(AbstractModule<Activity, Activity, T> abstractModule, AbstractModule<Activity, Activity, T> abstractModule2, HashSet<String> hashSet, ClassTag<T> classTag) {
        Table parametersTable = abstractModule.getParametersTable();
        Table parametersTable2 = abstractModule2.getParametersTable();
        hashSet.foreach(str -> {
            $anonfun$copyParams$1(parametersTable, parametersTable2, classTag, str);
            return BoxedUnit.UNIT;
        });
    }

    private <T> void copyParams(Table table, Table table2, ClassTag<T> classTag) {
        copyParam(table, table2, "weight", classTag);
        copyParam(table, table2, "bias", classTag);
    }

    private <T> void copyParam(Table table, Table table2, String str, ClassTag<T> classTag) {
        if (!table.contains(str)) {
            return;
        }
        if (!(table2.get(str).get() instanceof Tensor[])) {
            ((Tensor) table.get(str).get()).copy((Tensor) table2.get(str).get());
            return;
        }
        Predef$.MODULE$.require(table.get(str).get() instanceof Tensor[], () -> {
            return "param type mismatch!";
        });
        Tensor<T>[] tensorArr = (Tensor[]) table.get(str).get();
        Tensor[] tensorArr2 = (Tensor[]) table.get(str).get();
        int i = 0;
        while (true) {
            int i2 = i;
            if (i2 >= tensorArr.length) {
                return;
            }
            tensorArr2[i2].copy(tensorArr[i2]);
            i = i2 + 1;
        }
    }

    public static final /* synthetic */ void $anonfun$initTensorStorage$6(DeserializeContext deserializeContext, ClassTag classTag, TensorNumericMath.TensorNumeric tensorNumeric, Tuple2 tuple2) {
        Object storage;
        HashMap<Object, Object> storages = deserializeContext.storages();
        int i = new StringOps(Predef$.MODULE$.augmentString((String) tuple2._1())).toInt();
        Bigdl.BigDLTensor tensorValue = ((Bigdl.AttrValue) tuple2._2()).getTensorValue();
        int id = tensorValue.getStorage().getId();
        Tensor tensor = (Tensor) TensorConverter$.MODULE$.getAttributeValue(deserializeContext, (Bigdl.AttrValue) tuple2._2(), classTag, tensorNumeric);
        Bigdl.TensorType tensorType = tensorValue.getTensorType();
        if (Bigdl.TensorType.DENSE.equals(tensorType)) {
            storage = tensor.storage();
        } else {
            if (!Bigdl.TensorType.QUANT.equals(tensorType)) {
                throw new UnsupportedOperationException("Unsupported Tensor Type");
            }
            storage = ((QuantizedTensor) tensor).getStorage();
        }
        storages.update(BoxesRunTime.boxToInteger(i), tensor);
        storages.update(BoxesRunTime.boxToInteger(id), storage);
    }

    public static final /* synthetic */ void $anonfun$getAllLayers$1(HashSet hashSet, AbstractModule abstractModule) {
        MODULE$.getAllLayers(abstractModule, hashSet, ClassTag$.MODULE$.apply(Object.class));
    }

    public static final /* synthetic */ void $anonfun$copyParams$1(Table table, Table table2, ClassTag classTag, String str) {
        if (table.contains(str)) {
            Predef$.MODULE$.require(table2.contains(str), () -> {
                return new StringBuilder(32).append(str).append(" does not exist in loaded module").toString();
            });
            MODULE$.copyParams((Table) table.get(str).get(), (Table) table2.get(str).get(), classTag);
        }
    }

    private ModuleLoader$() {
        MODULE$ = this;
    }
}
