package com.intel.analytics.bigdl.nn;

import com.intel.analytics.bigdl.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.tensor.Tensor;
import com.intel.analytics.bigdl.tensor.Tensor$;
import com.intel.analytics.bigdl.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.utils.Table;
import scala.Predef$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: MaskHead.scala */
@ScalaSignature(bytes = "\u0006\u0001\u001d4QAC\u0006\u0001\u0017UA\u0001B\f\u0001\u0003\u0002\u0003\u0006Y\u0001\r\u0005\u0006\t\u0002!\t!\u0012\u0005\b\u0015\u0002\u0011\r\u0011\"\u0003L\u0011\u0019y\u0005\u0001)A\u0005\u0019\")\u0001\u000b\u0001C!#\")A\u000b\u0001C!+\"9\u0011\f\u0001a\u0001\n\u0003Q\u0006bB.\u0001\u0001\u0004%\t\u0001\u0018\u0005\u0007E\u0002\u0001\u000b\u0015\u0002\u0012\u0003#5\u000b7o\u001b)pgR\u0004&o\\2fgN|'O\u0003\u0002\r\u001b\u0005\u0011aN\u001c\u0006\u0003\u001d=\tQAY5hI2T!\u0001E\t\u0002\u0013\u0005t\u0017\r\\=uS\u000e\u001c(B\u0001\n\u0014\u0003\u0015Ig\u000e^3m\u0015\u0005!\u0012aA2p[N\u0011\u0001A\u0006\t\u0006/ia\"\u0005K\u0007\u00021)\u0011\u0011dC\u0001\u000bC\n\u001cHO]1di:t\u0017BA\u000e\u0019\u00059\t%m\u001d;sC\u000e$Xj\u001c3vY\u0016\u0004\"!\b\u0011\u000e\u0003yQ!aH\u0007\u0002\u000bU$\u0018\u000e\\:\n\u0005\u0005r\"!\u0002+bE2,\u0007cA\u0012'Q5\tAE\u0003\u0002&\u001b\u00051A/\u001a8t_JL!a\n\u0013\u0003\rQ+gn]8s!\tIC&D\u0001+\u0015\u0005Y\u0013!B:dC2\f\u0017BA\u0017+\u0005\u00151En\\1u\u0003\t)go\u0001\u0001\u0011\u0007E\n\u0005F\u0004\u00023\u007f9\u00111G\u0010\b\u0003iur!!\u000e\u001f\u000f\u0005YZdBA\u001c;\u001b\u0005A$BA\u001d0\u0003\u0019a$o\\8u}%\tA#\u0003\u0002\u0013'%\u0011\u0001#E\u0005\u0003\u001d=I!!J\u0007\n\u0005\u0001#\u0013!\u0005+f]N|'OT;nKJL7-T1uQ&\u0011!i\u0011\u0002\u000e)\u0016t7o\u001c:Ok6,'/[2\u000b\u0005\u0001#\u0013A\u0002\u001fj]&$h\bF\u0001G)\t9\u0015\n\u0005\u0002I\u00015\t1\u0002C\u0003/\u0005\u0001\u000f\u0001'A\u0004tS\u001elw.\u001b3\u0016\u00031\u00032\u0001S')\u0013\tq5BA\u0004TS\u001elw.\u001b3\u0002\u0011MLw-\\8jI\u0002\nA\"\u001e9eCR,w*\u001e;qkR$\"A\t*\t\u000bM+\u0001\u0019\u0001\u000f\u0002\u000b%t\u0007/\u001e;\u0002\u001fU\u0004H-\u0019;f\u000fJ\fG-\u00138qkR$2\u0001\b,X\u0011\u0015\u0019f\u00011\u0001\u001d\u0011\u0015Af\u00011\u0001#\u0003)9'/\u00193PkR\u0004X\u000f^\u0001\fe\u0006tw-\u001a\"vM\u001a,'/F\u0001#\u0003=\u0011\u0018M\\4f\u0005V4g-\u001a:`I\u0015\fHCA/a!\tIc,\u0003\u0002`U\t!QK\\5u\u0011\u001d\t\u0007\"!AA\u0002\t\n1\u0001\u001f\u00132\u00031\u0011\u0018M\\4f\u0005V4g-\u001a:!Q\tIA\r\u0005\u0002*K&\u0011aM\u000b\u0002\niJ\fgn]5f]R\u0004")
/* loaded from: input_file:com/intel/analytics/bigdl/nn/MaskPostProcessor.class */
public class MaskPostProcessor extends AbstractModule<Table, Tensor<Object>, Object> {
    private final TensorNumericMath.TensorNumeric<Object> ev;
    private final Sigmoid<Object> sigmoid;
    private transient Tensor<Object> rangeBuffer;

    public Tensor<Object> rangeBuffer() {
        return this.rangeBuffer;
    }

    public void rangeBuffer_$eq(Tensor<Object> tensor) {
        this.rangeBuffer = tensor;
    }

    private Sigmoid<Object> sigmoid() {
        return this.sigmoid;
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    public Tensor<Object> updateOutput(Table table) {
        Tensor tensor = (Tensor) table.apply(BoxesRunTime.boxToInteger(1));
        Tensor tensor2 = (Tensor) table.apply(BoxesRunTime.boxToInteger(2));
        int size = tensor.size(1);
        if (rangeBuffer() == null || rangeBuffer().nElement() != size) {
            rangeBuffer_$eq(Tensor$.MODULE$.apply(size, ClassTag$.MODULE$.Float(), this.ev));
            rangeBuffer().range(0.0d, size - 1, 1);
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        Tensor forward = sigmoid().forward(tensor);
        Predef$.MODULE$.require(tensor2.nDimension() == 1, () -> {
            return new StringBuilder(51).append("Labels should be tensor with one dimension,").append("but get ").append(tensor2.nDimension()).toString();
        });
        Predef$.MODULE$.require(rangeBuffer().nElement() == tensor2.nElement(), () -> {
            return new StringBuilder(52).append("number of masks should be same").append("with labels, but get ").append(this.rangeBuffer().nElement()).append(" ").append(tensor2.nElement()).toString();
        });
        output().resize(rangeBuffer().nElement(), 1, forward.size(3), forward.size(4));
        int i = 1;
        while (true) {
            int i2 = i;
            if (i2 > rangeBuffer().nElement()) {
                return output();
            }
            int unboxToFloat = ((int) BoxesRunTime.unboxToFloat(rangeBuffer().mo1135valueAt(i2))) + 1;
            output().narrow(1, i2, 1).copy(forward.narrow(1, i2, 1).narrow(2, ((int) BoxesRunTime.unboxToFloat(tensor2.mo1135valueAt(i2))) + 1, 1));
            i = i2 + 1;
        }
    }

    @Override // com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
    /* renamed from: updateGradInput, reason: avoid collision after fix types in other method */
    public Table updateGradInput2(Table table, Tensor<Object> tensor) {
        throw new UnsupportedOperationException("MaskPostProcessor only support inference");
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public MaskPostProcessor(TensorNumericMath.TensorNumeric<Object> tensorNumeric) {
        super(ClassTag$.MODULE$.apply(Table.class), ClassTag$.MODULE$.apply(Tensor.class), ClassTag$.MODULE$.Float(), tensorNumeric);
        this.ev = tensorNumeric;
        this.rangeBuffer = null;
        this.sigmoid = Sigmoid$.MODULE$.apply(ClassTag$.MODULE$.Float(), tensorNumeric);
    }
}
