package org.nd4j.linalg.indexing.masking;

import java.util.ArrayList;
import org.nd4j.autodiff.samediff.SDIndex;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.controlflow.Where;
import org.nd4j.linalg.api.ops.impl.shape.Gather;
import org.nd4j.linalg.api.ops.impl.shape.Squeeze;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.guava.primitives.Longs;

/* loaded from: input_file:org/nd4j/linalg/indexing/masking/Masking.class */
public class Masking {
    public static SDVariable applyMask(SameDiff sameDiff, SDVariable sDVariable, SDVariable sDVariable2, int i) {
        sDVariable2.shape();
        sDVariable2.rank();
        SDVariable shape = sDVariable2.shape();
        return sameDiff.gather(sDVariable.reshape(sameDiff.concat(0, shape.get(SDIndex.interval((Integer) 0, Integer.valueOf(i))), sameDiff.prod(shape.get(SDIndex.interval((Integer) 0, Integer.valueOf(sDVariable2.rank().eval().getInt(0)))), 0), shape.get(SDIndex.interval(Integer.valueOf(i), Integer.valueOf(sDVariable2.rank().eval().getInt(0)))))), sameDiff.squeeze(sameDiff.where(sDVariable2.reshape(-1)), 0), i);
    }

    public static INDArray applyMask(INDArray iNDArray, INDArray iNDArray2, int i) {
        long[] shape = iNDArray2.shape();
        long length = shape.length;
        long[] shape2 = iNDArray.shape();
        Preconditions.checkState(shape.length > 0, "Mask shape must not be scalar");
        long j = 1;
        for (int i2 = 0; i2 < i + length; i2++) {
            j *= shape2[i2];
        }
        ArrayList arrayList = new ArrayList();
        for (int i3 = 0; i3 < i; i3++) {
            arrayList.add(Long.valueOf(shape2[i3]));
        }
        arrayList.add(Long.valueOf(j));
        for (int i4 = i; i4 < i + length; i4++) {
            arrayList.add(Long.valueOf(shape2[i4]));
        }
        return Nd4j.getExecutioner().exec(new Gather(iNDArray.reshape(Longs.toArray(arrayList)), Nd4j.getExecutioner().exec(new Squeeze(Nd4j.getExecutioner().exec(new Where(iNDArray2.reshape(-1)))[0], 1))[0], i))[0];
    }
}
