package org.nd4j.linalg.factory;

import java.util.Arrays;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAMax;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAMin;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAddOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastCopyOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastDivOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMax;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMin;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastRDivOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastRSubOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastEqualTo;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThan;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThanOrEqual;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastLessThan;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastLessThanOrEqual;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastNotEqual;
import org.nd4j.linalg.api.ops.impl.transforms.custom.EqualTo;
import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThan;
import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Max;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Min;
import org.nd4j.linalg.api.ops.impl.transforms.custom.NotEqualTo;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.CopyOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.DivOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RDivOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SubOp;
import org.nd4j.linalg.api.ops.impl.transforms.same.AMax;
import org.nd4j.linalg.api.ops.impl.transforms.same.AMin;

/* loaded from: input_file:org/nd4j/linalg/factory/Broadcast.class */
public class Broadcast {
    private Broadcast() {
    }

    public static INDArray add(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int... iArr) {
        if (iArr != null && iArr.length != 0 && !iNDArray2.isScalar()) {
            return Nd4j.getExecutioner().exec((BroadcastOp) new BroadcastAddOp(iNDArray, iNDArray2, iNDArray3, iArr));
        }
        validateShapesNoDimCase(iNDArray, iNDArray2, iNDArray3);
        return Nd4j.getExecutioner().exec(new AddOp(iNDArray, iNDArray2, iNDArray3))[0];
    }

    public static INDArray copy(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int... iArr) {
        if (iArr != null && iArr.length != 0 && !iNDArray2.isScalar()) {
            return Nd4j.getExecutioner().exec((BroadcastOp) new BroadcastCopyOp(iNDArray, iNDArray2, iNDArray3, iArr));
        }
        validateShapesNoDimCase(iNDArray, iNDArray2, iNDArray3);
        return Nd4j.getExecutioner().exec(new CopyOp(iNDArray, iNDArray2, iNDArray3));
    }

    public static INDArray div(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int... iArr) {
        if (iArr != null && iArr.length != 0 && !iNDArray2.isScalar()) {
            return Nd4j.getExecutioner().exec((BroadcastOp) new BroadcastDivOp(iNDArray, iNDArray2, iNDArray3, iArr));
        }
        validateShapesNoDimCase(iNDArray, iNDArray2, iNDArray3);
        return Nd4j.getExecutioner().exec(new DivOp(iNDArray, iNDArray2, iNDArray3))[0];
    }

    public static INDArray eq(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int... iArr) {
        if (iArr != null && iArr.length != 0 && !iNDArray2.isScalar()) {
            return Nd4j.getExecutioner().exec((BroadcastOp) new BroadcastEqualTo(iNDArray, iNDArray2, iNDArray3, iArr));
        }
        validateShapesNoDimCase(iNDArray, iNDArray2, iNDArray3);
        return Nd4j.getExecutioner().exec(new EqualTo(iNDArray, iNDArray2, iNDArray3))[0];
    }

    public static INDArray gt(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int... iArr) {
        if (iArr != null && iArr.length != 0 && !iNDArray2.isScalar()) {
            return Nd4j.getExecutioner().exec((BroadcastOp) new BroadcastGreaterThan(iNDArray, iNDArray2, iNDArray3, iArr));
        }
        validateShapesNoDimCase(iNDArray, iNDArray2, iNDArray3);
        return Nd4j.getExecutioner().exec(new GreaterThan(iNDArray, iNDArray2, iNDArray3))[0];
    }

    public static INDArray gte(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int... iArr) {
        if (iArr != null && iArr.length != 0 && !iNDArray2.isScalar()) {
            return Nd4j.getExecutioner().exec((BroadcastOp) new BroadcastGreaterThanOrEqual(iNDArray, iNDArray2, iNDArray3, iArr));
        }
        validateShapesNoDimCase(iNDArray, iNDArray2, iNDArray3);
        return Nd4j.getExecutioner().exec(new GreaterThanOrEqual(iNDArray, iNDArray2, iNDArray3))[0];
    }

    public static INDArray lt(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int... iArr) {
        if (iArr != null && iArr.length != 0 && !iNDArray2.isScalar()) {
            return Nd4j.getExecutioner().exec((BroadcastOp) new BroadcastLessThan(iNDArray, iNDArray2, iNDArray3, iArr));
        }
        validateShapesNoDimCase(iNDArray, iNDArray2, iNDArray3);
        return Nd4j.getExecutioner().exec(new LessThan(iNDArray, iNDArray2, iNDArray3))[0];
    }

    public static INDArray lte(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int... iArr) {
        if (iArr != null && iArr.length != 0 && !iNDArray2.isScalar()) {
            return Nd4j.getExecutioner().exec((BroadcastOp) new BroadcastLessThanOrEqual(iNDArray, iNDArray2, iNDArray3, iArr));
        }
        validateShapesNoDimCase(iNDArray, iNDArray2, iNDArray3);
        return Nd4j.getExecutioner().exec(new LessThanOrEqual(iNDArray, iNDArray2, iNDArray3))[0];
    }

    public static INDArray mul(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int... iArr) {
        if (iArr != null && iArr.length != 0 && !iNDArray2.isScalar()) {
            return Nd4j.getExecutioner().exec((BroadcastOp) new BroadcastMulOp(iNDArray, iNDArray2, iNDArray3, iArr));
        }
        validateShapesNoDimCase(iNDArray, iNDArray2, iNDArray3);
        return Nd4j.getExecutioner().exec(new MulOp(iNDArray, iNDArray2, iNDArray3))[0];
    }

    public static INDArray neq(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int... iArr) {
        if (iArr != null && iArr.length != 0 && !iNDArray2.isScalar()) {
            return Nd4j.getExecutioner().exec((BroadcastOp) new BroadcastNotEqual(iNDArray, iNDArray2, iNDArray3, iArr));
        }
        validateShapesNoDimCase(iNDArray, iNDArray2, iNDArray3);
        return Nd4j.getExecutioner().exec(new NotEqualTo(iNDArray, iNDArray2, iNDArray3))[0];
    }

    public static INDArray rdiv(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int... iArr) {
        if (iArr != null && iArr.length != 0 && !iNDArray2.isScalar()) {
            return Nd4j.getExecutioner().exec((BroadcastOp) new BroadcastRDivOp(iNDArray, iNDArray2, iNDArray3, iArr));
        }
        validateShapesNoDimCase(iNDArray, iNDArray2, iNDArray3);
        return Nd4j.getExecutioner().exec(new RDivOp(iNDArray, iNDArray2, iNDArray3))[0];
    }

    public static INDArray rsub(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int... iArr) {
        if (iArr != null && iArr.length != 0 && !iNDArray2.isScalar()) {
            return Nd4j.getExecutioner().exec((BroadcastOp) new BroadcastRSubOp(iNDArray, iNDArray2, iNDArray3, iArr));
        }
        validateShapesNoDimCase(iNDArray, iNDArray2, iNDArray3);
        return Nd4j.getExecutioner().exec(new SubOp(iNDArray, iNDArray2, iNDArray3))[0];
    }

    public static INDArray sub(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int... iArr) {
        if (iArr != null && iArr.length != 0 && !iNDArray2.isScalar()) {
            return Nd4j.getExecutioner().exec((BroadcastOp) new BroadcastSubOp(iNDArray, iNDArray2, iNDArray3, iArr));
        }
        validateShapesNoDimCase(iNDArray, iNDArray2, iNDArray3);
        return Nd4j.getExecutioner().exec(new SubOp(iNDArray, iNDArray2, iNDArray3))[0];
    }

    public static INDArray max(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int... iArr) {
        if (iArr != null && iArr.length != 0 && !iNDArray2.isScalar()) {
            return Nd4j.getExecutioner().exec((BroadcastOp) new BroadcastMax(iNDArray, iNDArray2, iNDArray3, iArr));
        }
        validateShapesNoDimCase(iNDArray, iNDArray2, iNDArray3);
        return Nd4j.getExecutioner().exec(new Max(iNDArray, iNDArray2, iNDArray3))[0];
    }

    public static INDArray min(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int... iArr) {
        if (iArr != null && iArr.length != 0 && !iNDArray2.isScalar()) {
            return Nd4j.getExecutioner().exec((BroadcastOp) new BroadcastMin(iNDArray, iNDArray2, iNDArray3, iArr));
        }
        validateShapesNoDimCase(iNDArray, iNDArray2, iNDArray3);
        return Nd4j.getExecutioner().exec(new Min(iNDArray, iNDArray2, iNDArray3))[0];
    }

    public static INDArray amax(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int... iArr) {
        if (iArr != null && iArr.length != 0 && !iNDArray2.isScalar()) {
            return Nd4j.getExecutioner().exec((BroadcastOp) new BroadcastAMax(iNDArray, iNDArray2, iNDArray3, iArr));
        }
        validateShapesNoDimCase(iNDArray, iNDArray2, iNDArray3);
        return Nd4j.getExecutioner().exec(new AMax(iNDArray, iNDArray2, iNDArray3));
    }

    public static INDArray amin(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int... iArr) {
        if (iArr != null && iArr.length != 0 && !iNDArray2.isScalar()) {
            return Nd4j.getExecutioner().exec((BroadcastOp) new BroadcastAMin(iNDArray, iNDArray2, iNDArray3, iArr));
        }
        validateShapesNoDimCase(iNDArray, iNDArray2, iNDArray3);
        return Nd4j.getExecutioner().exec(new AMin(iNDArray, iNDArray2, iNDArray3));
    }

    public static void validateShapesNoDimCase(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        if (iNDArray.isScalar() || iNDArray2.isScalar()) {
            return;
        }
        Preconditions.checkArgument(iNDArray.equalShapes(iNDArray2), "When no dimensions are provided, X and Y shapes must be equal (x shape: %s, y shape: %s)", iNDArray.shape(), iNDArray2.shape());
        Preconditions.checkArgument(iNDArray.equalShapes(iNDArray3), "When no dimensions are provided, X and Z (result) shapes must be equal (x shape: %s, z shape: %s)", iNDArray.shape(), iNDArray3.shape());
    }

    public static void validateBroadcastDims(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int... iArr) {
        Preconditions.checkArgument(iNDArray == iNDArray3 || iNDArray.equalShapes(iNDArray3), "X and Z arrays must be equal shape. X shape: %s, Z shape: %s", iNDArray.shape(), iNDArray3.shape());
        long[] shape = iNDArray.shape();
        long[] shape2 = iNDArray2.shape();
        if (iArr.length == 1 && shape2.length == 2 && (shape2[0] == 1 || shape2[1] == 1)) {
            int length = iArr[0] < 0 ? iArr[0] + shape.length : iArr[0];
            if (shape2[0] == 1) {
                Preconditions.checkState(shape[length] == shape2[1], "Shapes do not match: dimensions[0] - x[%s] must match y[%s], x shape %s, y shape %s, dimensions %s", Integer.valueOf(iArr[0]), 1, shape, shape2, iArr);
                return;
            } else {
                Preconditions.checkState(shape[length] == shape2[0], "Shapes do not match: dimensions[0] - x[%s] must match y[%s], x shape %s, y shape %s, dimensions %s", Integer.valueOf(iArr[0]), 0, shape, shape2, iArr);
                return;
            }
        }
        if (shape.length == shape2.length) {
            int length2 = iArr.length;
            for (int i = 0; i < length2; i++) {
                int i2 = iArr[i];
                int length3 = i2 < 0 ? i2 + shape.length : i2;
                Preconditions.checkState(shape[length3] == shape2[length3], "Dimensions mismatch on dimension %s: x shape %s, y shape %s", Integer.valueOf(i2), shape, shape2);
            }
            return;
        }
        if (iArr.length != shape2.length) {
            throw new IllegalStateException("Invalid broadcast dimensions: x shape " + Arrays.toString(shape) + ", y shape " + Arrays.toString(shape2) + ", dimensions " + Arrays.toString(iArr));
        }
        for (int i3 = 0; i3 < iArr.length; i3++) {
            int length4 = iArr[i3] < 0 ? iArr[i3] + shape.length : iArr[i3];
            Preconditions.checkState(shape[length4] == shape2[i3], "Shapes do not match: dimensions[%s] - x[%s] must match y[%s], x shape %s, y shape %s, dimensions %s", Integer.valueOf(i3), Integer.valueOf(length4), Integer.valueOf(i3), shape, shape2, iArr);
        }
    }
}
