package org.nd4j.linalg.convolution;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Col2Im;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Im2col;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.PaddingMode;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/nd4j/linalg/convolution/Convolution.class */
public class Convolution {

    /* loaded from: input_file:org/nd4j/linalg/convolution/Convolution$Type.class */
    public enum Type {
        FULL,
        VALID,
        SAME
    }

    private Convolution() {
    }

    public static INDArray col2im(INDArray iNDArray, int[] iArr, int[] iArr2, int i, int i2) {
        return col2im(iNDArray, iArr[0], iArr[1], iArr2[0], iArr2[1], i, i2);
    }

    public static INDArray col2im(INDArray iNDArray, int i, int i2, int i3, int i4, int i5, int i6) {
        if (iNDArray.rank() != 6) {
            throw new IllegalArgumentException("col2im input array must be rank 6");
        }
        Col2Im build = Col2Im.builder().inputArrays(new INDArray[]{iNDArray}).outputs(new INDArray[]{Nd4j.create(iNDArray.dataType(), iNDArray.size(0), iNDArray.size(1), i5, i6)}).conv2DConfig(Conv2DConfig.builder().sH(i).sW(i2).dH(1L).dW(1L).kH(i5).kW(i6).pH(i3).pW(i4).build()).build();
        Nd4j.getExecutioner().execAndReturn(build);
        return build.outputArguments().get(0);
    }

    public static INDArray col2im(INDArray iNDArray, INDArray iNDArray2, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8) {
        if (iNDArray.rank() != 6) {
            throw new IllegalArgumentException("col2im input array must be rank 6");
        }
        if (iNDArray2.rank() != 4) {
            throw new IllegalArgumentException("col2im output array must be rank 4");
        }
        Nd4j.getExecutioner().execAndReturn(Col2Im.builder().inputArrays(new INDArray[]{iNDArray}).outputs(new INDArray[]{iNDArray2}).conv2DConfig(Conv2DConfig.builder().sH(i).sW(i2).dH(i7).dW(i8).kH(i5).kW(i6).pH(i3).pW(i4).build()).build());
        return iNDArray2;
    }

    public static INDArray im2col(INDArray iNDArray, int[] iArr, int[] iArr2, int[] iArr3) {
        Nd4j.getCompressor().autoDecompress(iNDArray);
        return im2col(iNDArray, iArr[0], iArr[1], iArr2[0], iArr2[1], iArr3[0], iArr3[1], 0, false);
    }

    public static INDArray im2col(INDArray iNDArray, int i, int i2, int i3, int i4, int i5, int i6, boolean z) {
        return im2col(iNDArray, i, i2, i3, i4, i5, i6, 1, 1, z);
    }

    public static INDArray im2col(INDArray iNDArray, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, boolean z) {
        Nd4j.getCompressor().autoDecompress(iNDArray);
        return im2col(iNDArray, i, i2, i3, i4, i5, i6, i7, i8, z, Nd4j.create(new long[]{iNDArray.size(0), iNDArray.size(1), i, i2, outputSize(iNDArray.size(2), i, i3, i5, i7, z), outputSize(iNDArray.size(3), i2, i4, i6, i8, z)}, 'c'));
    }

    public static INDArray im2col(INDArray iNDArray, int i, int i2, int i3, int i4, int i5, int i6, boolean z, INDArray iNDArray2) {
        Im2col build = Im2col.builder().outputs(new INDArray[]{iNDArray2}).inputArrays(new INDArray[]{iNDArray}).conv2DConfig(Conv2DConfig.builder().kH(i).pW(i6).pH(i5).sH(i3).sW(i4).kH(i).kW(i2).dH(1L).dW(1L).paddingMode(z ? PaddingMode.SAME : PaddingMode.VALID).build()).build();
        Nd4j.getExecutioner().execAndReturn(build);
        return build.outputArguments().get(0);
    }

    public static INDArray im2col(INDArray iNDArray, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, boolean z, INDArray iNDArray2) {
        Im2col build = Im2col.builder().outputs(new INDArray[]{iNDArray2}).inputArrays(new INDArray[]{iNDArray}).conv2DConfig(Conv2DConfig.builder().pW(i6).pH(i5).sH(i3).sW(i4).kW(i2).kH(i).dW(i8).dH(i7).paddingMode(z ? PaddingMode.SAME : PaddingMode.VALID).build()).build();
        Nd4j.getExecutioner().execAndReturn(build);
        return build.outputArguments().get(0);
    }

    public static INDArray pooling2D(INDArray iNDArray, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, boolean z, Pooling2D.Pooling2DType pooling2DType, Pooling2D.Divisor divisor, double d, int i9, int i10, INDArray iNDArray2) {
        Nd4j.getExecutioner().execAndReturn(new Pooling2D(iNDArray, iNDArray2, Pooling2DConfig.builder().dH(i7).dW(i8).extra(d).kH(i).kW(i2).pH(i5).pW(i6).paddingMode(z ? PaddingMode.SAME : PaddingMode.VALID).sH(i3).sW(i4).type(pooling2DType).divisor(divisor).build()));
        return iNDArray2;
    }

    public static INDArray im2col(INDArray iNDArray, int i, int i2, int i3, int i4, int i5, int i6, int i7, boolean z) {
        Im2col build = Im2col.builder().inputArrays(new INDArray[]{iNDArray}).outputs(new INDArray[]{z ? Nd4j.createUninitialized(iNDArray.dataType(), new long[]{iNDArray.size(0), iNDArray.size(1), i, i2, (int) Math.ceil((((float) iNDArray.size(2)) * 1.0f) / i3), (int) Math.ceil((((float) iNDArray.size(3)) * 1.0f) / i4)}, 'c') : Nd4j.createUninitialized(iNDArray.dataType(), new long[]{iNDArray.size(0), iNDArray.size(1), i, i2, (((iNDArray.size(2) - (i + ((i - 1) * 0))) + (2 * i5)) / i3) + 1, (((iNDArray.size(3) - (i2 + ((i2 - 1) * 0))) + (2 * i6)) / i4) + 1}, 'c')}).conv2DConfig(Conv2DConfig.builder().pW(i6).pH(i5).sH(i3).sW(i4).kW(i2).kH(i).dW(1L).dH(1L).paddingMode(z ? PaddingMode.SAME : PaddingMode.VALID).build()).build();
        Nd4j.getExecutioner().execAndReturn(build);
        return build.outputArguments().get(0);
    }

    @Deprecated
    public static long outSize(long j, long j2, long j3, long j4, int i, boolean z) {
        long effectiveKernelSize = effectiveKernelSize(j2, i);
        return z ? (((((j + (j4 * 2)) - effectiveKernelSize) + j3) - 1) / j3) + 1 : (((j + (j4 * 2)) - effectiveKernelSize) / j3) + 1;
    }

    public static long outputSize(long j, long j2, long j3, long j4, int i, boolean z) {
        return z ? (int) Math.ceil((((float) j) * 1.0f) / ((float) j3)) : (((j - effectiveKernelSize(j2, i)) + (2 * j4)) / j3) + 1;
    }

    public static long effectiveKernelSize(long j, int i) {
        return j + ((j - 1) * (i - 1));
    }

    public static INDArray conv2d(INDArray iNDArray, INDArray iNDArray2, Type type) {
        return Nd4j.getConvolution().conv2d(iNDArray, iNDArray2, type);
    }

    public static INDArray convn(INDArray iNDArray, INDArray iNDArray2, Type type, int[] iArr) {
        return Nd4j.getConvolution().convn(iNDArray, iNDArray2, type, iArr);
    }

    public static INDArray convn(INDArray iNDArray, INDArray iNDArray2, Type type) {
        return Nd4j.getConvolution().convn(iNDArray, iNDArray2, type);
    }
}
