package com.intel.analytics.bigdl.nn;

import com.intel.analytics.bigdl.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.tensor.Tensor;
import com.intel.analytics.bigdl.tensor.TensorDataType;
import com.intel.analytics.bigdl.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.utils.Engine$;
import com.intel.analytics.bigdl.utils.Table;
import com.intel.analytics.bigdl.utils.ThreadPool;
import scala.Function0;
import scala.Predef$;
import scala.collection.Seq$;
import scala.collection.TraversableOnce;
import scala.concurrent.Future;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: JoinTable.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005-h\u0001\u0002\r\u001a\u0001\u0011B\u0001B\u0013\u0001\u0003\u0006\u0004%\ta\u0013\u0005\t\u001f\u0002\u0011\t\u0011)A\u0005\u0019\"A\u0001\u000b\u0001BC\u0002\u0013\u00051\n\u0003\u0005R\u0001\t\u0005\t\u0015!\u0003M\u0011!\u0011\u0006AaA!\u0002\u0017\u0019\u0006\u0002C-\u0001\u0005\u0003\u0005\u000b1\u0002.\t\u000b9\u0004A\u0011A8\t\u000b]\u0004A\u0011\u0002=\t\u000bm\u0004A\u0011\t?\t\u000f\u0005\u001d\u0001\u0001\"\u0011\u0002\n!9\u00111\u0004\u0001\u0005B\u0005u\u0001bBA\u0018\u0001\u0011\u0005\u0013\u0011\u0007\u0005\b\u0003{\u0001A\u0011IA \u0011\u001d\t\u0019\u0005\u0001C!\u0003\u000bBq!a\u0012\u0001\t\u0003\nI\u0005C\u0005\u0002N\u0001\u0001\r\u0011\"\u0003\u0002P!I\u0011\u0011\u000e\u0001A\u0002\u0013%\u00111\u000e\u0005\t\u0003c\u0002\u0001\u0015)\u0003\u0002R\u001d9\u0011qQ\r\t\u0002\u0005%eA\u0002\r\u001a\u0011\u0003\tY\t\u0003\u0004o)\u0011\u0005\u0011\u0011\u0014\u0005\b\u00037#B\u0011AAO\u0011%\t9\u000eFA\u0001\n\u0013\tINA\u0005K_&tG+\u00192mK*\u0011!dG\u0001\u0003]:T!\u0001H\u000f\u0002\u000b\tLw\r\u001a7\u000b\u0005yy\u0012!C1oC2LH/[2t\u0015\t\u0001\u0013%A\u0003j]R,GNC\u0001#\u0003\r\u0019w.\\\u0002\u0001+\t)\u0003j\u0005\u0002\u0001MA)qE\u000b\u00173\u000f6\t\u0001F\u0003\u0002*3\u0005Q\u0011MY:ue\u0006\u001cGO\u001c8\n\u0005-B#AD!cgR\u0014\u0018m\u0019;N_\u0012,H.\u001a\t\u0003[Aj\u0011A\f\u0006\u0003_m\tQ!\u001e;jYNL!!\r\u0018\u0003\u000bQ\u000b'\r\\31\u0005MZ\u0004c\u0001\u001b8s5\tQG\u0003\u000277\u00051A/\u001a8t_JL!\u0001O\u001b\u0003\rQ+gn]8s!\tQ4\b\u0004\u0001\u0005\u0013q\u0002\u0011\u0011!A\u0001\u0006\u0003i$aA0%cE\u0011a\b\u0012\t\u0003\u007f\tk\u0011\u0001\u0011\u0006\u0002\u0003\u0006)1oY1mC&\u00111\t\u0011\u0002\b\u001d>$\b.\u001b8h!\tyT)\u0003\u0002G\u0001\n\u0019\u0011I\\=\u0011\u0005iBE!B%\u0001\u0005\u0004i$!\u0001+\u0002\u0013\u0011LW.\u001a8tS>tW#\u0001'\u0011\u0005}j\u0015B\u0001(A\u0005\rIe\u000e^\u0001\u000bI&lWM\\:j_:\u0004\u0013A\u00038J]B,H\u000fR5ng\u0006Ya.\u00138qkR$\u0015.\\:!\u0003))g/\u001b3f]\u000e,G%\r\t\u0004)^;U\"A+\u000b\u0005Y\u0003\u0015a\u0002:fM2,7\r^\u0005\u00031V\u0013\u0001b\u00117bgN$\u0016mZ\u0001\u0003KZ\u00042aW6H\u001d\ta\u0016N\u0004\u0002^Q:\u0011al\u001a\b\u0003?\u001at!\u0001Y3\u000f\u0005\u0005$W\"\u00012\u000b\u0005\r\u001c\u0013A\u0002\u001fs_>$h(C\u0001#\u0013\t\u0001\u0013%\u0003\u0002\u001f?%\u0011A$H\u0005\u0003mmI!A[\u001b\u0002#Q+gn]8s\u001dVlWM]5d\u001b\u0006$\b.\u0003\u0002m[\niA+\u001a8t_JtU/\\3sS\u000eT!A[\u001b\u0002\rqJg.\u001b;?)\r\u0001XO\u001e\u000b\u0004cN$\bc\u0001:\u0001\u000f6\t\u0011\u0004C\u0003S\u000f\u0001\u000f1\u000bC\u0003Z\u000f\u0001\u000f!\fC\u0003K\u000f\u0001\u0007A\nC\u0003Q\u000f\u0001\u0007A*\u0001\u000bhKR\u0004vn]5uSZ,G)[7f]NLwN\u001c\u000b\u0003\u0019fDQA\u001f\u0005A\u00021\nQ!\u001b8qkR\fA\"\u001e9eCR,w*\u001e;qkR$2!`A\u0003a\rq\u0018\u0011\u0001\t\u0004i]z\bc\u0001\u001e\u0002\u0002\u0011Q\u00111A\u0005\u0002\u0002\u0003\u0005)\u0011A\u001f\u0003\u0007}#3\u0007C\u0003{\u0013\u0001\u0007A&A\bva\u0012\fG/Z$sC\u0012Le\u000e];u)\u0015a\u00131BA\u0007\u0011\u0015Q(\u00021\u0001-\u0011\u001d\tyA\u0003a\u0001\u0003#\t!b\u001a:bI>+H\u000f];ua\u0011\t\u0019\"a\u0006\u0011\tQ:\u0014Q\u0003\t\u0004u\u0005]AaCA\r\u0003\u001b\t\t\u0011!A\u0003\u0002u\u00121a\u0018\u00137\u0003!!xn\u0015;sS:<GCAA\u0010!\u0011\t\t#!\u000b\u000f\t\u0005\r\u0012Q\u0005\t\u0003C\u0002K1!a\nA\u0003\u0019\u0001&/\u001a3fM&!\u00111FA\u0017\u0005\u0019\u0019FO]5oO*\u0019\u0011q\u0005!\u0002\u0011\r\fg.R9vC2$B!a\r\u0002:A\u0019q(!\u000e\n\u0007\u0005]\u0002IA\u0004C_>dW-\u00198\t\r\u0005mB\u00021\u0001E\u0003\u0015yG\u000f[3s\u0003\u0019)\u0017/^1mgR!\u00111GA!\u0011\u0019\tY$\u0004a\u0001\t\u0006A\u0001.Y:i\u0007>$W\rF\u0001M\u0003)\u0019G.Z1s'R\fG/\u001a\u000b\u0003\u0003\u0017j\u0011\u0001A\u0001\be\u0016\u001cX\u000f\u001c;t+\t\t\t\u0006E\u0003@\u0003'\n9&C\u0002\u0002V\u0001\u0013Q!\u0011:sCf\u0004b!!\u0017\u0002`\u0005\rTBAA.\u0015\r\ti\u0006Q\u0001\u000bG>t7-\u001e:sK:$\u0018\u0002BA1\u00037\u0012aAR;ukJ,\u0007cA \u0002f%\u0019\u0011q\r!\u0003\tUs\u0017\u000e^\u0001\fe\u0016\u001cX\u000f\u001c;t?\u0012*\u0017\u000f\u0006\u0003\u0002d\u00055\u0004\"CA8#\u0005\u0005\t\u0019AA)\u0003\rAH%M\u0001\te\u0016\u001cX\u000f\u001c;tA!\u001a!#!\u001e\u0011\u0007}\n9(C\u0002\u0002z\u0001\u0013\u0011\u0002\u001e:b]NLWM\u001c;)\u000f\u0001\ti(a!\u0002\u0006B\u0019q(a \n\u0007\u0005\u0005\u0005I\u0001\tTKJL\u0017\r\u001c,feNLwN\\+J\t\u0006)a/\u00197vKzA!R<3Z\u0003\u001e\u0005\u0013'A\u0005K_&tG+\u00192mKB\u0011!\u000fF\n\u0006)\u00055\u00151\u0013\t\u0004\u007f\u0005=\u0015bAAI\u0001\n1\u0011I\\=SK\u001a\u00042aPAK\u0013\r\t9\n\u0011\u0002\r'\u0016\u0014\u0018.\u00197ju\u0006\u0014G.\u001a\u000b\u0003\u0003\u0013\u000bQ!\u00199qYf,B!a(\u0002(R1\u0011\u0011UAj\u0003+$b!a)\u0002J\u0006=\u0007\u0003\u0002:\u0001\u0003K\u00032AOAT\t%Ie\u0003)A\u0001\u0002\u000b\u0007Q\b\u000b\u0005\u0002(\u0006-\u0016\u0011WA`!\ry\u0014QV\u0005\u0004\u0003_\u0003%aC:qK\u000eL\u0017\r\\5{K\u0012\f\u0014bIAZ\u0003k\u000bI,a.\u000f\u0007}\n),C\u0002\u00028\u0002\u000bQA\u00127pCR\fd\u0001JA^\u0003{\u000bebA1\u0002>&\t\u0011)M\u0005$\u0003\u0003\f\u0019-a2\u0002F:\u0019q(a1\n\u0007\u0005\u0015\u0007)\u0001\u0004E_V\u0014G.Z\u0019\u0007I\u0005m\u0016QX!\t\u0013\u0005-g#!AA\u0004\u00055\u0017AC3wS\u0012,gnY3%eA!AkVAS\u0011\u0019If\u0003q\u0001\u0002RB!1l[AS\u0011\u0015Qe\u00031\u0001M\u0011\u0015\u0001f\u00031\u0001M\u0003-\u0011X-\u00193SKN|GN^3\u0015\u0005\u0005m\u0007\u0003BAo\u0003Ol!!a8\u000b\t\u0005\u0005\u00181]\u0001\u0005Y\u0006twM\u0003\u0002\u0002f\u0006!!.\u0019<b\u0013\u0011\tI/a8\u0003\r=\u0013'.Z2u\u0001")
/* loaded from: input_file:com/intel/analytics/bigdl/nn/JoinTable.class */
public class JoinTable<T> extends AbstractModule<Table, Tensor<?>, T> {
    public static final long serialVersionUID = -8435694717504118735L;
    private final int dimension;
    private final int nInputDims;
    private transient Future<BoxedUnit>[] results;

    public int dimension() {
        return this.dimension;
    }

    public int nInputDims() {
        return this.nInputDims;
    }

    private Future<BoxedUnit>[] results() {
        return this.results;
    }

    private void results_$eq(Future<BoxedUnit>[] futureArr) {
        this.results = futureArr;
    }

    private int getPositiveDimension(Table table) {
        int dimension = dimension();
        Tensor tensor = (Tensor) table.apply(BoxesRunTime.boxToInteger(1));
        if (dimension < 0) {
            dimension = tensor.dim() + dimension + 1;
        } else if (nInputDims() > 0 && tensor.dim() == nInputDims() + 1) {
            dimension++;
        }
        Predef$.MODULE$.require(tensor.dim() >= dimension(), () -> {
            return new StringBuilder(63).append("dimension exceeds input dimensions").append(" input dimension ").append(tensor.dim()).append(", dimension ").append(this.dimension()).toString();
        });
        return dimension;
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    public Tensor<?> updateOutput(Table table) {
        int positiveDimension = getPositiveDimension(table);
        int[] iArr = null;
        int i = 1;
        while (true) {
            int i2 = i;
            if (i2 > table.length()) {
                break;
            }
            Tensor tensor = (Tensor) table.apply(BoxesRunTime.boxToInteger(i2));
            if (i2 == 1) {
                iArr = tensor.size();
            } else {
                int[] iArr2 = iArr;
                int i3 = positiveDimension - 1;
                iArr2[i3] = iArr2[i3] + tensor.size(positiveDimension);
            }
            i = i2 + 1;
        }
        Tensor tensor2 = (Tensor) table.apply(BoxesRunTime.boxToInteger(1));
        TensorDataType type = output().getType();
        TensorDataType type2 = tensor2.getType();
        if (type != null ? type.equals(type2) : type2 == null) {
            Tensor<?> output = output();
            output.resize(iArr, output.resize$default$2());
        } else {
            Tensor<T> emptyInstance = tensor2.emptyInstance();
            output_$eq(emptyInstance.resize(iArr, emptyInstance.resize$default$2()));
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        if (results() == null || results().length != table.length()) {
            results_$eq(new Future[table.length()]);
        }
        int i4 = 1;
        int i5 = 0;
        while (i5 < table.length()) {
            Tensor tensor3 = (Tensor) table.apply(BoxesRunTime.boxToInteger(i5 + 1));
            int i6 = i4;
            results()[i5] = Engine$.MODULE$.model().invoke((Function0) () -> {
                Tensor<?> narrow = this.output().narrow(positiveDimension, i6, tensor3.size(positiveDimension));
                if (narrow.isContiguous() || positiveDimension > 2) {
                    narrow.copy(tensor3);
                    return;
                }
                int i7 = 1;
                while (true) {
                    int i8 = i7;
                    if (i8 > narrow.size(1)) {
                        return;
                    }
                    Tensor<?> select = narrow.select(1, i8);
                    Tensor<T> select2 = tensor3.select(1, i8);
                    Predef$.MODULE$.require(select.isContiguous());
                    Predef$.MODULE$.require(select2.isContiguous());
                    select.copy(select2);
                    i7 = i8 + 1;
                }
            });
            i5++;
            i4 += tensor3.size(positiveDimension);
        }
        ThreadPool model = Engine$.MODULE$.model();
        model.sync(Predef$.MODULE$.wrapRefArray(results()), model.sync$default$2());
        return output();
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    public Table updateGradInput(Table table, Tensor<?> tensor) {
        int positiveDimension = getPositiveDimension(table);
        int i = 1;
        int i2 = 0;
        while (i2 < table.length()) {
            Tensor tensor2 = (Tensor) table.apply(BoxesRunTime.boxToInteger(i2 + 1));
            int i3 = i;
            int i4 = i2;
            results()[i2] = Engine$.MODULE$.model().invoke((Function0) () -> {
                Tensor<T> narrow = tensor.narrow(positiveDimension, i3, tensor2.size(positiveDimension));
                Tensor<?> tensor3 = (Tensor) table.apply(BoxesRunTime.boxToInteger(i4 + 1));
                if (this.gradInput().contains(BoxesRunTime.boxToInteger(i4 + 1))) {
                    ((Tensor) this.gradInput().apply(BoxesRunTime.boxToInteger(i4 + 1))).resizeAs(tensor3);
                } else {
                    this.gradInput().update(BoxesRunTime.boxToInteger(i4 + 1), tensor3.emptyInstance().resizeAs(tensor3));
                }
                if (narrow.isContiguous() || positiveDimension > 2) {
                    ((Tensor) this.gradInput().apply(BoxesRunTime.boxToInteger(i4 + 1))).copy(narrow);
                    return;
                }
                int i5 = 1;
                while (true) {
                    int i6 = i5;
                    if (i6 > narrow.size(1)) {
                        return;
                    }
                    Tensor<T> select = ((Tensor) this.gradInput().apply(BoxesRunTime.boxToInteger(i4 + 1))).select(1, i6);
                    Tensor<T> select2 = narrow.select(1, i6);
                    Predef$.MODULE$.require(select.isContiguous());
                    Predef$.MODULE$.require(select2.isContiguous());
                    select.copy(select2);
                    i5 = i6 + 1;
                }
            });
            i2++;
            i += tensor2.size(positiveDimension);
        }
        ThreadPool model = Engine$.MODULE$.model();
        model.sync(Predef$.MODULE$.wrapRefArray(results()), model.sync$default$2());
        return gradInput();
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    public String toString() {
        return "nn.JoinTable";
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    public boolean canEqual(Object obj) {
        return obj instanceof JoinTable;
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    public boolean equals(Object obj) {
        boolean z;
        if (obj instanceof JoinTable) {
            JoinTable joinTable = (JoinTable) obj;
            z = super.equals(joinTable) && joinTable.canEqual(this) && dimension() == joinTable.dimension() && nInputDims() == joinTable.nInputDims();
        } else {
            z = false;
        }
        return z;
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    public int hashCode() {
        return BoxesRunTime.unboxToInt(((TraversableOnce) Seq$.MODULE$.apply(Predef$.MODULE$.wrapIntArray(new int[]{super.hashCode(), dimension(), nInputDims()})).map(obj -> {
            return BoxesRunTime.boxToInteger(getHashCode$1(obj));
        }, Seq$.MODULE$.canBuildFrom())).foldLeft(BoxesRunTime.boxToInteger(0), (i, i2) -> {
            return (31 * i) + i2;
        }));
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    public JoinTable<T> clearState() {
        super.clearState();
        gradInput().clear();
        return this;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static final int getHashCode$1(Object obj) {
        if (obj == null) {
            return 0;
        }
        return obj.hashCode();
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public JoinTable(int i, int i2, ClassTag<T> classTag, TensorNumericMath.TensorNumeric<T> tensorNumeric) {
        super(ClassTag$.MODULE$.apply(Table.class), ClassTag$.MODULE$.apply(Tensor.class), classTag, tensorNumeric);
        this.dimension = i;
        this.nInputDims = i2;
        this.results = null;
    }
}
