package com.intel.analytics.bigdl.nn;

import com.intel.analytics.bigdl.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.tensor.ConvertableFrom$ConvertableFromDouble$;
import com.intel.analytics.bigdl.tensor.ConvertableFrom$ConvertableFromInt$;
import com.intel.analytics.bigdl.tensor.Tensor;
import com.intel.analytics.bigdl.tensor.Tensor$;
import com.intel.analytics.bigdl.tensor.TensorMath;
import com.intel.analytics.bigdl.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.utils.RandomGenerator$;
import com.intel.analytics.bigdl.utils.Table;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.Seq$;
import scala.collection.TraversableOnce;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: PairwiseDistance.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005ef\u0001\u0002\u000b\u0016\u0001\u0001B\u0001B\u0011\u0001\u0003\u0006\u0004%\ta\u0011\u0005\t\u000f\u0002\u0011\t\u0011)A\u0005\t\"A\u0001\n\u0001B\u0002B\u0003-\u0011\n\u0003\u0005P\u0001\t\u0005\t\u0015a\u0003Q\u0011\u0015!\u0007\u0001\"\u0001f\u0011\u0015a\u0007\u0001\"\u0011n\u0011\u0015\u0001\b\u0001\"\u0003r\u0011\u0015!\b\u0001\"\u0011v\u0011\u0015I\b\u0001\"\u0011{\u0011\u001d\t9\u0001\u0001C!\u0003\u0013Aq!!\u0006\u0001\t\u0003\n9\u0002C\u0004\u0002\u001c\u0001!\t%!\b\b\u000f\u0005-R\u0003#\u0001\u0002.\u00191A#\u0006E\u0001\u0003_Aa\u0001\u001a\b\u0005\u0002\u0005u\u0002bBA \u001d\u0011\u0005\u0011\u0011\t\u0005\n\u0003sr\u0011\u0013!C\u0001\u0003wB\u0011\"!&\u000f#\u0003%\t!a&\t\u0013\u0005\u0015f\"!A\u0005\n\u0005\u001d&\u0001\u0005)bSJ<\u0018n]3ESN$\u0018M\\2f\u0015\t1r#\u0001\u0002o]*\u0011\u0001$G\u0001\u0006E&<G\r\u001c\u0006\u00035m\t\u0011\"\u00198bYf$\u0018nY:\u000b\u0005qi\u0012!B5oi\u0016d'\"\u0001\u0010\u0002\u0007\r|Wn\u0001\u0001\u0016\u0005\u000524C\u0001\u0001#!\u0015\u0019c\u0005\u000b\u00185\u001b\u0005!#BA\u0013\u0016\u0003)\t'm\u001d;sC\u000e$hN\\\u0005\u0003O\u0011\u0012a\"\u00112tiJ\f7\r^'pIVdW\r\u0005\u0002*Y5\t!F\u0003\u0002,/\u0005)Q\u000f^5mg&\u0011QF\u000b\u0002\u0006)\u0006\u0014G.\u001a\t\u0004_I\"T\"\u0001\u0019\u000b\u0005E:\u0012A\u0002;f]N|'/\u0003\u00024a\t1A+\u001a8t_J\u0004\"!\u000e\u001c\r\u0001\u0011)q\u0007\u0001b\u0001q\t\tA+\u0005\u0002:\u007fA\u0011!(P\u0007\u0002w)\tA(A\u0003tG\u0006d\u0017-\u0003\u0002?w\t9aj\u001c;iS:<\u0007C\u0001\u001eA\u0013\t\t5HA\u0002B]f\fAA\\8s[V\tA\t\u0005\u0002;\u000b&\u0011ai\u000f\u0002\u0004\u0013:$\u0018!\u00028pe6\u0004\u0013AC3wS\u0012,gnY3%cA\u0019!*\u0014\u001b\u000e\u0003-S!\u0001T\u001e\u0002\u000fI,g\r\\3di&\u0011aj\u0013\u0002\t\u00072\f7o\u001d+bO\u0006\u0011QM\u001e\t\u0004#\u0006$dB\u0001*`\u001d\t\u0019fL\u0004\u0002U;:\u0011Q\u000b\u0018\b\u0003-ns!a\u0016.\u000e\u0003aS!!W\u0010\u0002\rq\u0012xn\u001c;?\u0013\u0005q\u0012B\u0001\u000f\u001e\u0013\tQ2$\u0003\u0002\u00193%\u0011\u0011gF\u0005\u0003AB\n\u0011\u0003V3og>\u0014h*^7fe&\u001cW*\u0019;i\u0013\t\u00117MA\u0007UK:\u001cxN\u001d(v[\u0016\u0014\u0018n\u0019\u0006\u0003AB\na\u0001P5oSRtDC\u00014l)\r9\u0017N\u001b\t\u0004Q\u0002!T\"A\u000b\t\u000b!+\u00019A%\t\u000b=+\u00019\u0001)\t\u000f\t+\u0001\u0013!a\u0001\t\u0006aQ\u000f\u001d3bi\u0016|U\u000f\u001e9viR\u0011aF\u001c\u0005\u0006_\u001a\u0001\r\u0001K\u0001\u0006S:\u0004X\u000f^\u0001\t[\u0006$\bn]5h]R\u0011AG\u001d\u0005\u0006g\u001e\u0001\r\u0001N\u0001\u0002q\u0006yQ\u000f\u001d3bi\u0016<%/\u00193J]B,H\u000fF\u0002)m^DQa\u001c\u0005A\u0002!BQ\u0001\u001f\u0005A\u00029\n!b\u001a:bI>+H\u000f];u\u0003!!xn\u0015;sS:<G#A>\u0011\u0007q\f\tA\u0004\u0002~}B\u0011qkO\u0005\u0003\u007fn\na\u0001\u0015:fI\u00164\u0017\u0002BA\u0002\u0003\u000b\u0011aa\u0015;sS:<'BA@<\u0003!\u0019\u0017M\\#rk\u0006dG\u0003BA\u0006\u0003#\u00012AOA\u0007\u0013\r\tya\u000f\u0002\b\u0005>|G.Z1o\u0011\u0019\t\u0019B\u0003a\u0001\u007f\u0005)q\u000e\u001e5fe\u00061Q-];bYN$B!a\u0003\u0002\u001a!1\u00111C\u0006A\u0002}\n\u0001\u0002[1tQ\u000e{G-\u001a\u000b\u0002\t\":\u0001!!\t\u0002(\u0005%\u0002c\u0001\u001e\u0002$%\u0019\u0011QE\u001e\u0003!M+'/[1m-\u0016\u00148/[8o+&#\u0015!\u0002<bYV,g\u0004CbBko\u0007(\u000f\\5\u0002!A\u000b\u0017N]<jg\u0016$\u0015n\u001d;b]\u000e,\u0007C\u00015\u000f'\u0015q\u0011\u0011GA\u001c!\rQ\u00141G\u0005\u0004\u0003kY$AB!osJ+g\rE\u0002;\u0003sI1!a\u000f<\u00051\u0019VM]5bY&T\u0018M\u00197f)\t\ti#A\u0003baBd\u00170\u0006\u0003\u0002D\u0005-C\u0003BA#\u0003o\"b!a\u0012\u0002n\u0005M\u0004\u0003\u00025\u0001\u0003\u0013\u00022!NA&\t%9\u0004\u0003)A\u0001\u0002\u000b\u0007\u0001\b\u000b\u0005\u0002L\u0005=\u0013QKA2!\rQ\u0014\u0011K\u0005\u0004\u0003'Z$aC:qK\u000eL\u0017\r\\5{K\u0012\f\u0014bIA,\u00033\ni&a\u0017\u000f\u0007i\nI&C\u0002\u0002\\m\nQA\u00127pCR\fd\u0001JA0\u0003CbdbA,\u0002b%\tA(M\u0005$\u0003K\n9'a\u001b\u0002j9\u0019!(a\u001a\n\u0007\u0005%4(\u0001\u0004E_V\u0014G.Z\u0019\u0007I\u0005}\u0013\u0011\r\u001f\t\u0013\u0005=\u0004#!AA\u0004\u0005E\u0014AC3wS\u0012,gnY3%eA!!*TA%\u0011\u0019y\u0005\u0003q\u0001\u0002vA!\u0011+YA%\u0011\u001d\u0011\u0005\u0003%AA\u0002\u0011\u000b1\u0004\n7fgNLg.\u001b;%OJ,\u0017\r^3sI\u0011,g-Y;mi\u0012\nT\u0003BA?\u0003'+\"!a +\u0007\u0011\u000b\ti\u000b\u0002\u0002\u0004B!\u0011QQAH\u001b\t\t9I\u0003\u0003\u0002\n\u0006-\u0015!C;oG\",7m[3e\u0015\r\tiiO\u0001\u000bC:tw\u000e^1uS>t\u0017\u0002BAI\u0003\u000f\u0013\u0011#\u001e8dQ\u0016\u001c7.\u001a3WCJL\u0017M\\2f\t\u00159\u0014C1\u00019\u0003=\t\u0007\u000f\u001d7zI\u0011,g-Y;mi\u0012\nT\u0003BA?\u00033#\u0011b\u000e\n!\u0002\u0003\u0005)\u0019\u0001\u001d)\u0011\u0005e\u0015qJAO\u0003C\u000b\u0014bIA,\u00033\ny*a\u00172\r\u0011\ny&!\u0019=c%\u0019\u0013QMA4\u0003G\u000bI'\r\u0004%\u0003?\n\t\u0007P\u0001\fe\u0016\fGMU3t_24X\r\u0006\u0002\u0002*B!\u00111VA[\u001b\t\tiK\u0003\u0003\u00020\u0006E\u0016\u0001\u00027b]\u001eT!!a-\u0002\t)\fg/Y\u0005\u0005\u0003o\u000biK\u0001\u0004PE*,7\r\u001e")
/* loaded from: input_file:com/intel/analytics/bigdl/nn/PairwiseDistance.class */
public class PairwiseDistance<T> extends AbstractModule<Table, Tensor<T>, T> {
    public static final long serialVersionUID = -4377017408738399127L;
    private final int norm;
    private final ClassTag<T> evidence$1;
    private final TensorNumericMath.TensorNumeric<T> ev;

    public int norm() {
        return this.norm;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    public Tensor<T> updateOutput(Table table) {
        output().resize(1);
        if (((Tensor) table.apply(BoxesRunTime.boxToInteger(1))).dim() == 1) {
            output().resize(1);
            output().setValue(1, ((TensorMath) table.apply(BoxesRunTime.boxToInteger(1))).mo1123dist((Tensor) table.apply(BoxesRunTime.boxToInteger(2)), norm()));
        } else if (((Tensor) table.apply(BoxesRunTime.boxToInteger(1))).dim() == 2) {
            Tensor apply = Tensor$.MODULE$.apply(this.evidence$1, this.ev);
            apply.resizeAs((Tensor) table.apply(BoxesRunTime.boxToInteger(1))).zero();
            apply.add((Tensor) table.apply(BoxesRunTime.boxToInteger(1)), this.ev.mo1182fromType(BoxesRunTime.boxToInteger(-1), ConvertableFrom$ConvertableFromInt$.MODULE$), (Tensor) table.apply(BoxesRunTime.boxToInteger(2)));
            apply.abs();
            output().resize(((Tensor) table.apply(BoxesRunTime.boxToInteger(1))).size(1));
            output().zero();
            output().add((Tensor) apply.pow(this.ev.mo1182fromType(BoxesRunTime.boxToInteger(norm()), ConvertableFrom$ConvertableFromInt$.MODULE$)).sum(2));
            output().pow(this.ev.divide(this.ev.mo1182fromType(BoxesRunTime.boxToInteger(1), ConvertableFrom$ConvertableFromInt$.MODULE$), this.ev.mo1182fromType(BoxesRunTime.boxToInteger(norm()), ConvertableFrom$ConvertableFromInt$.MODULE$)));
        } else {
            Predef$.MODULE$.require(false, () -> {
                return new StringBuilder(18).append("PairwiseDistance: ").append(ErrorInfo$.MODULE$.constrainEachInputAsVectorOrBatch()).toString();
            });
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        return output();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public T mathsign(T t) {
        if (this.ev.equals(new Tuple2(t, this.ev.zero()))) {
            BoxesRunTime.boxToDouble((2 * RandomGenerator$.MODULE$.RNG().uniform(0.0d, 2.0d)) - 3);
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        return this.ev.isGreater(t, this.ev.zero()) ? this.ev.one() : (T) this.ev.negative(this.ev.one());
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    public Table updateGradInput(Table table, Tensor<T> tensor) {
        Predef$.MODULE$.require(((Tensor) table.apply(BoxesRunTime.boxToInteger(1))).dim() <= 2, () -> {
            return new StringBuilder(19).append("PairwiseDistance : ").append(ErrorInfo$.MODULE$.constrainEachInputAsVectorOrBatch()).toString();
        });
        if (gradInput().contains(BoxesRunTime.boxToInteger(1))) {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        } else {
            gradInput().update(BoxesRunTime.boxToInteger(1), Tensor$.MODULE$.apply(this.evidence$1, this.ev));
        }
        if (gradInput().contains(BoxesRunTime.boxToInteger(2))) {
            BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
        } else {
            gradInput().update(BoxesRunTime.boxToInteger(2), Tensor$.MODULE$.apply(this.evidence$1, this.ev));
        }
        ((Tensor) gradInput().apply(BoxesRunTime.boxToInteger(1))).resizeAs((Tensor) table.apply(BoxesRunTime.boxToInteger(1)));
        ((Tensor) gradInput().apply(BoxesRunTime.boxToInteger(2))).resizeAs((Tensor) table.apply(BoxesRunTime.boxToInteger(2)));
        ((Tensor) gradInput().apply(BoxesRunTime.boxToInteger(1))).copy((Tensor) table.apply(BoxesRunTime.boxToInteger(1))).add((Tensor) this.ev.negative(this.ev.one()), (Tensor<Tensor>) table.apply(BoxesRunTime.boxToInteger(2)));
        if (norm() == 1) {
            ((Tensor) gradInput().apply(BoxesRunTime.boxToInteger(1))).apply1(obj -> {
                return this.mathsign(obj);
            });
        } else if (norm() > 2) {
            ((TensorMath) gradInput().apply(BoxesRunTime.boxToInteger(1))).cmul(((Tensor) gradInput().apply(BoxesRunTime.boxToInteger(1))).m1158clone().abs().pow(this.ev.minus(this.ev.mo1182fromType(BoxesRunTime.boxToInteger(norm()), ConvertableFrom$ConvertableFromInt$.MODULE$), this.ev.mo1182fromType(BoxesRunTime.boxToInteger(2), ConvertableFrom$ConvertableFromInt$.MODULE$))));
        } else {
            BoxedUnit boxedUnit3 = BoxedUnit.UNIT;
        }
        if (((Tensor) table.apply(BoxesRunTime.boxToInteger(1))).dim() > 1) {
            Tensor<T> apply = Tensor$.MODULE$.apply(this.evidence$1, this.ev);
            apply.resize(output().size(1), 1).copy(output()).add((Tensor<T>) this.ev.mo1182fromType(BoxesRunTime.boxToDouble(1.0E-6d), ConvertableFrom$ConvertableFromDouble$.MODULE$)).pow(this.ev.negative(this.ev.mo1182fromType(BoxesRunTime.boxToInteger(norm() - 1), ConvertableFrom$ConvertableFromInt$.MODULE$)));
            ((TensorMath) gradInput().apply(BoxesRunTime.boxToInteger(1))).cmul(apply.expand(new int[]{((Tensor) gradInput().apply(BoxesRunTime.boxToInteger(1))).size(1), ((Tensor) gradInput().apply(BoxesRunTime.boxToInteger(1))).size(2)}));
        } else {
            ((TensorMath) gradInput().apply(BoxesRunTime.boxToInteger(1))).mul(this.ev.pow(this.ev.plus(output().mo1137apply(new int[]{1}), this.ev.mo1182fromType(BoxesRunTime.boxToDouble(1.0E-6d), ConvertableFrom$ConvertableFromDouble$.MODULE$)), this.ev.mo1182fromType(BoxesRunTime.boxToInteger(1 - norm()), ConvertableFrom$ConvertableFromInt$.MODULE$)));
        }
        if (((Tensor) table.apply(BoxesRunTime.boxToInteger(1))).dim() == 1) {
            ((TensorMath) gradInput().apply(BoxesRunTime.boxToInteger(1))).mul(tensor.mo1137apply(new int[]{1}));
        } else {
            Tensor<T> apply2 = Tensor$.MODULE$.apply(this.evidence$1, this.ev);
            Tensor<T> apply3 = Tensor$.MODULE$.apply(this.evidence$1, this.ev);
            apply2.resizeAs((Tensor) table.apply(BoxesRunTime.boxToInteger(1))).zero();
            apply3.resize(((Tensor) table.apply(BoxesRunTime.boxToInteger(1))).size(2)).fill(this.ev.one());
            apply2.addr(tensor, apply3);
            ((TensorMath) gradInput().apply(BoxesRunTime.boxToInteger(1))).cmul(apply2);
        }
        ((Tensor) gradInput().apply(BoxesRunTime.boxToInteger(2))).zero().add((Tensor) this.ev.negative(this.ev.one()), (Tensor<Tensor>) gradInput().apply(BoxesRunTime.boxToInteger(1)));
        return gradInput();
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    public String toString() {
        return "nn.PairwiseDistance";
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    public boolean canEqual(Object obj) {
        return obj instanceof PairwiseDistance;
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    public boolean equals(Object obj) {
        boolean z;
        if (obj instanceof PairwiseDistance) {
            PairwiseDistance pairwiseDistance = (PairwiseDistance) obj;
            z = super.equals(pairwiseDistance) && pairwiseDistance.canEqual(this) && norm() == pairwiseDistance.norm();
        } else {
            z = false;
        }
        return z;
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    public int hashCode() {
        return BoxesRunTime.unboxToInt(((TraversableOnce) Seq$.MODULE$.apply(Predef$.MODULE$.wrapIntArray(new int[]{super.hashCode(), norm()})).map(i -> {
            return BoxesRunTime.boxToInteger(i).hashCode();
        }, Seq$.MODULE$.canBuildFrom())).foldLeft(BoxesRunTime.boxToInteger(0), (i2, i3) -> {
            return (31 * i2) + i3;
        }));
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public PairwiseDistance(int i, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        super(ClassTag$.MODULE$.apply(Table.class), ClassTag$.MODULE$.apply(Tensor.class), classTag, tensorNumeric);
        this.norm = i;
        this.evidence$1 = classTag;
        this.ev = tensorNumeric;
    }
}
