package com.intel.analytics.bigdl.utils;

import com.intel.analytics.bigdl.nn.Container;
import com.intel.analytics.bigdl.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.nn.tf.Const;
import com.intel.analytics.bigdl.tensor.QuantizedTensor;
import com.intel.analytics.bigdl.tensor.QuantizedTensor$;
import com.intel.analytics.bigdl.tensor.QuantizedType$;
import com.intel.analytics.bigdl.tensor.Storage;
import com.intel.analytics.bigdl.tensor.Storage$;
import com.intel.analytics.bigdl.tensor.Tensor;
import com.intel.analytics.bigdl.tensor.Tensor$;
import com.intel.analytics.bigdl.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.tensor.TensorType;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectStreamClass;
import org.apache.commons.lang3.SerializationException;
import scala.Array$;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.collection.immutable.Map;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.ArrayBuffer;
import scala.collection.mutable.ArrayBuffer$;
import scala.collection.mutable.ArrayOps;
import scala.math.Numeric$IntIsIntegral$;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;
import scala.runtime.ObjectRef;
import scala.runtime.RichInt$;
import scala.runtime.ScalaRunTime$;
import scala.util.Try$;

/* compiled from: Util.scala */
/* loaded from: input_file:com/intel/analytics/bigdl/utils/Util$.class */
public final class Util$ {
    public static Util$ MODULE$;

    static {
        new Util$();
    }

    public long kthLargest(long[] jArr, int i, int i2, int i3) {
        while (i3 != 0) {
            int randomPartition = randomPartition(jArr, i, i2);
            if (randomPartition - i == i3 - 1) {
                return jArr[randomPartition];
            }
            if (randomPartition - i > i3 - 1) {
                return kthLargest(jArr, i, randomPartition - 1, i3);
            }
            i3 = ((i3 - randomPartition) + i) - 1;
            i2 = i2;
            i = randomPartition + 1;
            jArr = jArr;
        }
        return Long.MAX_VALUE;
    }

    public void swap(long[] jArr, int i, int i2) {
        long j = jArr[i];
        jArr[i] = jArr[i2];
        jArr[i2] = j;
    }

    private int partition(long[] jArr, int i, int i2) {
        long j = jArr[i2];
        IntRef create = IntRef.create(i);
        RichInt$.MODULE$.to$extension0(Predef$.MODULE$.intWrapper(i), i2 - 1).foreach$mVc$sp(i3 -> {
            if (jArr[i3] > j) {
                MODULE$.swap(jArr, create.elem, i3);
                create.elem++;
            }
        });
        swap(jArr, create.elem, i2);
        return create.elem;
    }

    private int randomPartition(long[] jArr, int i, int i2) {
        swap(jArr, i + ((int) (Math.random() % ((i2 - i) + 1))), i2);
        return partition(jArr, i, i2);
    }

    public <B> Object shift(Object obj, int i, int i2) {
        Predef$.MODULE$.require(i < ScalaRunTime$.MODULE$.array_length(obj) && i >= 0, () -> {
            return new StringBuilder(30).append("invalid from ").append(i).append(" array length is ").append(ScalaRunTime$.MODULE$.array_length(obj)).toString();
        });
        Predef$.MODULE$.require(i2 < ScalaRunTime$.MODULE$.array_length(obj) && i2 >= 0, () -> {
            return new StringBuilder(28).append("invalid to ").append(i2).append(" array length is ").append(ScalaRunTime$.MODULE$.array_length(obj)).toString();
        });
        if (i == i2) {
            return obj;
        }
        if (i < i2) {
            int i3 = i;
            while (true) {
                int i4 = i3;
                if (i4 >= i2) {
                    return obj;
                }
                Object array_apply = ScalaRunTime$.MODULE$.array_apply(obj, i4);
                ScalaRunTime$.MODULE$.array_update(obj, i4, ScalaRunTime$.MODULE$.array_apply(obj, i4 + 1));
                ScalaRunTime$.MODULE$.array_update(obj, i4 + 1, array_apply);
                i3 = i4 + 1;
            }
        } else {
            int i5 = i;
            while (true) {
                int i6 = i5;
                if (i6 <= i2) {
                    return obj;
                }
                Object array_apply2 = ScalaRunTime$.MODULE$.array_apply(obj, i6);
                ScalaRunTime$.MODULE$.array_update(obj, i6, ScalaRunTime$.MODULE$.array_apply(obj, i6 - 1));
                ScalaRunTime$.MODULE$.array_update(obj, i6 - 1, array_apply2);
                i5 = i6 - 1;
            }
        }
    }

    public <T> Tensor<T>[] getAndClearWeightBias(Tuple2<Tensor<T>[], Tensor<T>[]> tuple2, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Tuple2 tuple22;
        BoxedUnit boxedUnit;
        if (((Tensor[]) tuple2._1()).length == 0) {
            return (Tensor[]) Array$.MODULE$.apply(Nil$.MODULE$, ClassTag$.MODULE$.apply(Tensor.class));
        }
        int i = 0;
        Tensor<T>[] tensorArr = new Tensor[((Tensor[]) tuple2._1()).length];
        if (new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) tuple2._1())).exists(tensor -> {
            return BoxesRunTime.boxToBoolean($anonfun$getAndClearWeightBias$1(tensor));
        })) {
            tuple22 = new Tuple2(BoxesRunTime.boxToBoolean(false), (Object) null);
        } else {
            Storage<T> apply = Storage$.MODULE$.apply(((Tensor[]) tuple2._1())[0].storage().array(), classTag);
            tuple22 = new Tuple2(BoxesRunTime.boxToBoolean(BoxesRunTime.unboxToInt(new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps((int[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) tuple2._1())).map(tensor2 -> {
                return BoxesRunTime.boxToInteger(tensor2.nElement());
            }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Int())))).sum(Numeric$IntIsIntegral$.MODULE$)) == apply.length()), apply);
        }
        Tuple2 tuple23 = tuple22;
        if (tuple23 == null) {
            throw new MatchError(tuple23);
        }
        Tuple2 tuple24 = new Tuple2(BoxesRunTime.boxToBoolean(tuple23._1$mcZ$sp()), (Storage) tuple23._2());
        boolean _1$mcZ$sp = tuple24._1$mcZ$sp();
        Storage<T> storage = (Storage) tuple24._2();
        while (i < ((Tensor[]) tuple2._1()).length) {
            if (((Tensor[]) tuple2._1())[i] != null) {
                Tensor tensor3 = ((Tensor[]) tuple2._1())[i];
                if (QuantizedType$.MODULE$.equals(tensor3.getTensorType())) {
                    QuantizedTensor quantizedTensor = (QuantizedTensor) tensor3;
                    tensorArr[i] = QuantizedTensor$.MODULE$.apply(quantizedTensor.getStorage(), quantizedTensor.maxOfRow(), quantizedTensor.minOfRow(), quantizedTensor.sumOfRow(), quantizedTensor.size(), quantizedTensor.params(), classTag, tensorNumeric);
                    boxedUnit = BoxedUnit.UNIT;
                } else {
                    tensorArr[i] = _1$mcZ$sp ? Tensor$.MODULE$.apply(storage, tensor3.storageOffset(), tensor3.size(), tensor3.stride(), classTag, tensorNumeric) : Tensor$.MODULE$.apply(Storage$.MODULE$.apply(tensor3.storage().array(), classTag), tensor3.storageOffset(), tensor3.size(), tensor3.stride(), classTag, tensorNumeric);
                    boxedUnit = BoxedUnit.UNIT;
                }
                i++;
            }
        }
        clearTensor((Tensor[]) tuple2._1(), classTag, tensorNumeric);
        clearTensor((Tensor[]) tuple2._2(), classTag, tensorNumeric);
        return tensorArr;
    }

    public <T> Map<String, Tensor<?>> getAndClearConsts(Container<?, ?, T> container, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        ArrayBuffer arrayBuffer = (ArrayBuffer) ((TraversableLike) container.findModules("Const").map(abstractModule -> {
            return (Const) abstractModule;
        }, ArrayBuffer$.MODULE$.canBuildFrom())).map(r5 -> {
            return new Tuple2(r5, r5.value().shallowClone());
        }, ArrayBuffer$.MODULE$.canBuildFrom());
        arrayBuffer.foreach(tuple2 -> {
            return ((Const) tuple2._1()).value().set();
        });
        Map<String, Tensor<?>> map = ((TraversableOnce) arrayBuffer.map(tuple22 -> {
            return new Tuple2(((AbstractModule) tuple22._1()).getName(), tuple22._2());
        }, ArrayBuffer$.MODULE$.canBuildFrom())).toMap(Predef$.MODULE$.$conforms());
        Predef$.MODULE$.require(map.size() == arrayBuffer.length(), () -> {
            return new StringBuilder(59).append(container).append("'s Const node's name is duplicated,").append("please check your model.").toString();
        });
        return map;
    }

    public <T> void putConsts(Container<?, ?, T> container, Map<String, Tensor<?>> map, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        ((ArrayBuffer) container.findModules("Const").map(abstractModule -> {
            return (Const) abstractModule;
        }, ArrayBuffer$.MODULE$.canBuildFrom())).foreach(r5 -> {
            return r5.value().set((Tensor) map.apply(r5.getName()));
        });
    }

    private <T> void clearTensor(Tensor<T>[] tensorArr, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        int i = 0;
        while (true) {
            int i2 = i;
            if (i2 >= tensorArr.length) {
                return;
            }
            if (tensorArr[i2] != null) {
                TensorType tensorType = tensorArr[i2].getTensorType();
                QuantizedType$ quantizedType$ = QuantizedType$.MODULE$;
                if (tensorType != null ? !tensorType.equals(quantizedType$) : quantizedType$ != null) {
                    BoxedUnit boxedUnit = BoxedUnit.UNIT;
                } else {
                    tensorArr[i2].mo1159toQuantizedTensor().release();
                }
                tensorArr[i2].set();
            } else {
                BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
            }
            i = i2 + 1;
        }
    }

    public <T> void putWeightBias(Tensor<T>[] tensorArr, AbstractModule<Activity, Activity, T> abstractModule, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Tensor[] tensorArr2 = (Tensor[]) abstractModule.parameters()._1();
        int i = 0;
        while (true) {
            int i2 = i;
            if (i2 >= tensorArr2.length) {
                return;
            }
            if (tensorArr2[i2] != null) {
                clearAndSet$1(tensorArr2[i2], tensorArr[i2]);
            }
            i = i2 + 1;
        }
    }

    public <T> void initGradWeightBias(Tensor<T>[] tensorArr, AbstractModule<Activity, Activity, T> abstractModule, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        Tuple2<Tensor<T>[], Tensor<T>[]> parameters = abstractModule.parameters();
        if (parameters == null) {
            throw new MatchError(parameters);
        }
        Tuple2 tuple2 = new Tuple2((Tensor[]) parameters._1(), (Tensor[]) parameters._2());
        Tensor[] tensorArr2 = (Tensor[]) tuple2._1();
        Tensor[] tensorArr3 = (Tensor[]) tuple2._2();
        Storage<T> apply = Storage$.MODULE$.apply(BoxesRunTime.unboxToInt(new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps((int[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(tensorArr3)).map(tensor -> {
            return BoxesRunTime.boxToInteger(tensor.nElement());
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Int())))).sum(Numeric$IntIsIntegral$.MODULE$)), classTag);
        new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(tensorArr)).exists(tensor2 -> {
            return BoxesRunTime.boxToBoolean($anonfun$initGradWeightBias$2(tensor2));
        });
        int i = 0;
        while (true) {
            int i2 = i;
            if (i2 >= tensorArr2.length) {
                return;
            }
            if (tensorArr2[i2] != null) {
                Tensor<T> tensor3 = tensorArr[i2];
                Tensor<T> tensor4 = QuantizedType$.MODULE$.equals(tensor3.getTensorType()) ? tensorArr3[i2].set(Tensor$.MODULE$.apply(1, classTag, tensorNumeric)) : tensorArr3[i2].set(apply, tensor3.storageOffset(), tensor3.size(), tensor3.stride());
            } else {
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            }
            i = i2 + 1;
        }
    }

    public <T> T deserialize(byte[] bArr, ClassTag<T> classTag) {
        if (bArr == null) {
            throw new IllegalArgumentException("The byte[] must not be null");
        }
        return (T) deserialize(new ByteArrayInputStream(bArr), classTag);
    }

    public <T> T deserialize(final InputStream inputStream, ClassTag<T> classTag) {
        if (inputStream == null) {
            throw new IllegalArgumentException("The InputStream must not be null");
        }
        ObjectRef create = ObjectRef.create((Object) null);
        try {
            try {
                try {
                    try {
                        create.elem = new ObjectInputStream(inputStream) { // from class: com.intel.analytics.bigdl.utils.Util$$anon$1
                            /* JADX INFO: Access modifiers changed from: private */
                            public /* synthetic */ Class super$resolveClass(ObjectStreamClass objectStreamClass) {
                                return super.resolveClass(objectStreamClass);
                            }

                            @Override // java.io.ObjectInputStream
                            public Class<?> resolveClass(ObjectStreamClass objectStreamClass) {
                                return (Class) Try$.MODULE$.apply(() -> {
                                    return Class.forName(objectStreamClass.getName(), false, this.getClass().getClassLoader());
                                }).getOrElse(() -> {
                                    return this.super$resolveClass(objectStreamClass);
                                });
                            }
                        };
                        return (T) ((ObjectInputStream) create.elem).readObject();
                    } catch (ClassNotFoundException e) {
                        throw new SerializationException(e);
                    }
                } catch (IOException e2) {
                    throw new SerializationException(e2);
                }
            } catch (ClassCastException e3) {
                throw new SerializationException(e3);
            }
        } finally {
            if (((ObjectInputStream) create.elem) != null) {
                Try$.MODULE$.apply(() -> {
                    ((ObjectInputStream) create.elem).close();
                });
            }
        }
    }

    public static final /* synthetic */ boolean $anonfun$getAndClearWeightBias$1(Tensor tensor) {
        TensorType tensorType = tensor.getTensorType();
        QuantizedType$ quantizedType$ = QuantizedType$.MODULE$;
        return tensorType != null ? tensorType.equals(quantizedType$) : quantizedType$ == null;
    }

    private static final void clearAndSet$1(Tensor tensor, Tensor tensor2) {
        TensorType tensorType = tensor.getTensorType();
        QuantizedType$ quantizedType$ = QuantizedType$.MODULE$;
        if (tensorType != null ? tensorType.equals(quantizedType$) : quantizedType$ == null) {
            TensorType tensorType2 = tensor2.getTensorType();
            QuantizedType$ quantizedType$2 = QuantizedType$.MODULE$;
            if (tensorType2 != null ? tensorType2.equals(quantizedType$2) : quantizedType$2 == null) {
                QuantizedTensor quantizedTensor = (QuantizedTensor) tensor;
                if (quantizedTensor.getNativeStorage() != ((QuantizedTensor) tensor2).getNativeStorage()) {
                    quantizedTensor.release();
                } else {
                    BoxedUnit boxedUnit = BoxedUnit.UNIT;
                }
                tensor.set(tensor2);
            }
        }
        BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
        tensor.set(tensor2);
    }

    public static final /* synthetic */ boolean $anonfun$initGradWeightBias$2(Tensor tensor) {
        TensorType tensorType = tensor.getTensorType();
        QuantizedType$ quantizedType$ = QuantizedType$.MODULE$;
        return tensorType != null ? tensorType.equals(quantizedType$) : quantizedType$ == null;
    }

    private Util$() {
        MODULE$ = this;
    }
}
