package org.nd4j.evaluation.classification;

import java.io.Serializable;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.primitives.Triple;
import org.nd4j.evaluation.BaseEvaluation;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.IMetric;
import org.nd4j.evaluation.curves.PrecisionRecallCurve;
import org.nd4j.evaluation.curves.RocCurve;
import org.nd4j.evaluation.serde.ROCSerializer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.ReduceOp;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;

@JsonSerialize(using = ROCSerializer.class)
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY)
@JsonIgnoreProperties({"probAndLabel", "exactAllocBlockSize"})
/* loaded from: input_file:org/nd4j/evaluation/classification/ROC.class */
public class ROC extends BaseEvaluation<ROC> {
    private static final int DEFAULT_EXACT_ALLOC_BLOCK_SIZE = 2048;
    private final Map<Double, CountsForThreshold> counts;
    private int thresholdSteps;
    private long countActualPositive;
    private long countActualNegative;
    private Double auc;
    private Double auprc;
    private RocCurve rocCurve;
    private PrecisionRecallCurve prCurve;
    private boolean isExact;
    private INDArray probAndLabel;
    private int exampleCount;
    private boolean rocRemoveRedundantPts;
    private int exactAllocBlockSize;
    protected int axis;

    /* loaded from: input_file:org/nd4j/evaluation/classification/ROC$CountsForThreshold.class */
    public static class CountsForThreshold implements Serializable, Cloneable {
        private double threshold;
        private long countTruePositive;
        private long countFalsePositive;

        public CountsForThreshold(double d) {
            this(d, 0L, 0L);
        }

        /* renamed from: clone, reason: merged with bridge method [inline-methods] */
        public CountsForThreshold m2237clone() {
            return new CountsForThreshold(this.threshold, this.countTruePositive, this.countFalsePositive);
        }

        public void incrementFalsePositive(long j) {
            this.countFalsePositive += j;
        }

        public void incrementTruePositive(long j) {
            this.countTruePositive += j;
        }

        public CountsForThreshold(double d, long j, long j2) {
            this.threshold = d;
            this.countTruePositive = j;
            this.countFalsePositive = j2;
        }

        public double getThreshold() {
            return this.threshold;
        }

        public long getCountTruePositive() {
            return this.countTruePositive;
        }

        public long getCountFalsePositive() {
            return this.countFalsePositive;
        }

        public void setThreshold(double d) {
            this.threshold = d;
        }

        public void setCountTruePositive(long j) {
            this.countTruePositive = j;
        }

        public void setCountFalsePositive(long j) {
            this.countFalsePositive = j;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof CountsForThreshold)) {
                return false;
            }
            CountsForThreshold countsForThreshold = (CountsForThreshold) obj;
            return countsForThreshold.canEqual(this) && Double.compare(getThreshold(), countsForThreshold.getThreshold()) == 0 && getCountTruePositive() == countsForThreshold.getCountTruePositive() && getCountFalsePositive() == countsForThreshold.getCountFalsePositive();
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof CountsForThreshold;
        }

        public int hashCode() {
            long doubleToLongBits = Double.doubleToLongBits(getThreshold());
            int i = (1 * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
            long countTruePositive = getCountTruePositive();
            int i2 = (i * 59) + ((int) ((countTruePositive >>> 32) ^ countTruePositive));
            long countFalsePositive = getCountFalsePositive();
            return (i2 * 59) + ((int) ((countFalsePositive >>> 32) ^ countFalsePositive));
        }

        public String toString() {
            double threshold = getThreshold();
            long countTruePositive = getCountTruePositive();
            getCountFalsePositive();
            return "ROC.CountsForThreshold(threshold=" + threshold + ", countTruePositive=" + threshold + ", countFalsePositive=" + countTruePositive + ")";
        }

        public CountsForThreshold() {
        }
    }

    /* loaded from: input_file:org/nd4j/evaluation/classification/ROC$Metric.class */
    public enum Metric implements IMetric {
        AUROC,
        AUPRC;

        @Override // org.nd4j.evaluation.IMetric
        public Class<? extends IEvaluation> getEvaluationClass() {
            return ROC.class;
        }

        @Override // org.nd4j.evaluation.IMetric
        public boolean minimize() {
            return false;
        }
    }

    public ROC(int i, boolean z, int i2, int i3) {
        this(i, z, i2);
        this.axis = i3;
    }

    public ROC() {
        this(0);
    }

    public ROC(int i) {
        this(i, true);
    }

    public ROC(int i, boolean z) {
        this(i, z, DEFAULT_EXACT_ALLOC_BLOCK_SIZE);
    }

    public ROC(int i, boolean z, int i2) {
        this.counts = new LinkedHashMap();
        this.exampleCount = 0;
        this.axis = 1;
        if (i > 0) {
            this.thresholdSteps = i;
            double d = 1.0d / i;
            for (int i3 = 0; i3 <= i; i3++) {
                double d2 = i3 * d;
                this.counts.put(Double.valueOf(d2), new CountsForThreshold(d2));
            }
            this.isExact = false;
        } else {
            this.isExact = true;
        }
        this.rocRemoveRedundantPts = z;
        this.exactAllocBlockSize = i2;
    }

    public static ROC fromJson(String str) {
        return (ROC) fromJson(str, ROC.class);
    }

    public void setAxis(int i) {
        this.axis = i;
    }

    public int getAxis() {
        return this.axis;
    }

    private double getAuc() {
        if (this.auc != null) {
            return this.auc.doubleValue();
        }
        this.auc = Double.valueOf(calculateAUC());
        return this.auc.doubleValue();
    }

    public double calculateAUC() {
        if (this.auc != null) {
            return this.auc.doubleValue();
        }
        Preconditions.checkState(this.exampleCount > 0, "Unable to calculate AUC: no evaluation has been performed (no examples)");
        this.auc = Double.valueOf(getRocCurve().calculateAUC());
        return this.auc.doubleValue();
    }

    public RocCurve getRocCurve() {
        if (this.rocCurve != null) {
            return this.rocCurve;
        }
        Preconditions.checkState(this.exampleCount > 0, "Unable to get ROC curve: no evaluation has been performed (no examples)");
        if (!this.isExact) {
            double[][] dArr = new double[3][this.thresholdSteps + 1];
            int i = 0;
            Iterator<Map.Entry<Double, CountsForThreshold>> it = this.counts.entrySet().iterator();
            while (it.hasNext()) {
                CountsForThreshold value = it.next().getValue();
                double countTruePositive = value.getCountTruePositive() / this.countActualPositive;
                dArr[0][i] = value.getThreshold();
                dArr[1][i] = value.getCountFalsePositive() / this.countActualNegative;
                dArr[2][i] = countTruePositive;
                i++;
            }
            return new RocCurve(dArr[0], dArr[1], dArr[2]);
        }
        INDArray sortRows = Nd4j.sortRows(getProbAndLabelUsed(), 0, false);
        INDArray column = sortRows.getColumn(1L, true);
        INDArray rsub = sortRows.getColumn(1L, true).rsub(Double.valueOf(1.0d));
        INDArray cumsum = column.cumsum(-1);
        INDArray cumsum2 = rsub.cumsum(-1);
        long size = sortRows.size(0);
        INDArray create = Nd4j.create(DataType.DOUBLE, size + 2, 1);
        create.put(new INDArrayIndex[]{NDArrayIndex.interval(1L, size + 1), NDArrayIndex.all()}, sortRows.getColumn(0L, true));
        INDArray create2 = Nd4j.create(DataType.DOUBLE, size + 2, 1);
        create2.put(new INDArrayIndex[]{NDArrayIndex.interval(1L, size + 1), NDArrayIndex.all()}, cumsum2.div(Long.valueOf(this.countActualNegative)));
        INDArray create3 = Nd4j.create(DataType.DOUBLE, size + 2, 1);
        create3.put(new INDArrayIndex[]{NDArrayIndex.interval(1L, size + 1), NDArrayIndex.all()}, cumsum.div(Long.valueOf(this.countActualPositive)));
        create.putScalar(0L, 0L, 1.0d);
        create2.putScalar(0L, 0L, 0.0d);
        create3.putScalar(0L, 0L, 0.0d);
        create2.putScalar(size + 1, 0L, 1.0d);
        create3.putScalar(size + 1, 0L, 1.0d);
        double[] asDouble = create2.data().asDouble();
        double[] asDouble2 = create3.data().asDouble();
        double[] asDouble3 = create.data().asDouble();
        if (this.rocRemoveRedundantPts) {
            double[][] dArr2 = (double[][]) removeRedundant(asDouble3, asDouble, asDouble2, null, null, null).getFirst();
            asDouble3 = dArr2[0];
            asDouble = dArr2[1];
            asDouble2 = dArr2[2];
        }
        this.rocCurve = new RocCurve(asDouble3, asDouble, asDouble2);
        return this.rocCurve;
    }

    protected INDArray getProbAndLabelUsed() {
        if (this.probAndLabel == null || this.exampleCount == 0) {
            return null;
        }
        return this.probAndLabel.get(NDArrayIndex.interval(0, this.exampleCount), NDArrayIndex.all());
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static Pair<double[][], int[][]> removeRedundant(double[] dArr, double[] dArr2, double[] dArr3, int[] iArr, int[] iArr2, int[] iArr3) {
        boolean z;
        double[] dArr4 = new double[dArr.length];
        double[] dArr5 = new double[dArr2.length];
        double[] dArr6 = new double[dArr3.length];
        int[] iArr4 = null;
        int[] iArr5 = null;
        int[] iArr6 = null;
        boolean z2 = false;
        if (iArr != null) {
            iArr4 = new int[iArr.length];
            iArr5 = new int[iArr2.length];
            iArr6 = new int[iArr3.length];
            z2 = true;
        }
        int i = -1;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            if (i2 == 0 || i2 == dArr.length - 1) {
                z = true;
            } else {
                z = (((dArr2[i2 - 1] > dArr2[i2] ? 1 : (dArr2[i2 - 1] == dArr2[i2] ? 0 : -1)) == 0 && (dArr2[i2] > dArr2[i2 + 1] ? 1 : (dArr2[i2] == dArr2[i2 + 1] ? 0 : -1)) == 0) || ((dArr3[i2 - 1] > dArr3[i2] ? 1 : (dArr3[i2 - 1] == dArr3[i2] ? 0 : -1)) == 0 && (dArr3[i2] > dArr3[i2 + 1] ? 1 : (dArr3[i2] == dArr3[i2 + 1] ? 0 : -1)) == 0)) ? false : true;
            }
            if (z) {
                i++;
                dArr4[i] = dArr[i2];
                dArr6[i] = dArr3[i2];
                dArr5[i] = dArr2[i2];
                if (z2) {
                    iArr4[i] = iArr[i2];
                    iArr5[i] = iArr2[i2];
                    iArr6[i] = iArr3[i2];
                }
            }
        }
        if (i < dArr2.length - 1) {
            dArr4 = Arrays.copyOfRange(dArr4, 0, i + 1);
            dArr5 = Arrays.copyOfRange(dArr5, 0, i + 1);
            dArr6 = Arrays.copyOfRange(dArr6, 0, i + 1);
            if (z2) {
                iArr4 = Arrays.copyOfRange(iArr4, 0, i + 1);
                iArr5 = Arrays.copyOfRange(iArr5, 0, i + 1);
                iArr6 = Arrays.copyOfRange(iArr6, 0, i + 1);
            }
        }
        return new Pair<>(new double[]{dArr4, dArr5, dArr6}, z2 ? new int[]{iArr4, iArr5, iArr6} : null);
    }

    private double getAuprc() {
        if (this.auprc != null) {
            return this.auprc.doubleValue();
        }
        this.auprc = Double.valueOf(calculateAUCPR());
        return this.auprc.doubleValue();
    }

    public double calculateAUCPR() {
        if (this.auprc != null) {
            return this.auprc.doubleValue();
        }
        Preconditions.checkState(this.exampleCount > 0, "Unable to calculate AUPRC: no evaluation has been performed (no examples)");
        this.auprc = Double.valueOf(getPrecisionRecallCurve().calculateAUPRC());
        return this.auprc.doubleValue();
    }

    public PrecisionRecallCurve getPrecisionRecallCurve() {
        double[] dArr;
        double[] dArr2;
        double[] dArr3;
        int[] iArr;
        int[] iArr2;
        int[] iArr3;
        if (this.prCurve != null) {
            return this.prCurve;
        }
        Preconditions.checkState(this.exampleCount > 0, "Unable to get PR curve: no evaluation has been performed (no examples)");
        if (this.isExact) {
            INDArray sortRows = Nd4j.sortRows(getProbAndLabelUsed(), 0, false);
            INDArray cumsum = sortRows.getColumn(1L, true).cumsum(-1);
            long size = sortRows.size(0);
            INDArray create = Nd4j.create(DataType.DOUBLE, size + 2, 1);
            create.put(new INDArrayIndex[]{NDArrayIndex.interval(1L, size + 1), NDArrayIndex.all()}, sortRows.getColumn(0L, true));
            INDArray div = cumsum.castTo(DataType.DOUBLE).div(Nd4j.linspace(1L, size, size, DataType.DOUBLE).reshape(cumsum.shape()));
            INDArray create2 = Nd4j.create(DataType.DOUBLE, size + 2, 1);
            create2.put(new INDArrayIndex[]{NDArrayIndex.interval(1L, size + 1), NDArrayIndex.all()}, div);
            INDArray create3 = Nd4j.create(DataType.DOUBLE, size + 2, 1);
            create3.put(new INDArrayIndex[]{NDArrayIndex.interval(1L, size + 1), NDArrayIndex.all()}, cumsum.div(Long.valueOf(this.countActualPositive)));
            create.putScalar(0L, 0L, 1.0d);
            create2.putScalar(0L, 0L, 1.0d);
            create3.putScalar(0L, 0L, 0.0d);
            create2.putScalar(size + 1, 0L, cumsum.getDouble(cumsum.length() - 1) / size);
            create3.putScalar(size + 1, 0L, 1.0d);
            dArr = create.data().asDouble();
            dArr2 = create2.data().asDouble();
            dArr3 = create3.data().asDouble();
            iArr = new int[dArr.length];
            iArr2 = new int[dArr.length];
            iArr3 = new int[dArr.length];
            for (int i = 1; i < iArr.length - 1; i++) {
                iArr[i] = cumsum.getInt(i - 1);
                iArr2[i] = i - iArr[i];
                iArr3[i] = ((int) this.countActualPositive) - iArr[i];
            }
            iArr[iArr.length - 1] = (int) this.countActualPositive;
            iArr2[iArr.length - 1] = (int) (this.exampleCount - this.countActualPositive);
            iArr3[iArr.length - 1] = 0;
            iArr[0] = 0;
            iArr2[0] = 0;
            iArr3[0] = (int) this.countActualPositive;
            ArrayUtils.reverse(dArr);
            ArrayUtils.reverse(dArr2);
            ArrayUtils.reverse(dArr3);
            ArrayUtils.reverse(iArr);
            ArrayUtils.reverse(iArr2);
            ArrayUtils.reverse(iArr3);
            if (this.rocRemoveRedundantPts) {
                Pair<double[][], int[][]> removeRedundant = removeRedundant(dArr, dArr2, dArr3, iArr, iArr2, iArr3);
                double[][] dArr4 = (double[][]) removeRedundant.getFirst();
                int[][] iArr4 = (int[][]) removeRedundant.getSecond();
                dArr = dArr4[0];
                dArr2 = dArr4[1];
                dArr3 = dArr4[2];
                iArr = iArr4[0];
                iArr2 = iArr4[1];
                iArr3 = iArr4[2];
            }
        } else {
            dArr = new double[this.counts.size()];
            dArr2 = new double[this.counts.size()];
            dArr3 = new double[this.counts.size()];
            iArr = new int[this.counts.size()];
            iArr2 = new int[this.counts.size()];
            iArr3 = new int[this.counts.size()];
            int i2 = 0;
            for (Map.Entry<Double, CountsForThreshold> entry : this.counts.entrySet()) {
                entry.getKey().doubleValue();
                CountsForThreshold value = entry.getValue();
                long countTruePositive = value.getCountTruePositive();
                long countFalsePositive = value.getCountFalsePositive();
                double d = (countTruePositive == 0 && countFalsePositive == 0) ? 1.0d : countTruePositive / (countTruePositive + countFalsePositive);
                double d2 = this.countActualPositive == 0 ? 1.0d : countTruePositive / this.countActualPositive;
                dArr[i2] = value.getThreshold();
                dArr2[i2] = d;
                dArr3[i2] = d2;
                iArr[i2] = (int) countTruePositive;
                iArr2[i2] = (int) countFalsePositive;
                iArr3[i2] = (int) (this.countActualPositive - countTruePositive);
                i2++;
            }
        }
        this.prCurve = new PrecisionRecallCurve(dArr, dArr2, dArr3, iArr, iArr2, iArr3, this.exampleCount);
        return this.prCurve;
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public void eval(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, List<? extends Serializable> list) {
        INDArray column;
        INDArray column2;
        INDArray column3;
        INDArray iNDArray4;
        INDArray iNDArray5;
        INDArray column4;
        INDArray column5;
        Triple<INDArray, INDArray, INDArray> reshapeAndExtractNotMasked = BaseEvaluation.reshapeAndExtractNotMasked(iNDArray, iNDArray2, iNDArray3, this.axis);
        if (reshapeAndExtractNotMasked == null) {
            return;
        }
        INDArray iNDArray6 = (INDArray) reshapeAndExtractNotMasked.getFirst();
        INDArray iNDArray7 = (INDArray) reshapeAndExtractNotMasked.getSecond();
        if (iNDArray6.rank() == 3 && iNDArray7.rank() == 3) {
            evalTimeSeries(iNDArray6, iNDArray7);
        }
        if (iNDArray6.rank() > 2 || iNDArray7.rank() > 2 || iNDArray6.size(1) != iNDArray7.size(1) || iNDArray6.size(1) > 2) {
            throw new IllegalArgumentException("Invalid input data shape: labels shape = " + Arrays.toString(iNDArray6.shape()) + ", predictions shape = " + Arrays.toString(iNDArray7.shape()) + "; require rank 2 array with size(1) == 1 or 2");
        }
        if (iNDArray6.dataType() != iNDArray7.dataType()) {
            iNDArray6 = iNDArray6.castTo(iNDArray7.dataType());
        }
        long longValue = Nd4j.getExecutioner().execAndReturn((ReduceOp) new MatchCondition(iNDArray7, Conditions.isNan(), new int[0])).getFinalResult().longValue();
        Preconditions.checkState(longValue == 0, "Cannot perform evaluation with NaN(s) present: %s NaN(s) present in predictions INDArray", longValue);
        double d = 1.0d / this.thresholdSteps;
        boolean z = iNDArray6.size(1) == 1;
        if (this.isExact) {
            if (this.probAndLabel == null) {
                this.probAndLabel = Nd4j.create(DataType.DOUBLE, new long[]{Math.max(iNDArray6.size(0), this.exactAllocBlockSize), 2}, 'c');
            }
            if (this.exampleCount + iNDArray6.size(0) >= this.probAndLabel.size(0)) {
                INDArray create = Nd4j.create(DataType.DOUBLE, new long[]{this.probAndLabel.size(0) + Math.max(this.exactAllocBlockSize, iNDArray6.size(0)), 2}, 'c');
                if (this.exampleCount > 0) {
                    create.get(NDArrayIndex.interval(0, this.exampleCount), NDArrayIndex.all()).assign(this.probAndLabel.get(NDArrayIndex.interval(0, this.exampleCount), NDArrayIndex.all()));
                }
                this.probAndLabel = create;
            }
            if (z) {
                column4 = iNDArray7;
                column5 = iNDArray6;
            } else {
                column4 = iNDArray7.getColumn(1L, true);
                column5 = iNDArray6.getColumn(1L, true);
            }
            long size = iNDArray6.size(0);
            this.probAndLabel.get(NDArrayIndex.interval(this.exampleCount, this.exampleCount + size), NDArrayIndex.point(0L)).assign(column4);
            this.probAndLabel.get(NDArrayIndex.interval(this.exampleCount, this.exampleCount + size), NDArrayIndex.point(1L)).assign(column5);
            int intValue = column5.sumNumber().intValue();
            this.countActualPositive += intValue;
            this.countActualNegative += iNDArray6.size(0) - intValue;
        } else {
            if (z) {
                column = iNDArray6;
                column2 = iNDArray6.rsub(Double.valueOf(1.0d));
                column3 = iNDArray7;
            } else {
                column = iNDArray6.getColumn(1L, true);
                column2 = iNDArray6.getColumn(0L, true);
                column3 = iNDArray7.getColumn(1L, true);
            }
            this.countActualPositive += column.sumNumber().intValue();
            this.countActualNegative += column2.sumNumber().intValue();
            INDArray iNDArray8 = null;
            INDArray iNDArray9 = null;
            INDArray iNDArray10 = null;
            for (int i = 0; i <= this.thresholdSteps; i++) {
                double d2 = i * d;
                Condition greaterThanOrEqual = Conditions.greaterThanOrEqual(Double.valueOf(d2));
                Condition lessThanOrEqual = Conditions.lessThanOrEqual(Double.valueOf(d2));
                if (iNDArray8 == null) {
                    iNDArray8 = column3.dup(column.ordering());
                } else {
                    iNDArray8.assign(column3);
                }
                INDArray exec = Nd4j.getExecutioner().exec(new CompareAndSet(Nd4j.getExecutioner().exec(new CompareAndSet(iNDArray8, 1.0d, greaterThanOrEqual)), 0.0d, lessThanOrEqual));
                if (i == 0) {
                    iNDArray4 = exec.mul(column);
                    iNDArray5 = exec.mul(column2);
                    iNDArray9 = iNDArray4;
                    iNDArray10 = iNDArray5;
                } else {
                    iNDArray4 = Nd4j.getExecutioner().exec(new MulOp(exec, column, iNDArray9))[0];
                    iNDArray5 = Nd4j.getExecutioner().exec(new MulOp(exec, column2, iNDArray10))[0];
                }
                int intValue2 = iNDArray4.sumNumber().intValue();
                int intValue3 = iNDArray5.sumNumber().intValue();
                CountsForThreshold countsForThreshold = this.counts.get(Double.valueOf(d2));
                countsForThreshold.incrementTruePositive(intValue2);
                countsForThreshold.incrementFalsePositive(intValue3);
            }
        }
        this.exampleCount = (int) (this.exampleCount + iNDArray6.size(0));
        this.auc = null;
        this.auprc = null;
        this.rocCurve = null;
        this.prCurve = null;
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public void merge(ROC roc) {
        if (this.thresholdSteps != roc.thresholdSteps) {
            throw new UnsupportedOperationException("Cannot merge ROC instances with different numbers of threshold steps (" + this.thresholdSteps + " vs. " + roc.thresholdSteps + ")");
        }
        this.countActualPositive += roc.countActualPositive;
        this.countActualNegative += roc.countActualNegative;
        this.auc = null;
        this.auprc = null;
        this.rocCurve = null;
        this.prCurve = null;
        if (!this.isExact) {
            for (Double d : this.counts.keySet()) {
                CountsForThreshold countsForThreshold = this.counts.get(d);
                CountsForThreshold countsForThreshold2 = roc.counts.get(d);
                countsForThreshold.countTruePositive += countsForThreshold2.countTruePositive;
                countsForThreshold.countFalsePositive += countsForThreshold2.countFalsePositive;
            }
        } else {
            if (roc.exampleCount == 0) {
                return;
            }
            if (this.exampleCount == 0) {
                this.exampleCount = roc.exampleCount;
                this.probAndLabel = roc.probAndLabel;
                return;
            }
            if (this.exampleCount + roc.exampleCount > this.probAndLabel.size(0)) {
                INDArray create = Nd4j.create(DataType.DOUBLE, this.probAndLabel.size(0) + Math.max(roc.probAndLabel.size(0), this.exactAllocBlockSize), 2);
                create.put(new INDArrayIndex[]{NDArrayIndex.interval(0, this.exampleCount), NDArrayIndex.all()}, this.probAndLabel.get(NDArrayIndex.interval(0, this.exampleCount), NDArrayIndex.all()));
                this.probAndLabel = create;
            }
            this.probAndLabel.put(new INDArrayIndex[]{NDArrayIndex.interval(this.exampleCount, this.exampleCount + roc.exampleCount), NDArrayIndex.all()}, roc.probAndLabel.get(NDArrayIndex.interval(0, roc.exampleCount), NDArrayIndex.all()));
        }
        this.exampleCount += roc.exampleCount;
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public void reset() {
        this.countActualPositive = 0L;
        this.countActualNegative = 0L;
        this.counts.clear();
        if (this.isExact) {
            this.probAndLabel = null;
        } else {
            double d = 1.0d / this.thresholdSteps;
            for (int i = 0; i <= this.thresholdSteps; i++) {
                double d2 = i * d;
                this.counts.put(Double.valueOf(d2), new CountsForThreshold(d2));
            }
        }
        this.exampleCount = 0;
        this.auc = null;
        this.auprc = null;
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public String stats() {
        if (this.exampleCount == 0) {
            return "ROC: No data available (no data has been performed)";
        }
        StringBuilder sb = new StringBuilder();
        sb.append("AUC (Area under ROC Curve):                ").append(calculateAUC()).append("\n");
        sb.append("AUPRC (Area under Precision/Recall Curve): ").append(calculateAUCPR());
        if (!this.isExact) {
            sb.append("\n");
            sb.append("[Note: Thresholded AUC/AUPRC calculation used with ").append(this.thresholdSteps).append(" steps); accuracy may reduced compared to exact mode]");
        }
        return sb.toString();
    }

    @Override // org.nd4j.evaluation.BaseEvaluation
    public String toString() {
        return stats();
    }

    public double scoreForMetric(Metric metric) {
        switch (metric) {
            case AUROC:
                return calculateAUC();
            case AUPRC:
                return calculateAUCPR();
            default:
                throw new IllegalStateException("Unknown metric: " + metric);
        }
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public double getValue(IMetric iMetric) {
        if (iMetric instanceof Metric) {
            return scoreForMetric((Metric) iMetric);
        }
        throw new IllegalStateException("Can't get value for non-ROC Metric " + iMetric);
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public ROC newInstance() {
        return new ROC(this.thresholdSteps, this.rocRemoveRedundantPts, this.exactAllocBlockSize, this.axis);
    }

    @Override // org.nd4j.evaluation.BaseEvaluation
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof ROC)) {
            return false;
        }
        ROC roc = (ROC) obj;
        if (!roc.canEqual(this) || !super.equals(obj) || getThresholdSteps() != roc.getThresholdSteps() || getCountActualPositive() != roc.getCountActualPositive() || getCountActualNegative() != roc.getCountActualNegative() || isExact() != roc.isExact() || getExampleCount() != roc.getExampleCount() || isRocRemoveRedundantPts() != roc.isRocRemoveRedundantPts()) {
            return false;
        }
        Map<Double, CountsForThreshold> counts = getCounts();
        Map<Double, CountsForThreshold> counts2 = roc.getCounts();
        return counts == null ? counts2 == null : counts.equals(counts2);
    }

    @Override // org.nd4j.evaluation.BaseEvaluation
    protected boolean canEqual(Object obj) {
        return obj instanceof ROC;
    }

    @Override // org.nd4j.evaluation.BaseEvaluation
    public int hashCode() {
        int hashCode = (super.hashCode() * 59) + getThresholdSteps();
        long countActualPositive = getCountActualPositive();
        int i = (hashCode * 59) + ((int) ((countActualPositive >>> 32) ^ countActualPositive));
        long countActualNegative = getCountActualNegative();
        int exampleCount = (((((((i * 59) + ((int) ((countActualNegative >>> 32) ^ countActualNegative))) * 59) + (isExact() ? 79 : 97)) * 59) + getExampleCount()) * 59) + (isRocRemoveRedundantPts() ? 79 : 97);
        Map<Double, CountsForThreshold> counts = getCounts();
        return (exampleCount * 59) + (counts == null ? 43 : counts.hashCode());
    }

    public Map<Double, CountsForThreshold> getCounts() {
        return this.counts;
    }

    public int getThresholdSteps() {
        return this.thresholdSteps;
    }

    public long getCountActualPositive() {
        return this.countActualPositive;
    }

    public long getCountActualNegative() {
        return this.countActualNegative;
    }

    public PrecisionRecallCurve getPrCurve() {
        return this.prCurve;
    }

    public boolean isExact() {
        return this.isExact;
    }

    public INDArray getProbAndLabel() {
        return this.probAndLabel;
    }

    public int getExampleCount() {
        return this.exampleCount;
    }

    public boolean isRocRemoveRedundantPts() {
        return this.rocRemoveRedundantPts;
    }

    public int getExactAllocBlockSize() {
        return this.exactAllocBlockSize;
    }

    public void setThresholdSteps(int i) {
        this.thresholdSteps = i;
    }

    public void setCountActualPositive(long j) {
        this.countActualPositive = j;
    }

    public void setCountActualNegative(long j) {
        this.countActualNegative = j;
    }

    public void setAuc(Double d) {
        this.auc = d;
    }

    public void setAuprc(Double d) {
        this.auprc = d;
    }

    public void setRocCurve(RocCurve rocCurve) {
        this.rocCurve = rocCurve;
    }

    public void setPrCurve(PrecisionRecallCurve precisionRecallCurve) {
        this.prCurve = precisionRecallCurve;
    }

    public void setExact(boolean z) {
        this.isExact = z;
    }

    public void setProbAndLabel(INDArray iNDArray) {
        this.probAndLabel = iNDArray;
    }

    public void setExampleCount(int i) {
        this.exampleCount = i;
    }

    public void setRocRemoveRedundantPts(boolean z) {
        this.rocRemoveRedundantPts = z;
    }

    public void setExactAllocBlockSize(int i) {
        this.exactAllocBlockSize = i;
    }
}
