package org.nd4j.linalg.checkutil;

import java.io.PrintStream;
import java.util.Arrays;
import org.apache.commons.math3.linear.BlockRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/nd4j/linalg/checkutil/CheckUtil.class */
public class CheckUtil {
    public static boolean checkMmul(INDArray iNDArray, INDArray iNDArray2, double d, double d2) {
        if (iNDArray.size(1) != iNDArray2.size(0)) {
            throw new IllegalArgumentException("first.columns != second.rows");
        }
        RealMatrix convertToApacheMatrix = convertToApacheMatrix(iNDArray);
        RealMatrix convertToApacheMatrix2 = convertToApacheMatrix(iNDArray2);
        INDArray mmul = iNDArray.mmul(iNDArray2);
        RealMatrix multiply = convertToApacheMatrix.multiply(convertToApacheMatrix2);
        if (!checkShape(multiply, mmul)) {
            return false;
        }
        boolean checkEntries = checkEntries(multiply, mmul, d, d2);
        if (!checkEntries) {
            printFailureDetails(iNDArray, iNDArray2, multiply, mmul, Shape.toOffsetZeroCopy(iNDArray).mmul(Shape.toOffsetZeroCopy(iNDArray2)), "mmul");
        }
        return checkEntries;
    }

    public static boolean checkGemm(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, boolean z, boolean z2, double d, double d2, double d3, double d4) {
        if ((z ? iNDArray.rows() : iNDArray.columns()) != (z2 ? iNDArray2.columns() : iNDArray2.rows())) {
            throw new IllegalArgumentException("Common dimensions don't match: a.shape=" + Arrays.toString(iNDArray.shape()) + ", b.shape=" + Arrays.toString(iNDArray2.shape()) + ", tA=" + z + ", tb=" + z2);
        }
        long columns = z ? iNDArray.columns() : iNDArray.rows();
        long rows = z2 ? iNDArray2.rows() : iNDArray2.columns();
        if (iNDArray3.rows() != columns || iNDArray3.columns() != rows) {
            throw new IllegalArgumentException("C does not match outRows or outCols");
        }
        if (iNDArray3.offset() != 0 || iNDArray3.ordering() != 'f') {
            throw new IllegalArgumentException("Invalid c");
        }
        RealMatrix add = convertToApacheMatrix(z ? iNDArray.transpose() : iNDArray).scalarMultiply(d).multiply(convertToApacheMatrix(z2 ? iNDArray2.transpose() : iNDArray2)).add(convertToApacheMatrix(iNDArray3).scalarMultiply(d2));
        INDArray create = Nd4j.create(iNDArray3.shape(), 'f');
        create.assign(iNDArray3);
        INDArray create2 = Nd4j.create(iNDArray3.shape(), 'f');
        create2.assign(iNDArray3);
        INDArray gemm = Nd4j.gemm(iNDArray, iNDArray2, iNDArray3, z, z2, d, d2);
        if (gemm != iNDArray3) {
            System.out.println("Returned different array than c");
            return false;
        }
        if (!checkShape(add, gemm)) {
            return false;
        }
        boolean checkEntries = checkEntries(add, gemm, d3, d4);
        if (!checkEntries) {
            printGemmFailureDetails(iNDArray, iNDArray2, create2, z, z2, d, d2, add, gemm, Nd4j.gemm(Shape.toOffsetZeroCopy(iNDArray), Shape.toOffsetZeroCopy(iNDArray2), create, z, z2, d, d2));
        }
        return checkEntries;
    }

    public static boolean checkAdd(INDArray iNDArray, INDArray iNDArray2, double d, double d2) {
        RealMatrix convertToApacheMatrix = convertToApacheMatrix(iNDArray);
        RealMatrix convertToApacheMatrix2 = convertToApacheMatrix(iNDArray2);
        INDArray add = iNDArray.add(iNDArray2);
        RealMatrix add2 = convertToApacheMatrix.add(convertToApacheMatrix2);
        if (!checkShape(add2, add)) {
            return false;
        }
        boolean checkEntries = checkEntries(add2, add, d, d2);
        if (!checkEntries) {
            printFailureDetails(iNDArray, iNDArray2, add2, add, Shape.toOffsetZeroCopy(iNDArray).add(Shape.toOffsetZeroCopy(iNDArray2)), AddOp.OP_NAME);
        }
        return checkEntries;
    }

    public static boolean checkSubtract(INDArray iNDArray, INDArray iNDArray2, double d, double d2) {
        RealMatrix convertToApacheMatrix = convertToApacheMatrix(iNDArray);
        RealMatrix convertToApacheMatrix2 = convertToApacheMatrix(iNDArray2);
        INDArray sub = iNDArray.sub(iNDArray2);
        RealMatrix subtract = convertToApacheMatrix.subtract(convertToApacheMatrix2);
        if (!checkShape(subtract, sub)) {
            return false;
        }
        boolean checkEntries = checkEntries(subtract, sub, d, d2);
        if (!checkEntries) {
            printFailureDetails(iNDArray, iNDArray2, subtract, sub, Shape.toOffsetZeroCopy(iNDArray).sub(Shape.toOffsetZeroCopy(iNDArray2)), "sub");
        }
        return checkEntries;
    }

    public static boolean checkMulManually(INDArray iNDArray, INDArray iNDArray2, double d, double d2) {
        INDArray mul = iNDArray.mul(iNDArray2);
        long[] shape = iNDArray.shape();
        INDArray zeros = Nd4j.zeros(iNDArray.shape());
        for (int i = 0; i < shape[0]; i++) {
            for (int i2 = 0; i2 < shape[1]; i2++) {
                zeros.putScalar(new int[]{i, i2}, iNDArray.getDouble(i, i2) * iNDArray2.getDouble(i, i2));
            }
        }
        if (!checkShape(zeros, mul)) {
            return false;
        }
        boolean checkEntries = checkEntries(zeros, mul, d, d2);
        if (!checkEntries) {
            printFailureDetails(iNDArray, iNDArray2, zeros, mul, Shape.toOffsetZeroCopy(iNDArray).mul(Shape.toOffsetZeroCopy(iNDArray2)), "mul");
        }
        return checkEntries;
    }

    public static boolean checkDivManually(INDArray iNDArray, INDArray iNDArray2, double d, double d2) {
        INDArray div = iNDArray.div(iNDArray2);
        long[] shape = iNDArray.shape();
        INDArray zeros = Nd4j.zeros(iNDArray.shape());
        for (int i = 0; i < shape[0]; i++) {
            for (int i2 = 0; i2 < shape[1]; i2++) {
                zeros.putScalar(new int[]{i, i2}, iNDArray.getDouble(i, i2) / iNDArray2.getDouble(i, i2));
            }
        }
        if (!checkShape(zeros, div)) {
            return false;
        }
        boolean checkEntries = checkEntries(zeros, div, d, d2);
        if (!checkEntries) {
            printFailureDetails(iNDArray, iNDArray2, zeros, div, Shape.toOffsetZeroCopy(iNDArray).mul(Shape.toOffsetZeroCopy(iNDArray2)), "div");
        }
        return checkEntries;
    }

    private static boolean checkShape(RealMatrix realMatrix, INDArray iNDArray) {
        long[] jArr = {realMatrix.getRowDimension(), realMatrix.getColumnDimension()};
        if (Arrays.equals(jArr, iNDArray.shape())) {
            return true;
        }
        System.out.println("Failure on shape: " + Arrays.toString(iNDArray.shape()) + ", expected " + Arrays.toString(jArr));
        return false;
    }

    private static boolean checkShape(INDArray iNDArray, INDArray iNDArray2) {
        if (Arrays.equals(iNDArray.shape(), iNDArray2.shape())) {
            return true;
        }
        System.out.println("Failure on shape: " + Arrays.toString(iNDArray2.shape()) + ", expected " + Arrays.toString(iNDArray.shape()));
        return false;
    }

    public static boolean checkEntries(RealMatrix realMatrix, INDArray iNDArray, double d, double d2) {
        int[] iArr = {realMatrix.getRowDimension(), realMatrix.getColumnDimension()};
        for (int i = 0; i < iArr[0]; i++) {
            for (int i2 = 0; i2 < iArr[1]; i2++) {
                double entry = realMatrix.getEntry(i, i2);
                double d3 = iNDArray.getDouble(i, i2);
                if (Double.isNaN(d3)) {
                    PrintStream printStream = System.out;
                    printStream.println("NaN failure on value: (" + i + "," + i2 + " exp=" + entry + ", act=" + printStream);
                    return false;
                }
                if (entry != 0.0d || d3 != 0.0d) {
                    double abs = Math.abs(entry - d3);
                    if (abs / (Math.abs(entry) + Math.abs(d3)) > d && abs > d2) {
                        PrintStream printStream2 = System.out;
                        printStream2.println("Failure on value: (" + i + "," + i2 + " exp=" + entry + ", act=" + printStream2 + ", absError=" + d3 + ", relError=" + printStream2);
                        return false;
                    }
                }
            }
        }
        return true;
    }

    public static boolean checkEntries(INDArray iNDArray, INDArray iNDArray2, double d, double d2) {
        long[] shape = iNDArray.shape();
        for (int i = 0; i < shape[0]; i++) {
            for (int i2 = 0; i2 < shape[1]; i2++) {
                double d3 = iNDArray.getDouble(i, i2);
                double d4 = iNDArray2.getDouble(i, i2);
                if (d3 != 0.0d || d4 != 0.0d) {
                    double abs = Math.abs(d3 - d4);
                    if (abs / (Math.abs(d3) + Math.abs(d4)) > d && abs > d2) {
                        PrintStream printStream = System.out;
                        printStream.println("Failure on value: (" + i + "," + i2 + " exp=" + d3 + ", act=" + printStream + ", absError=" + d4 + ", relError=" + printStream);
                        return false;
                    }
                }
            }
        }
        return true;
    }

    public static RealMatrix convertToApacheMatrix(INDArray iNDArray) {
        if (iNDArray.rank() != 2) {
            throw new IllegalArgumentException("Input rank is not 2 (not matrix)");
        }
        long[] shape = iNDArray.shape();
        if (iNDArray.columns() > Integer.MAX_VALUE || iNDArray.rows() > Integer.MAX_VALUE) {
            throw new ND4JArraySizeException();
        }
        BlockRealMatrix blockRealMatrix = new BlockRealMatrix((int) shape[0], (int) shape[1]);
        for (int i = 0; i < shape[0]; i++) {
            for (int i2 = 0; i2 < shape[1]; i2++) {
                blockRealMatrix.setEntry(i, i2, iNDArray.getDouble(i, i2));
            }
        }
        return blockRealMatrix;
    }

    public static INDArray convertFromApacheMatrix(RealMatrix realMatrix, DataType dataType) {
        long[] jArr = {realMatrix.getRowDimension(), realMatrix.getColumnDimension()};
        INDArray create = Nd4j.create(dataType, jArr);
        for (int i = 0; i < jArr[0]; i++) {
            for (int i2 = 0; i2 < jArr[1]; i2++) {
                create.putScalar(new int[]{i, i2}, realMatrix.getEntry(i, i2));
            }
        }
        return create;
    }

    public static void printFailureDetails(INDArray iNDArray, INDArray iNDArray2, RealMatrix realMatrix, INDArray iNDArray3, INDArray iNDArray4, String str) {
        System.out.println("\nFactory: " + Nd4j.factory().getClass() + "\n");
        System.out.println("First:");
        printMatrixFullPrecision(iNDArray);
        System.out.println("\nSecond:");
        printMatrixFullPrecision(iNDArray2);
        System.out.println("\nExpected (Apache Commons)");
        printApacheMatrix(realMatrix);
        System.out.println("\nSame Nd4j op on copies: (Shape.toOffsetZeroCopy(first)." + str + "(Shape.toOffsetZeroCopy(second)))");
        printMatrixFullPrecision(iNDArray4);
        System.out.println("\nActual:");
        printMatrixFullPrecision(iNDArray3);
    }

    public static void printGemmFailureDetails(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, boolean z, boolean z2, double d, double d2, RealMatrix realMatrix, INDArray iNDArray4, INDArray iNDArray5) {
        System.out.println("\nFactory: " + Nd4j.factory().getClass() + "\n");
        PrintStream printStream = System.out;
        printStream.println("Op: gemm(a,b,c,transposeA=" + z + ",transposeB=" + z2 + ",alpha=" + d + ",beta=" + printStream + ")");
        System.out.println("a:");
        printMatrixFullPrecision(iNDArray);
        System.out.println("\nb:");
        printMatrixFullPrecision(iNDArray2);
        System.out.println("\nc:");
        printMatrixFullPrecision(iNDArray3);
        System.out.println("\nExpected (Apache Commons)");
        printApacheMatrix(realMatrix);
        PrintStream printStream2 = System.out;
        printStream2.println("\nSame Nd4j op on zero offset copies: gemm(aCopy,bCopy,cCopy," + z + "," + z2 + "," + d + "," + printStream2 + ")");
        printMatrixFullPrecision(iNDArray5);
        System.out.println("\nActual:");
        printMatrixFullPrecision(iNDArray4);
    }

    public static void printMatrixFullPrecision(INDArray iNDArray) {
        boolean z = iNDArray.data().dataType() == DataType.FLOAT;
        printNDArrayHeader(iNDArray);
        long[] shape = iNDArray.shape();
        for (int i = 0; i < shape[0]; i++) {
            for (int i2 = 0; i2 < shape[1]; i2++) {
                if (z) {
                    System.out.print(iNDArray.getFloat(i, i2));
                } else {
                    System.out.print(iNDArray.getDouble(i, i2));
                }
                if (i2 != shape[1] - 1) {
                    System.out.print(", ");
                } else {
                    System.out.println();
                }
            }
        }
    }

    public static void printNDArrayHeader(INDArray iNDArray) {
        PrintStream printStream = System.out;
        DataType dataType = iNDArray.data().dataType();
        char ordering = iNDArray.ordering();
        long offset = iNDArray.offset();
        String arrays = Arrays.toString(iNDArray.shape());
        String arrays2 = Arrays.toString(iNDArray.stride());
        long length = iNDArray.length();
        iNDArray.data().length();
        printStream.println(dataType + " - order=" + ordering + ", offset=" + offset + ", shape=" + printStream + ", stride=" + arrays + ", length=" + arrays2 + ", data().length()=" + length);
    }

    public static void printFailureDetails(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, INDArray iNDArray5, String str) {
        System.out.println("\nFactory: " + Nd4j.factory().getClass() + "\n");
        System.out.println("First:");
        printMatrixFullPrecision(iNDArray);
        System.out.println("\nSecond:");
        printMatrixFullPrecision(iNDArray2);
        System.out.println("\nExpected");
        printMatrixFullPrecision(iNDArray3);
        System.out.println("\nSame Nd4j op on copies: (Shape.toOffsetZeroCopy(first)." + str + "(Shape.toOffsetZeroCopy(second)))");
        printMatrixFullPrecision(iNDArray5);
        System.out.println("\nActual:");
        printMatrixFullPrecision(iNDArray4);
    }

    public static void printApacheMatrix(RealMatrix realMatrix) {
        int rowDimension = realMatrix.getRowDimension();
        int columnDimension = realMatrix.getColumnDimension();
        System.out.println("Apache Commons RealMatrix: Shape: [" + rowDimension + "," + columnDimension + "]");
        for (int i = 0; i < rowDimension; i++) {
            for (int i2 = 0; i2 < columnDimension; i2++) {
                System.out.print(realMatrix.getEntry(i, i2));
                if (i2 != columnDimension - 1) {
                    System.out.print(", ");
                } else {
                    System.out.println();
                }
            }
        }
    }
}
