package org.nd4j.autodiff.listeners.debugging;

import java.io.File;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/nd4j/autodiff/listeners/debugging/ArraySavingListener.class */
public class ArraySavingListener extends BaseListener {
    protected final File dir;
    protected int count = 0;

    public ArraySavingListener(@NonNull File file) {
        if (file == null) {
            throw new NullPointerException("dir is marked non-null but is null");
        }
        if (!file.exists()) {
            file.mkdir();
        }
        if (file.listFiles() != null && file.listFiles().length > 0) {
            throw new IllegalStateException("Directory is not empty: " + file.getAbsolutePath());
        }
        this.dir = file;
    }

    @Override // org.nd4j.autodiff.listeners.Listener
    public boolean isActive(Operation operation) {
        return true;
    }

    @Override // org.nd4j.autodiff.listeners.BaseListener, org.nd4j.autodiff.listeners.Listener
    public void opExecution(SameDiff sameDiff, At at, MultiDataSet multiDataSet, SameDiffOp sameDiffOp, OpContext opContext, INDArray[] iNDArrayArr) {
        List<String> outputsOfOp = sameDiffOp.getOutputsOfOp();
        for (int i = 0; i < iNDArrayArr.length; i++) {
            int i2 = this.count;
            this.count = i2 + 1;
            File file = new File(this.dir, i2 + "_" + outputsOfOp.get(i).replaceAll("/", "__") + ".bin");
            try {
                Nd4j.saveBinary(iNDArrayArr[i], file);
                System.out.println(file.getAbsolutePath());
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
    }

    public static void compare(File file, File file2, double d) throws Exception {
        File[] listFiles = file.listFiles();
        File[] listFiles2 = file2.listFiles();
        Preconditions.checkNotNull(listFiles, "No files in directory 1: %s", file);
        Preconditions.checkNotNull(listFiles2, "No files in directory 2: %s", file2);
        Preconditions.checkState(listFiles.length == listFiles2.length, "Different number of files: %s vs %s", listFiles.length, listFiles2.length);
        toMap(listFiles);
        Map<String, File> map = toMap(listFiles2);
        for (File file3 : listFiles) {
            String name = file3.getName();
            String substring = name.substring(name.indexOf(95) + 1, name.length() - 4);
            File file4 = map.get(substring);
            INDArray readBinary = Nd4j.readBinary(file3);
            INDArray readBinary2 = Nd4j.readBinary(file4);
            if (readBinary.equalsWithEps(readBinary2, d)) {
                System.out.println("Equals: " + substring.replaceAll("__", "/"));
            } else if (readBinary.dataType() == DataType.BOOL) {
                INDArray exec = Nd4j.exec(new Xor(readBinary, readBinary2));
                System.out.println("FAILS: " + substring.replaceAll("__", "/") + " - boolean, # differences = " + exec.castTo(DataType.INT).sumNumber().intValue());
                System.out.println("\t" + file3.getAbsolutePath());
                System.out.println("\t" + file4.getAbsolutePath());
                exec.close();
            } else {
                INDArray sub = readBinary.sub(readBinary2);
                INDArray abs = Nd4j.math.abs(sub);
                System.out.println("FAILS: " + substring.replaceAll("__", "/") + " - max difference = " + abs.maxNumber().doubleValue());
                System.out.println("\t" + file3.getAbsolutePath());
                System.out.println("\t" + file4.getAbsolutePath());
                sub.close();
                abs.close();
            }
            readBinary.close();
            readBinary2.close();
        }
    }

    private static Map<String, File> toMap(File[] fileArr) {
        HashMap hashMap = new HashMap();
        for (File file : fileArr) {
            String name = file.getName();
            hashMap.put(name.substring(name.indexOf(95) + 1, name.length() - 4), file);
        }
        return hashMap;
    }
}
