package org.nd4j.evaluation.classification;

import java.io.Serializable;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.Counter;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.primitives.Triple;
import org.nd4j.evaluation.BaseEvaluation;
import org.nd4j.evaluation.EvaluationAveraging;
import org.nd4j.evaluation.EvaluationUtils;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.IMetric;
import org.nd4j.evaluation.meta.Prediction;
import org.nd4j.evaluation.serde.ConfusionMatrixDeserializer;
import org.nd4j.evaluation.serde.ConfusionMatrixSerializer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.ReduceOp;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.serde.jackson.shaded.NDArrayTextDeSerializer;
import org.nd4j.serde.jackson.shaded.NDArrayTextSerializer;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@JsonIgnoreProperties({"confusionMatrixMetaData"})
/* loaded from: input_file:org/nd4j/evaluation/classification/Evaluation.class */
public class Evaluation extends BaseEvaluation<Evaluation> {
    private static final Logger log = LoggerFactory.getLogger(Evaluation.class);
    protected static final double DEFAULT_EDGE_VALUE = 0.0d;
    protected static final int CONFUSION_PRINT_MAX_CLASSES = 20;
    protected int axis;
    protected Integer binaryPositiveClass;
    protected final int topN;
    protected int topNCorrectCount;
    protected int topNTotalCount;
    protected Counter<Integer> truePositives;
    protected Counter<Integer> falsePositives;
    protected Counter<Integer> trueNegatives;
    protected Counter<Integer> falseNegatives;

    @JsonDeserialize(using = ConfusionMatrixDeserializer.class)
    @JsonSerialize(using = ConfusionMatrixSerializer.class)
    protected ConfusionMatrix<Integer> confusion;
    protected int numRowCounter;
    protected List<String> labelsList;
    protected Double binaryDecisionThreshold;

    @JsonDeserialize(using = NDArrayTextDeSerializer.class)
    @JsonSerialize(using = NDArrayTextSerializer.class)
    protected INDArray costArray;
    protected Map<Pair<Integer, Integer>, List<Object>> confusionMatrixMetaData;
    protected int maxWarningClassesToPrint;

    /* loaded from: input_file:org/nd4j/evaluation/classification/Evaluation$Metric.class */
    public enum Metric implements IMetric {
        ACCURACY,
        F1,
        PRECISION,
        RECALL,
        GMEASURE,
        MCC;

        @Override // org.nd4j.evaluation.IMetric
        public Class<? extends IEvaluation> getEvaluationClass() {
            return Evaluation.class;
        }

        @Override // org.nd4j.evaluation.IMetric
        public boolean minimize() {
            return false;
        }
    }

    protected Evaluation(int i, Integer num, int i2, List<String> list, Double d, INDArray iNDArray, int i3) {
        this.axis = 1;
        this.binaryPositiveClass = 1;
        this.topNCorrectCount = 0;
        this.topNTotalCount = 0;
        this.truePositives = new Counter<>();
        this.falsePositives = new Counter<>();
        this.trueNegatives = new Counter<>();
        this.falseNegatives = new Counter<>();
        this.numRowCounter = 0;
        this.labelsList = new ArrayList();
        this.maxWarningClassesToPrint = 16;
        this.axis = i;
        this.binaryPositiveClass = num;
        this.topN = i2;
        this.labelsList = list;
        this.binaryDecisionThreshold = d;
        this.costArray = iNDArray;
        this.maxWarningClassesToPrint = i3;
    }

    public Evaluation() {
        this.axis = 1;
        this.binaryPositiveClass = 1;
        this.topNCorrectCount = 0;
        this.topNTotalCount = 0;
        this.truePositives = new Counter<>();
        this.falsePositives = new Counter<>();
        this.trueNegatives = new Counter<>();
        this.falseNegatives = new Counter<>();
        this.numRowCounter = 0;
        this.labelsList = new ArrayList();
        this.maxWarningClassesToPrint = 16;
        this.topN = 1;
        this.binaryPositiveClass = 1;
    }

    public Evaluation(int i) {
        this(i, i == 2 ? 1 : null);
    }

    public Evaluation(int i, Integer num) {
        this(createLabels(i), 1);
        if (num != null) {
            Preconditions.checkArgument(num.intValue() == 0 || num.intValue() == 1, "Only 0 and 1 are valid inputs for binaryPositiveClass; got " + num);
            Preconditions.checkArgument(i == 2, "Cannot set binaryPositiveClass argument when number of classes is not equal to 2 (got: numClasses=" + i + ")");
        }
        this.binaryPositiveClass = num;
    }

    public Evaluation(List<String> list) {
        this(list, 1);
    }

    public Evaluation(Map<Integer, String> map) {
        this(createLabelsFromMap(map), 1);
    }

    public Evaluation(List<String> list, int i) {
        this.axis = 1;
        this.binaryPositiveClass = 1;
        this.topNCorrectCount = 0;
        this.topNTotalCount = 0;
        this.truePositives = new Counter<>();
        this.falsePositives = new Counter<>();
        this.trueNegatives = new Counter<>();
        this.falseNegatives = new Counter<>();
        this.numRowCounter = 0;
        this.labelsList = new ArrayList();
        this.maxWarningClassesToPrint = 16;
        this.labelsList = list;
        if (list != null) {
            createConfusion(list.size());
        }
        this.topN = i;
        if (list == null || list.size() != 2) {
            return;
        }
        this.binaryPositiveClass = 1;
    }

    public Evaluation(double d) {
        this(d, (Integer) 1);
    }

    public Evaluation(double d, @NonNull Integer num) {
        this.axis = 1;
        this.binaryPositiveClass = 1;
        this.topNCorrectCount = 0;
        this.topNTotalCount = 0;
        this.truePositives = new Counter<>();
        this.falsePositives = new Counter<>();
        this.trueNegatives = new Counter<>();
        this.falseNegatives = new Counter<>();
        this.numRowCounter = 0;
        this.labelsList = new ArrayList();
        this.maxWarningClassesToPrint = 16;
        if (num == null) {
            throw new NullPointerException("binaryPositiveClass is marked non-null but is null");
        }
        if (num != null) {
            Preconditions.checkArgument(num.intValue() == 0 || num.intValue() == 1, "Only 0 and 1 are valid inputs for binaryPositiveClass; got " + num);
        }
        this.binaryDecisionThreshold = Double.valueOf(d);
        this.topN = 1;
        this.binaryPositiveClass = num;
    }

    public Evaluation(INDArray iNDArray) {
        this((List<String>) null, iNDArray);
    }

    public Evaluation(List<String> list, INDArray iNDArray) {
        this.axis = 1;
        this.binaryPositiveClass = 1;
        this.topNCorrectCount = 0;
        this.topNTotalCount = 0;
        this.truePositives = new Counter<>();
        this.falsePositives = new Counter<>();
        this.trueNegatives = new Counter<>();
        this.falseNegatives = new Counter<>();
        this.numRowCounter = 0;
        this.labelsList = new ArrayList();
        this.maxWarningClassesToPrint = 16;
        if (iNDArray != null && !iNDArray.isRowVectorOrScalar()) {
            throw new IllegalArgumentException("Invalid cost array: must be a row vector (got shape: " + Arrays.toString(iNDArray.shape()) + ")");
        }
        if (iNDArray != null && iNDArray.minNumber().doubleValue() < 0.0d) {
            throw new IllegalArgumentException("Invalid cost array: Cost array values must be positive");
        }
        this.labelsList = list;
        this.costArray = iNDArray == null ? null : iNDArray.castTo(DataType.FLOAT);
        this.topN = 1;
    }

    protected int numClasses() {
        return this.labelsList != null ? this.labelsList.size() : confusion().getClasses().size();
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public void reset() {
        this.confusion = null;
        this.truePositives = new Counter<>();
        this.falsePositives = new Counter<>();
        this.trueNegatives = new Counter<>();
        this.falseNegatives = new Counter<>();
        this.topNCorrectCount = 0;
        this.topNTotalCount = 0;
        this.numRowCounter = 0;
    }

    private ConfusionMatrix<Integer> confusion() {
        return this.confusion;
    }

    private static List<String> createLabels(int i) {
        if (i == 1) {
            i = 2;
        }
        ArrayList arrayList = new ArrayList(i);
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(String.valueOf(i2));
        }
        return arrayList;
    }

    private static List<String> createLabelsFromMap(Map<Integer, String> map) {
        int size = map.size();
        ArrayList arrayList = new ArrayList(size);
        for (int i = 0; i < size; i++) {
            String str = map.get(Integer.valueOf(i));
            if (str == null) {
                throw new IllegalArgumentException("Invalid labels map: missing key for class " + i + " (expect integers 0 to " + (size - 1) + ")");
            }
            arrayList.add(str);
        }
        return arrayList;
    }

    private void createConfusion(int i) {
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(Integer.valueOf(i2));
        }
        this.confusion = new ConfusionMatrix<>(arrayList);
    }

    public void setAxis(int i) {
        this.axis = i;
    }

    public int getAxis() {
        return this.axis;
    }

    @Override // org.nd4j.evaluation.BaseEvaluation, org.nd4j.evaluation.IEvaluation
    public void eval(INDArray iNDArray, INDArray iNDArray2) {
        eval(iNDArray, iNDArray2, (List<? extends Serializable>) null);
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public void eval(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, List<? extends Serializable> list) {
        INDArray argMax;
        int i;
        int i2;
        Triple<INDArray, INDArray, INDArray> reshapeAndExtractNotMasked = BaseEvaluation.reshapeAndExtractNotMasked(iNDArray, iNDArray2, iNDArray3, this.axis);
        if (reshapeAndExtractNotMasked == null) {
            return;
        }
        INDArray iNDArray4 = (INDArray) reshapeAndExtractNotMasked.getFirst();
        INDArray iNDArray5 = (INDArray) reshapeAndExtractNotMasked.getSecond();
        Preconditions.checkState(((INDArray) reshapeAndExtractNotMasked.getThird()) == null, "Per-output masking for Evaluation is not supported");
        long longValue = Nd4j.getExecutioner().execAndReturn((ReduceOp) new MatchCondition(iNDArray5, Conditions.isNan(), new int[0])).getFinalResult().longValue();
        Preconditions.checkState(longValue == 0, "Cannot perform evaluation with NaNs present in predictions: %s NaNs present in predictions INDArray", longValue);
        this.numRowCounter = (int) (this.numRowCounter + iNDArray4.size(0));
        if (iNDArray4.dataType() != iNDArray5.dataType()) {
            iNDArray4 = iNDArray4.castTo(iNDArray5.dataType());
        }
        if (this.confusion == null) {
            int columns = iNDArray4.columns();
            if (columns == 1) {
                columns = 2;
            }
            if (this.labelsList == null || this.labelsList.isEmpty()) {
                this.labelsList = new ArrayList(columns);
                for (int i3 = 0; i3 < columns; i3++) {
                    this.labelsList.add(String.valueOf(i3));
                }
            }
            createConfusion(columns);
        }
        if (!Arrays.equals(iNDArray4.shape(), iNDArray5.shape())) {
            throw new IllegalArgumentException("Unable to evaluate. Predictions and labels arrays are not same shape. Predictions shape: " + Arrays.toString(iNDArray5.shape()) + ", Labels shape: " + Arrays.toString(iNDArray4.shape()));
        }
        int columns2 = iNDArray4.columns();
        int rows = iNDArray4.rows();
        if (columns2 == 1) {
            INDArray castTo = iNDArray5.gt(Double.valueOf(this.binaryDecisionThreshold == null ? 0.5d : this.binaryDecisionThreshold.doubleValue())).castTo(iNDArray2.dataType());
            INDArray rsub = iNDArray4.rsub(Double.valueOf(1.0d));
            INDArray rsub2 = castTo.rsub(Double.valueOf(1.0d));
            int intValue = iNDArray4.mul(castTo).castTo(DataType.INT).sumNumber().intValue();
            int intValue2 = rsub.mul(castTo).castTo(DataType.INT).sumNumber().intValue();
            int intValue3 = rsub2.mul(iNDArray4).castTo(DataType.INT).sumNumber().intValue();
            int i4 = ((rows - intValue) - intValue2) - intValue3;
            confusion().add(1, 1, intValue);
            confusion().add(1, 0, intValue3);
            confusion().add(0, 1, intValue2);
            confusion().add(0, 0, i4);
            this.truePositives.incrementCount(1, intValue);
            this.falsePositives.incrementCount(1, intValue2);
            this.falseNegatives.incrementCount(1, intValue3);
            this.trueNegatives.incrementCount(1, i4);
            this.truePositives.incrementCount(0, i4);
            this.falsePositives.incrementCount(0, intValue3);
            this.falseNegatives.incrementCount(0, intValue2);
            this.trueNegatives.incrementCount(0, intValue);
            if (list != null) {
                for (int i5 = 0; i5 < castTo.size(0) && i5 < list.size(); i5++) {
                    addToMetaConfusionMatrix(iNDArray4.getDouble(0L) == 0.0d ? 0 : 1, castTo.getDouble(0L) == 0.0d ? 0 : 1, list.get(i5));
                }
            }
        } else {
            if (this.binaryDecisionThreshold == null) {
                argMax = this.costArray != null ? Nd4j.argMax(iNDArray5.mulRowVector(this.costArray.castTo(iNDArray5.dataType())), 1) : Nd4j.argMax(iNDArray5, 1);
            } else {
                if (columns2 != 2) {
                    throw new IllegalStateException("Binary decision threshold is set, but number of columns for predictions is " + columns2 + ". Binary decision threshold can only be used for binary prediction cases");
                }
                argMax = iNDArray5.getColumn(1L).gt(this.binaryDecisionThreshold);
            }
            INDArray argMax2 = Nd4j.argMax(iNDArray4, 1);
            long length = argMax.length();
            for (int i6 = 0; i6 < length; i6++) {
                int i7 = (int) argMax2.getDouble(i6);
                int i8 = (int) argMax.getDouble(i6);
                confusion().add(Integer.valueOf(i7), Integer.valueOf(i8));
                if (list != null && list.size() > i6) {
                    addToMetaConfusionMatrix(i7, i8, list.get(i6));
                }
                if (i7 == i8) {
                    this.truePositives.incrementCount(Integer.valueOf(i7), 1.0d);
                    for (int i9 = 0; i9 < columns2; i9++) {
                        if (i9 != i7) {
                            this.trueNegatives.incrementCount(Integer.valueOf(i9), 1.0d);
                        }
                    }
                } else {
                    this.falsePositives.incrementCount(Integer.valueOf(i8), 1.0d);
                    this.falseNegatives.incrementCount(Integer.valueOf(i7), 1.0d);
                    if (i7 < i8) {
                        i = i7;
                        i2 = i8;
                    } else {
                        i = i8;
                        i2 = i7;
                    }
                    for (int i10 = 0; i10 < i; i10++) {
                        this.trueNegatives.incrementCount(Integer.valueOf(i10), 1.0d);
                    }
                    for (int i11 = i + 1; i11 < i2; i11++) {
                        this.trueNegatives.incrementCount(Integer.valueOf(i11), 1.0d);
                    }
                    for (int i12 = i2 + 1; i12 < columns2; i12++) {
                        this.trueNegatives.incrementCount(Integer.valueOf(i12), 1.0d);
                    }
                }
            }
        }
        if (columns2 <= 1 || this.topN <= 1) {
            return;
        }
        long length2 = Nd4j.argMax(iNDArray4, 1).length();
        for (int i13 = 0; i13 < length2; i13++) {
            if (((int) Nd4j.getExecutioner().exec((ReduceOp) new MatchCondition(iNDArray5.getRow(i13), Conditions.greaterThan(Double.valueOf(iNDArray5.getDouble(i13, (int) r0.getDouble(i13)))), new int[0])).getDouble(0L)) < this.topN) {
                this.topNCorrectCount++;
            }
            this.topNTotalCount++;
        }
    }

    public void eval(int i, int i2) {
        this.numRowCounter++;
        if (this.confusion == null) {
            throw new UnsupportedOperationException("Cannot evaluate single example without initializing confusion matrix first");
        }
        addToConfusion(Integer.valueOf(i2), Integer.valueOf(i));
        if (i == i2) {
            incrementTruePositives(Integer.valueOf(i));
            for (Integer num : confusion().getClasses()) {
                if (num.intValue() != i) {
                    this.trueNegatives.incrementCount(num, 1.0d);
                }
            }
            return;
        }
        incrementFalseNegatives(Integer.valueOf(i2));
        incrementFalsePositives(Integer.valueOf(i));
        for (Integer num2 : confusion().getClasses()) {
            if (num2.intValue() != i && num2.intValue() != i2) {
                this.trueNegatives.incrementCount(num2, 1.0d);
            }
        }
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public String stats() {
        return stats(false);
    }

    public String stats(boolean z) {
        return stats(z, numClasses() <= 20, numClasses() > 20);
    }

    public String stats(boolean z, boolean z2) {
        return stats(z, z2, false);
    }

    private String stats(boolean z, boolean z2, boolean z3) {
        if (this.numRowCounter == 0) {
            return "Evaluation: No data available (no evaluation has been performed)";
        }
        StringBuilder append = new StringBuilder().append("\n");
        StringBuilder sb = new StringBuilder();
        ConfusionMatrix<Integer> confusion = confusion();
        if (confusion == null) {
            confusion = new ConfusionMatrix<>();
        }
        List<Integer> classes = confusion.getClasses();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (Integer num : classes) {
            if (!z && this.truePositives.getCount(num) == 0.0d) {
                if (this.falsePositives.getCount(num) == 0.0d) {
                    arrayList.add(num);
                }
                if (this.falseNegatives.getCount(num) == 0.0d) {
                    arrayList2.add(num);
                }
            }
        }
        if (!arrayList.isEmpty()) {
            warningHelper(sb, arrayList, "precision");
        }
        if (!arrayList2.isEmpty()) {
            warningHelper(sb, arrayList2, "recall");
        }
        int size = confusion.getClasses().size();
        DecimalFormat decimalFormat = new DecimalFormat("0.0000");
        double accuracy = accuracy();
        double precision = precision();
        double recall = recall();
        double f1 = f1();
        append.append("\n========================Evaluation Metrics========================");
        append.append("\n # of classes:    ").append(size);
        append.append("\n Accuracy:        ").append(format(decimalFormat, accuracy));
        if (this.topN > 1) {
            append.append("\n Top ").append(this.topN).append(" Accuracy:  ").append(format(decimalFormat, topNAccuracy()));
        }
        append.append("\n Precision:       ").append(format(decimalFormat, precision));
        if (size > 2 && averagePrecisionNumClassesExcluded() > 0) {
            int averagePrecisionNumClassesExcluded = averagePrecisionNumClassesExcluded();
            append.append("\t(").append(averagePrecisionNumClassesExcluded).append(" class");
            if (averagePrecisionNumClassesExcluded > 1) {
                append.append("es");
            }
            append.append(" excluded from average)");
        }
        append.append("\n Recall:          ").append(format(decimalFormat, recall));
        if (size > 2 && averageRecallNumClassesExcluded() > 0) {
            int averageRecallNumClassesExcluded = averageRecallNumClassesExcluded();
            append.append("\t(").append(averageRecallNumClassesExcluded).append(" class");
            if (averageRecallNumClassesExcluded > 1) {
                append.append("es");
            }
            append.append(" excluded from average)");
        }
        append.append("\n F1 Score:        ").append(format(decimalFormat, f1));
        if (size > 2 && averageF1NumClassesExcluded() > 0) {
            int averageF1NumClassesExcluded = averageF1NumClassesExcluded();
            append.append("\t(").append(averageF1NumClassesExcluded).append(" class");
            if (averageF1NumClassesExcluded > 1) {
                append.append("es");
            }
            append.append(" excluded from average)");
        }
        if (size > 2 || this.binaryPositiveClass == null) {
            append.append("\nPrecision, recall & F1: macro-averaged (equally weighted avg. of ").append(size).append(" classes)");
        }
        if (size == 2 && this.binaryPositiveClass != null) {
            append.append("\nPrecision, recall & F1: reported for positive class (class ").append(this.binaryPositiveClass);
            if (this.labelsList != null) {
                append.append(" - \"").append(this.labelsList.get(this.binaryPositiveClass.intValue())).append("\"");
            }
            append.append(") only");
        }
        if (this.binaryDecisionThreshold != null) {
            append.append("\nBinary decision threshold: ").append(this.binaryDecisionThreshold);
        }
        if (this.costArray != null) {
            append.append("\nCost array: ").append(Arrays.toString(this.costArray.dup().data().asFloat()));
        }
        append.append("\n\n");
        append.append((CharSequence) sb);
        if (z2) {
            append.append("\n=========================Confusion Matrix=========================\n");
            append.append(confusionMatrix());
        } else if (z3) {
            append.append("\n\nNote: Confusion matrix not generated due to space requirements for ").append(size).append(" classes.\n").append("Use stats(false,true) to generate anyway");
        }
        append.append("\n==================================================================");
        return append.toString();
    }

    public String confusionMatrix() {
        int numClasses = numClasses();
        if (this.confusion == null) {
            return "Confusion matrix: <no data>";
        }
        List<Integer> classes = this.confusion.getClasses();
        int i = 1;
        for (Integer num : classes) {
            Iterator<Integer> it = classes.iterator();
            while (it.hasNext()) {
                i = Math.max(i, confusion().getCount(num, it.next()));
            }
        }
        int ceil = (int) Math.ceil(Math.log10(Math.max(i, numClasses)));
        if (ceil < 1) {
            ceil = 1;
        }
        String str = "%" + (ceil + 1) + "d";
        StringBuilder sb = new StringBuilder();
        for (int i2 = 0; i2 < numClasses; i2++) {
            sb.append(String.format(str, Integer.valueOf(i2)));
        }
        sb.append("\n");
        int i3 = ((ceil + 1) * numClasses) + 1;
        for (int i4 = 0; i4 < i3; i4++) {
            sb.append("-");
        }
        sb.append("\n");
        for (int i5 = 0; i5 < numClasses; i5++) {
            String resolveLabelForClass = resolveLabelForClass(Integer.valueOf(i5));
            for (int i6 = 0; i6 < numClasses; i6++) {
                sb.append(String.format(str, Integer.valueOf(this.confusion.getCount(Integer.valueOf(i5), Integer.valueOf(i6)))));
            }
            sb.append(" | ").append(i5).append(" = ").append(resolveLabelForClass).append("\n");
        }
        sb.append("\nConfusion matrix format: Actual (rowClass) predicted as (columnClass) N times");
        return sb.toString();
    }

    private static String format(DecimalFormat decimalFormat, double d) {
        return (Double.isNaN(d) || Double.isInfinite(d)) ? String.valueOf(d) : decimalFormat.format(d);
    }

    private String resolveLabelForClass(Integer num) {
        return (this.labelsList == null || this.labelsList.size() <= num.intValue()) ? num.toString() : this.labelsList.get(num.intValue());
    }

    private void warningHelper(StringBuilder sb, List<Integer> list, String str) {
        String str2;
        sb.append("Warning: ").append(list.size()).append(" class");
        if (list.size() == 1) {
            str2 = "was";
        } else {
            str2 = "were";
            sb.append("es");
        }
        sb.append(" ").append(str2);
        sb.append(" never predicted by the model and ").append(str2).append(" excluded from average ").append(str);
        if (list.size() <= this.maxWarningClassesToPrint) {
            sb.append("\nClasses excluded from average ").append(str).append(": ").append(list).append("\n");
        }
    }

    public double precision(Integer num) {
        return precision(num, 0.0d);
    }

    public double precision(Integer num, double d) {
        Preconditions.checkState(this.numRowCounter > 0, "Cannot get precision: no evaluation has been performed");
        return EvaluationUtils.precision((long) this.truePositives.getCount(num), (long) this.falsePositives.getCount(num), d);
    }

    public double precision() {
        return (this.binaryPositiveClass == null || numClasses() != 2) ? precision(EvaluationAveraging.Macro) : precision(this.binaryPositiveClass);
    }

    public double precision(EvaluationAveraging evaluationAveraging) {
        Preconditions.checkState(this.numRowCounter > 0, "Cannot get precision: no evaluation has been performed");
        int size = confusion().getClasses().size();
        if (evaluationAveraging == EvaluationAveraging.Macro) {
            double d = 0.0d;
            int i = 0;
            for (int i2 = 0; i2 < size; i2++) {
                double precision = precision(Integer.valueOf(i2), -1.0d);
                if (precision != -1.0d) {
                    d += precision;
                    i++;
                }
            }
            return d / i;
        }
        if (evaluationAveraging != EvaluationAveraging.Micro) {
            throw new UnsupportedOperationException("Unknown averaging approach: " + evaluationAveraging);
        }
        long j = 0;
        long j2 = 0;
        for (int i3 = 0; i3 < size; i3++) {
            j = (long) (j + this.truePositives.getCount(Integer.valueOf(i3)));
            j2 = (long) (j2 + this.falsePositives.getCount(Integer.valueOf(i3)));
        }
        return EvaluationUtils.precision(j, j2, 0.0d);
    }

    public int averagePrecisionNumClassesExcluded() {
        return numClassesExcluded("precision");
    }

    public int averageRecallNumClassesExcluded() {
        return numClassesExcluded("recall");
    }

    public int averageF1NumClassesExcluded() {
        return numClassesExcluded("f1");
    }

    public int averageFBetaNumClassesExcluded() {
        return numClassesExcluded("fbeta");
    }

    /* JADX WARN: Removed duplicated region for block: B:20:0x00ac  */
    /* JADX WARN: Removed duplicated region for block: B:23:0x00f1  */
    /* JADX WARN: Removed duplicated region for block: B:26:0x00f4 A[SYNTHETIC] */
    /* JADX WARN: Removed duplicated region for block: B:27:0x00bd  */
    /* JADX WARN: Removed duplicated region for block: B:28:0x00cb  */
    /* JADX WARN: Removed duplicated region for block: B:29:0x00da A[SYNTHETIC] */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    private int numClassesExcluded(java.lang.String r8) {
        /*
            r7 = this;
            r0 = 0
            r9 = r0
            r0 = r7
            org.nd4j.evaluation.classification.ConfusionMatrix r0 = r0.confusion()
            java.util.List r0 = r0.getClasses()
            int r0 = r0.size()
            r10 = r0
            r0 = 0
            r11 = r0
        L12:
            r0 = r11
            r1 = r10
            if (r0 >= r1) goto Lfa
            r0 = r8
            java.lang.String r0 = r0.toLowerCase()
            r14 = r0
            r0 = -1
            r15 = r0
            r0 = r14
            int r0 = r0.hashCode()
            switch(r0) {
                case -1376177026: goto L50;
                case -934922479: goto L60;
                case 3211: goto L70;
                case 97219414: goto L80;
                default: goto L8d;
            }
        L50:
            r0 = r14
            java.lang.String r1 = "precision"
            boolean r0 = r0.equals(r1)
            if (r0 == 0) goto L8d
            r0 = 0
            r15 = r0
            goto L8d
        L60:
            r0 = r14
            java.lang.String r1 = "recall"
            boolean r0 = r0.equals(r1)
            if (r0 == 0) goto L8d
            r0 = 1
            r15 = r0
            goto L8d
        L70:
            r0 = r14
            java.lang.String r1 = "f1"
            boolean r0 = r0.equals(r1)
            if (r0 == 0) goto L8d
            r0 = 2
            r15 = r0
            goto L8d
        L80:
            r0 = r14
            java.lang.String r1 = "fbeta"
            boolean r0 = r0.equals(r1)
            if (r0 == 0) goto L8d
            r0 = 3
            r15 = r0
        L8d:
            r0 = r15
            switch(r0) {
                case 0: goto Lac;
                case 1: goto Lbd;
                case 2: goto Lcb;
                case 3: goto Lcb;
                default: goto Lda;
            }
        Lac:
            r0 = r7
            r1 = r11
            java.lang.Integer r1 = java.lang.Integer.valueOf(r1)
            r2 = -4616189618054758400(0xbff0000000000000, double:-1.0)
            double r0 = r0.precision(r1, r2)
            r12 = r0
            goto Le8
        Lbd:
            r0 = r7
            r1 = r11
            r2 = -4616189618054758400(0xbff0000000000000, double:-1.0)
            double r0 = r0.recall(r1, r2)
            r12 = r0
            goto Le8
        Lcb:
            r0 = r7
            r1 = 4607182418800017408(0x3ff0000000000000, double:1.0)
            r2 = r11
            r3 = -4616189618054758400(0xbff0000000000000, double:-1.0)
            double r0 = r0.fBeta(r1, r2, r3)
            r12 = r0
            goto Le8
        Lda:
            java.lang.RuntimeException r0 = new java.lang.RuntimeException
            r1 = r0
            r2 = r8
            java.lang.String r2 = "Unknown metric: " + r2
            r1.<init>(r2)
            throw r0
        Le8:
            r0 = r12
            r1 = -4616189618054758400(0xbff0000000000000, double:-1.0)
            int r0 = (r0 > r1 ? 1 : (r0 == r1 ? 0 : -1))
            if (r0 != 0) goto Lf4
            int r9 = r9 + 1
        Lf4:
            int r11 = r11 + 1
            goto L12
        Lfa:
            r0 = r9
            return r0
        */
        throw new UnsupportedOperationException("Method not decompiled: org.nd4j.evaluation.classification.Evaluation.numClassesExcluded(java.lang.String):int");
    }

    public double recall(int i) {
        return recall(i, 0.0d);
    }

    public double recall(int i, double d) {
        Preconditions.checkState(this.numRowCounter > 0, "Cannot get recall: no evaluation has been performed");
        return EvaluationUtils.recall((long) this.truePositives.getCount(Integer.valueOf(i)), (long) this.falseNegatives.getCount(Integer.valueOf(i)), d);
    }

    public double recall() {
        return (this.binaryPositiveClass == null || numClasses() != 2) ? recall(EvaluationAveraging.Macro) : recall(this.binaryPositiveClass.intValue());
    }

    public double recall(EvaluationAveraging evaluationAveraging) {
        Preconditions.checkState(this.numRowCounter > 0, "Cannot get recall: no evaluation has been performed");
        int size = confusion().getClasses().size();
        if (evaluationAveraging == EvaluationAveraging.Macro) {
            double d = 0.0d;
            int i = 0;
            for (int i2 = 0; i2 < size; i2++) {
                double recall = recall(i2, -1.0d);
                if (recall != -1.0d) {
                    d += recall;
                    i++;
                }
            }
            return d / i;
        }
        if (evaluationAveraging != EvaluationAveraging.Micro) {
            throw new UnsupportedOperationException("Unknown averaging approach: " + evaluationAveraging);
        }
        long j = 0;
        long j2 = 0;
        for (int i3 = 0; i3 < size; i3++) {
            j = (long) (j + this.truePositives.getCount(Integer.valueOf(i3)));
            j2 = (long) (j2 + this.falseNegatives.getCount(Integer.valueOf(i3)));
        }
        return EvaluationUtils.recall(j, j2, 0.0d);
    }

    public double falsePositiveRate(int i) {
        return falsePositiveRate(i, 0.0d);
    }

    public double falsePositiveRate(int i, double d) {
        Preconditions.checkState(this.numRowCounter > 0, "Cannot get false positive rate: no evaluation has been performed");
        return EvaluationUtils.falsePositiveRate((long) this.falsePositives.getCount(Integer.valueOf(i)), (long) this.trueNegatives.getCount(Integer.valueOf(i)), d);
    }

    public double falsePositiveRate() {
        return (this.binaryPositiveClass == null || numClasses() != 2) ? falsePositiveRate(EvaluationAveraging.Macro) : falsePositiveRate(this.binaryPositiveClass.intValue());
    }

    public double falsePositiveRate(EvaluationAveraging evaluationAveraging) {
        Preconditions.checkState(this.numRowCounter > 0, "Cannot get false positive rate: no evaluation has been performed");
        int size = confusion().getClasses().size();
        if (evaluationAveraging == EvaluationAveraging.Macro) {
            double d = 0.0d;
            for (int i = 0; i < size; i++) {
                d += falsePositiveRate(i);
            }
            return d / size;
        }
        if (evaluationAveraging != EvaluationAveraging.Micro) {
            throw new UnsupportedOperationException("Unknown averaging approach: " + evaluationAveraging);
        }
        long j = 0;
        long j2 = 0;
        for (int i2 = 0; i2 < size; i2++) {
            j = (long) (j + this.falsePositives.getCount(Integer.valueOf(i2)));
            j2 = (long) (j2 + this.trueNegatives.getCount(Integer.valueOf(i2)));
        }
        return EvaluationUtils.falsePositiveRate(j, j2, 0.0d);
    }

    public double falseNegativeRate(Integer num) {
        return falseNegativeRate(num, 0.0d);
    }

    public double falseNegativeRate(Integer num, double d) {
        Preconditions.checkState(this.numRowCounter > 0, "Cannot get false negative rate: no evaluation has been performed");
        return EvaluationUtils.falseNegativeRate((long) this.falseNegatives.getCount(num), (long) this.truePositives.getCount(num), d);
    }

    public double falseNegativeRate() {
        return (this.binaryPositiveClass == null || numClasses() != 2) ? falseNegativeRate(EvaluationAveraging.Macro) : falseNegativeRate(this.binaryPositiveClass);
    }

    public double falseNegativeRate(EvaluationAveraging evaluationAveraging) {
        Preconditions.checkState(this.numRowCounter > 0, "Cannot get false negative rate: no evaluation has been performed");
        int size = confusion().getClasses().size();
        if (evaluationAveraging == EvaluationAveraging.Macro) {
            double d = 0.0d;
            for (int i = 0; i < size; i++) {
                d += falseNegativeRate(Integer.valueOf(i));
            }
            return d / size;
        }
        if (evaluationAveraging != EvaluationAveraging.Micro) {
            throw new UnsupportedOperationException("Unknown averaging approach: " + evaluationAveraging);
        }
        long j = 0;
        long j2 = 0;
        for (int i2 = 0; i2 < size; i2++) {
            j = (long) (j + this.falseNegatives.getCount(Integer.valueOf(i2)));
            j2 = (long) (j2 + this.trueNegatives.getCount(Integer.valueOf(i2)));
        }
        return EvaluationUtils.falseNegativeRate(j, j2, 0.0d);
    }

    public double falseAlarmRate() {
        return (this.binaryPositiveClass == null || numClasses() != 2) ? (falsePositiveRate() + falseNegativeRate()) / 2.0d : (falsePositiveRate(this.binaryPositiveClass.intValue()) + falseNegativeRate(this.binaryPositiveClass)) / 2.0d;
    }

    public double f1(int i) {
        return fBeta(1.0d, i);
    }

    public double fBeta(double d, int i) {
        return fBeta(d, i, 0.0d);
    }

    public double fBeta(double d, int i, double d2) {
        Preconditions.checkState(this.numRowCounter > 0, "Cannot get fBeta score: no evaluation has been performed");
        double precision = precision(Integer.valueOf(i), -1.0d);
        double recall = recall(i, -1.0d);
        return (precision == -1.0d || recall == -1.0d) ? d2 : EvaluationUtils.fBeta(d, precision, recall);
    }

    public double f1() {
        return (this.binaryPositiveClass == null || numClasses() != 2) ? f1(EvaluationAveraging.Macro) : f1(this.binaryPositiveClass.intValue());
    }

    public double f1(EvaluationAveraging evaluationAveraging) {
        return fBeta(1.0d, evaluationAveraging);
    }

    public double fBeta(double d, EvaluationAveraging evaluationAveraging) {
        Preconditions.checkState(this.numRowCounter > 0, "Cannot get fBeta score: no evaluation has been performed");
        int size = confusion().getClasses().size();
        if (size == 2) {
            return EvaluationUtils.fBeta(d, (long) this.truePositives.getCount(1), (long) this.falsePositives.getCount(1), (long) this.falseNegatives.getCount(1));
        }
        if (evaluationAveraging == EvaluationAveraging.Macro) {
            double d2 = 0.0d;
            int i = 0;
            for (int i2 = 0; i2 < size; i2++) {
                double fBeta = fBeta(d, i2, -1.0d);
                if (fBeta != -1.0d) {
                    d2 += fBeta;
                    i++;
                }
            }
            return d2 / i;
        }
        if (evaluationAveraging != EvaluationAveraging.Micro) {
            throw new UnsupportedOperationException("Unknown averaging approach: " + evaluationAveraging);
        }
        long j = 0;
        long j2 = 0;
        long j3 = 0;
        for (int i3 = 0; i3 < size; i3++) {
            j = (long) (j + this.truePositives.getCount(Integer.valueOf(i3)));
            j2 = (long) (j2 + this.falsePositives.getCount(Integer.valueOf(i3)));
            j3 = (long) (j3 + this.falseNegatives.getCount(Integer.valueOf(i3)));
        }
        return EvaluationUtils.fBeta(d, j, j2, j3);
    }

    public double gMeasure(int i) {
        Preconditions.checkState(this.numRowCounter > 0, "Cannot get gMeasure: no evaluation has been performed");
        return EvaluationUtils.gMeasure(precision(Integer.valueOf(i)), recall(i));
    }

    public double gMeasure(EvaluationAveraging evaluationAveraging) {
        Preconditions.checkState(this.numRowCounter > 0, "Cannot get gMeasure: no evaluation has been performed");
        int size = confusion().getClasses().size();
        if (evaluationAveraging == EvaluationAveraging.Macro) {
            double d = 0.0d;
            for (int i = 0; i < size; i++) {
                d += gMeasure(i);
            }
            return d / size;
        }
        if (evaluationAveraging != EvaluationAveraging.Micro) {
            throw new UnsupportedOperationException("Unknown averaging approach: " + evaluationAveraging);
        }
        long j = 0;
        long j2 = 0;
        long j3 = 0;
        for (int i2 = 0; i2 < size; i2++) {
            j = (long) (j + this.truePositives.getCount(Integer.valueOf(i2)));
            j2 = (long) (j2 + this.falsePositives.getCount(Integer.valueOf(i2)));
            j3 = (long) (j3 + this.falseNegatives.getCount(Integer.valueOf(i2)));
        }
        return EvaluationUtils.gMeasure(EvaluationUtils.precision(j, j2, 0.0d), EvaluationUtils.recall(j, j3, 0.0d));
    }

    public double accuracy() {
        Preconditions.checkState(this.numRowCounter > 0, "Cannot get accuracy: no evaluation has been performed");
        int size = confusion().getClasses().size();
        int i = 0;
        for (int i2 = 0; i2 < size; i2++) {
            i += confusion().getCount(Integer.valueOf(i2), Integer.valueOf(i2));
        }
        return i / getNumRowCounter();
    }

    public double topNAccuracy() {
        if (this.topN <= 1) {
            return accuracy();
        }
        if (this.topNTotalCount == 0) {
            return 0.0d;
        }
        return this.topNCorrectCount / this.topNTotalCount;
    }

    public double matthewsCorrelation(int i) {
        Preconditions.checkState(this.numRowCounter > 0, "Cannot get Matthews correlation: no evaluation has been performed");
        return EvaluationUtils.matthewsCorrelation((long) this.truePositives.getCount(Integer.valueOf(i)), (long) this.falsePositives.getCount(Integer.valueOf(i)), (long) this.falseNegatives.getCount(Integer.valueOf(i)), (long) this.trueNegatives.getCount(Integer.valueOf(i)));
    }

    public double matthewsCorrelation(EvaluationAveraging evaluationAveraging) {
        Preconditions.checkState(this.numRowCounter > 0, "Cannot get Matthews correlation: no evaluation has been performed");
        int size = confusion().getClasses().size();
        if (evaluationAveraging == EvaluationAveraging.Macro) {
            double d = 0.0d;
            for (int i = 0; i < size; i++) {
                d += matthewsCorrelation(i);
            }
            return d / size;
        }
        if (evaluationAveraging != EvaluationAveraging.Micro) {
            throw new UnsupportedOperationException("Unknown averaging approach: " + evaluationAveraging);
        }
        long j = 0;
        long j2 = 0;
        long j3 = 0;
        long j4 = 0;
        for (int i2 = 0; i2 < size; i2++) {
            j = (long) (j + this.truePositives.getCount(Integer.valueOf(i2)));
            j2 = (long) (j2 + this.falsePositives.getCount(Integer.valueOf(i2)));
            j3 = (long) (j3 + this.falseNegatives.getCount(Integer.valueOf(i2)));
            j4 = (long) (j4 + this.trueNegatives.getCount(Integer.valueOf(i2)));
        }
        return EvaluationUtils.matthewsCorrelation(j, j2, j3, j4);
    }

    public Map<Integer, Integer> truePositives() {
        return convertToMap(this.truePositives, confusion().getClasses().size());
    }

    public Map<Integer, Integer> trueNegatives() {
        return convertToMap(this.trueNegatives, confusion().getClasses().size());
    }

    public Map<Integer, Integer> falsePositives() {
        return convertToMap(this.falsePositives, confusion().getClasses().size());
    }

    public Map<Integer, Integer> falseNegatives() {
        return convertToMap(this.falseNegatives, confusion().getClasses().size());
    }

    public Map<Integer, Integer> negative() {
        return addMapsByKey(trueNegatives(), falsePositives());
    }

    public Map<Integer, Integer> positive() {
        return addMapsByKey(truePositives(), falseNegatives());
    }

    private Map<Integer, Integer> convertToMap(Counter<Integer> counter, int i) {
        HashMap hashMap = new HashMap();
        for (int i2 = 0; i2 < i; i2++) {
            hashMap.put(Integer.valueOf(i2), Integer.valueOf((int) counter.getCount(Integer.valueOf(i2))));
        }
        return hashMap;
    }

    private Map<Integer, Integer> addMapsByKey(Map<Integer, Integer> map, Map<Integer, Integer> map2) {
        HashMap hashMap = new HashMap();
        HashSet<Integer> hashSet = new HashSet(map.keySet());
        hashSet.addAll(map2.keySet());
        for (Integer num : hashSet) {
            Integer num2 = map.get(num);
            Integer num3 = map2.get(num);
            if (num2 == null) {
                num2 = 0;
            }
            if (num3 == null) {
                num3 = 0;
            }
            hashMap.put(num, Integer.valueOf(num2.intValue() + num3.intValue()));
        }
        return hashMap;
    }

    public void incrementTruePositives(Integer num) {
        this.truePositives.incrementCount(num, 1.0d);
    }

    public void incrementTrueNegatives(Integer num) {
        this.trueNegatives.incrementCount(num, 1.0d);
    }

    public void incrementFalseNegatives(Integer num) {
        this.falseNegatives.incrementCount(num, 1.0d);
    }

    public void incrementFalsePositives(Integer num) {
        this.falsePositives.incrementCount(num, 1.0d);
    }

    public void addToConfusion(Integer num, Integer num2) {
        confusion().add(num, num2);
    }

    public int classCount(Integer num) {
        return confusion().getActualTotal(num);
    }

    public int getNumRowCounter() {
        return this.numRowCounter;
    }

    public int getTopNCorrectCount() {
        if (this.confusion == null) {
            return 0;
        }
        if (this.topN > 1) {
            return this.topNCorrectCount;
        }
        int size = confusion().getClasses().size();
        int i = 0;
        for (int i2 = 0; i2 < size; i2++) {
            i += confusion().getCount(Integer.valueOf(i2), Integer.valueOf(i2));
        }
        return i;
    }

    public int getTopNTotalCount() {
        return this.topN <= 1 ? getNumRowCounter() : this.topNTotalCount;
    }

    public String getClassLabel(Integer num) {
        return resolveLabelForClass(num);
    }

    public ConfusionMatrix<Integer> getConfusionMatrix() {
        return this.confusion;
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public void merge(Evaluation evaluation) {
        if (evaluation == null) {
            return;
        }
        this.truePositives.incrementAll(evaluation.truePositives);
        this.falsePositives.incrementAll(evaluation.falsePositives);
        this.trueNegatives.incrementAll(evaluation.trueNegatives);
        this.falseNegatives.incrementAll(evaluation.falseNegatives);
        if (this.confusion == null) {
            if (evaluation.confusion != null) {
                this.confusion = new ConfusionMatrix<>(evaluation.confusion);
            }
        } else if (evaluation.confusion != null) {
            confusion().add(evaluation.confusion);
        }
        this.numRowCounter += evaluation.numRowCounter;
        if (this.labelsList.isEmpty()) {
            this.labelsList.addAll(evaluation.labelsList);
        }
        if (this.topN != evaluation.topN) {
            log.warn("Different topN values ({} vs {}) detected during Evaluation merging. Top N accuracy may not be accurate.", Integer.valueOf(this.topN), Integer.valueOf(evaluation.topN));
        }
        this.topNCorrectCount += evaluation.topNCorrectCount;
        this.topNTotalCount += evaluation.topNTotalCount;
    }

    public String confusionToString() {
        int size = confusion().getClasses().size();
        int i = 0;
        Iterator<String> it = this.labelsList.iterator();
        while (it.hasNext()) {
            i = Math.max(i, it.next().length());
        }
        int max = Math.max(i + 5, 10);
        StringBuilder sb = new StringBuilder();
        sb.append("%-3d");
        sb.append("%-");
        sb.append(max);
        sb.append("s | ");
        StringBuilder sb2 = new StringBuilder();
        sb2.append("   %-").append(max).append("s   ");
        for (int i2 = 0; i2 < size; i2++) {
            sb.append("%7d");
            sb2.append("%7d");
        }
        String sb3 = sb.toString();
        StringBuilder sb4 = new StringBuilder();
        Object[] objArr = new Object[size + 1];
        objArr[0] = "Predicted:";
        for (int i3 = 0; i3 < size; i3++) {
            objArr[i3 + 1] = Integer.valueOf(i3);
        }
        sb4.append(String.format(sb2.toString(), objArr)).append("\n");
        sb4.append("   Actual:\n");
        for (int i4 = 0; i4 < size; i4++) {
            Object[] objArr2 = new Object[size + 2];
            objArr2[0] = Integer.valueOf(i4);
            objArr2[1] = this.labelsList.get(i4);
            for (int i5 = 0; i5 < size; i5++) {
                objArr2[i5 + 2] = Integer.valueOf(confusion().getCount(Integer.valueOf(i4), Integer.valueOf(i5)));
            }
            sb4.append(String.format(sb3, objArr2));
            sb4.append("\n");
        }
        return sb4.toString();
    }

    private void addToMetaConfusionMatrix(int i, int i2, Object obj) {
        if (this.confusionMatrixMetaData == null) {
            this.confusionMatrixMetaData = new HashMap();
        }
        Pair<Integer, Integer> pair = new Pair<>(Integer.valueOf(i), Integer.valueOf(i2));
        List<Object> list = this.confusionMatrixMetaData.get(pair);
        if (list == null) {
            list = new ArrayList();
            this.confusionMatrixMetaData.put(pair, list);
        }
        list.add(obj);
    }

    public List<Prediction> getPredictionErrors() {
        if (this.confusionMatrixMetaData == null) {
            return null;
        }
        ArrayList arrayList = new ArrayList();
        ArrayList<Map.Entry> arrayList2 = new ArrayList(this.confusionMatrixMetaData.entrySet());
        Collections.sort(arrayList2, new Comparator<Map.Entry<Pair<Integer, Integer>, List<Object>>>() { // from class: org.nd4j.evaluation.classification.Evaluation.1
            @Override // java.util.Comparator
            public int compare(Map.Entry<Pair<Integer, Integer>, List<Object>> entry, Map.Entry<Pair<Integer, Integer>, List<Object>> entry2) {
                Pair<Integer, Integer> key = entry.getKey();
                Pair<Integer, Integer> key2 = entry2.getKey();
                int compare = Integer.compare(((Integer) key.getFirst()).intValue(), ((Integer) key2.getFirst()).intValue());
                return compare != 0 ? compare : Integer.compare(((Integer) key.getSecond()).intValue(), ((Integer) key2.getSecond()).intValue());
            }
        });
        for (Map.Entry entry : arrayList2) {
            Pair pair = (Pair) entry.getKey();
            if (!((Integer) pair.getFirst()).equals(pair.getSecond())) {
                Iterator it = ((List) entry.getValue()).iterator();
                while (it.hasNext()) {
                    arrayList.add(new Prediction(((Integer) pair.getFirst()).intValue(), ((Integer) pair.getSecond()).intValue(), it.next()));
                }
            }
        }
        return arrayList;
    }

    public List<Prediction> getPredictionsByActualClass(int i) {
        if (this.confusionMatrixMetaData == null) {
            return null;
        }
        ArrayList arrayList = new ArrayList();
        for (Map.Entry<Pair<Integer, Integer>, List<Object>> entry : this.confusionMatrixMetaData.entrySet()) {
            if (((Integer) entry.getKey().getFirst()).intValue() == i) {
                int intValue = ((Integer) entry.getKey().getFirst()).intValue();
                int intValue2 = ((Integer) entry.getKey().getSecond()).intValue();
                Iterator<Object> it = entry.getValue().iterator();
                while (it.hasNext()) {
                    arrayList.add(new Prediction(intValue, intValue2, it.next()));
                }
            }
        }
        return arrayList;
    }

    public List<Prediction> getPredictionByPredictedClass(int i) {
        if (this.confusionMatrixMetaData == null) {
            return null;
        }
        ArrayList arrayList = new ArrayList();
        for (Map.Entry<Pair<Integer, Integer>, List<Object>> entry : this.confusionMatrixMetaData.entrySet()) {
            if (((Integer) entry.getKey().getSecond()).intValue() == i) {
                int intValue = ((Integer) entry.getKey().getFirst()).intValue();
                int intValue2 = ((Integer) entry.getKey().getSecond()).intValue();
                Iterator<Object> it = entry.getValue().iterator();
                while (it.hasNext()) {
                    arrayList.add(new Prediction(intValue, intValue2, it.next()));
                }
            }
        }
        return arrayList;
    }

    public List<Prediction> getPredictions(int i, int i2) {
        if (this.confusionMatrixMetaData == null) {
            return null;
        }
        ArrayList arrayList = new ArrayList();
        List<Object> list = this.confusionMatrixMetaData.get(new Pair(Integer.valueOf(i), Integer.valueOf(i2)));
        if (list == null) {
            return arrayList;
        }
        Iterator<Object> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(new Prediction(i, i2, it.next()));
        }
        return arrayList;
    }

    public double scoreForMetric(Metric metric) {
        switch (metric) {
            case ACCURACY:
                return accuracy();
            case F1:
                return f1();
            case PRECISION:
                return precision();
            case RECALL:
                return recall();
            case GMEASURE:
                return gMeasure(EvaluationAveraging.Macro);
            case MCC:
                return matthewsCorrelation(EvaluationAveraging.Macro);
            default:
                throw new IllegalStateException("Unknown metric: " + metric);
        }
    }

    public static Evaluation fromJson(String str) {
        return (Evaluation) fromJson(str, Evaluation.class);
    }

    public static Evaluation fromYaml(String str) {
        return (Evaluation) fromYaml(str, Evaluation.class);
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public double getValue(IMetric iMetric) {
        if (iMetric instanceof Metric) {
            return scoreForMetric((Metric) iMetric);
        }
        throw new IllegalStateException("Can't get value for non-evaluation Metric " + iMetric);
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public Evaluation newInstance() {
        return new Evaluation(this.axis, this.binaryPositiveClass, this.topN, this.labelsList, this.binaryDecisionThreshold, this.costArray, this.maxWarningClassesToPrint);
    }

    @Override // org.nd4j.evaluation.BaseEvaluation
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof Evaluation)) {
            return false;
        }
        Evaluation evaluation = (Evaluation) obj;
        if (!evaluation.canEqual(this) || !super.equals(obj) || getTopN() != evaluation.getTopN() || getTopNCorrectCount() != evaluation.getTopNCorrectCount() || getTopNTotalCount() != evaluation.getTopNTotalCount() || getNumRowCounter() != evaluation.getNumRowCounter() || getMaxWarningClassesToPrint() != evaluation.getMaxWarningClassesToPrint()) {
            return false;
        }
        Integer binaryPositiveClass = getBinaryPositiveClass();
        Integer binaryPositiveClass2 = evaluation.getBinaryPositiveClass();
        if (binaryPositiveClass == null) {
            if (binaryPositiveClass2 != null) {
                return false;
            }
        } else if (!binaryPositiveClass.equals(binaryPositiveClass2)) {
            return false;
        }
        Double binaryDecisionThreshold = getBinaryDecisionThreshold();
        Double binaryDecisionThreshold2 = evaluation.getBinaryDecisionThreshold();
        if (binaryDecisionThreshold == null) {
            if (binaryDecisionThreshold2 != null) {
                return false;
            }
        } else if (!binaryDecisionThreshold.equals(binaryDecisionThreshold2)) {
            return false;
        }
        Counter<Integer> truePositives = getTruePositives();
        Counter<Integer> truePositives2 = evaluation.getTruePositives();
        if (truePositives == null) {
            if (truePositives2 != null) {
                return false;
            }
        } else if (!truePositives.equals(truePositives2)) {
            return false;
        }
        Counter<Integer> falsePositives = getFalsePositives();
        Counter<Integer> falsePositives2 = evaluation.getFalsePositives();
        if (falsePositives == null) {
            if (falsePositives2 != null) {
                return false;
            }
        } else if (!falsePositives.equals(falsePositives2)) {
            return false;
        }
        Counter<Integer> trueNegatives = getTrueNegatives();
        Counter<Integer> trueNegatives2 = evaluation.getTrueNegatives();
        if (trueNegatives == null) {
            if (trueNegatives2 != null) {
                return false;
            }
        } else if (!trueNegatives.equals(trueNegatives2)) {
            return false;
        }
        Counter<Integer> falseNegatives = getFalseNegatives();
        Counter<Integer> falseNegatives2 = evaluation.getFalseNegatives();
        if (falseNegatives == null) {
            if (falseNegatives2 != null) {
                return false;
            }
        } else if (!falseNegatives.equals(falseNegatives2)) {
            return false;
        }
        ConfusionMatrix<Integer> confusion = getConfusion();
        ConfusionMatrix<Integer> confusion2 = evaluation.getConfusion();
        if (confusion == null) {
            if (confusion2 != null) {
                return false;
            }
        } else if (!confusion.equals(confusion2)) {
            return false;
        }
        List<String> labelsList = getLabelsList();
        List<String> labelsList2 = evaluation.getLabelsList();
        if (labelsList == null) {
            if (labelsList2 != null) {
                return false;
            }
        } else if (!labelsList.equals(labelsList2)) {
            return false;
        }
        INDArray costArray = getCostArray();
        INDArray costArray2 = evaluation.getCostArray();
        if (costArray == null) {
            if (costArray2 != null) {
                return false;
            }
        } else if (!costArray.equals(costArray2)) {
            return false;
        }
        Map<Pair<Integer, Integer>, List<Object>> confusionMatrixMetaData = getConfusionMatrixMetaData();
        Map<Pair<Integer, Integer>, List<Object>> confusionMatrixMetaData2 = evaluation.getConfusionMatrixMetaData();
        return confusionMatrixMetaData == null ? confusionMatrixMetaData2 == null : confusionMatrixMetaData.equals(confusionMatrixMetaData2);
    }

    @Override // org.nd4j.evaluation.BaseEvaluation
    protected boolean canEqual(Object obj) {
        return obj instanceof Evaluation;
    }

    @Override // org.nd4j.evaluation.BaseEvaluation
    public int hashCode() {
        int hashCode = (((((((((super.hashCode() * 59) + getTopN()) * 59) + getTopNCorrectCount()) * 59) + getTopNTotalCount()) * 59) + getNumRowCounter()) * 59) + getMaxWarningClassesToPrint();
        Integer binaryPositiveClass = getBinaryPositiveClass();
        int hashCode2 = (hashCode * 59) + (binaryPositiveClass == null ? 43 : binaryPositiveClass.hashCode());
        Double binaryDecisionThreshold = getBinaryDecisionThreshold();
        int hashCode3 = (hashCode2 * 59) + (binaryDecisionThreshold == null ? 43 : binaryDecisionThreshold.hashCode());
        Counter<Integer> truePositives = getTruePositives();
        int hashCode4 = (hashCode3 * 59) + (truePositives == null ? 43 : truePositives.hashCode());
        Counter<Integer> falsePositives = getFalsePositives();
        int hashCode5 = (hashCode4 * 59) + (falsePositives == null ? 43 : falsePositives.hashCode());
        Counter<Integer> trueNegatives = getTrueNegatives();
        int hashCode6 = (hashCode5 * 59) + (trueNegatives == null ? 43 : trueNegatives.hashCode());
        Counter<Integer> falseNegatives = getFalseNegatives();
        int hashCode7 = (hashCode6 * 59) + (falseNegatives == null ? 43 : falseNegatives.hashCode());
        ConfusionMatrix<Integer> confusion = getConfusion();
        int hashCode8 = (hashCode7 * 59) + (confusion == null ? 43 : confusion.hashCode());
        List<String> labelsList = getLabelsList();
        int hashCode9 = (hashCode8 * 59) + (labelsList == null ? 43 : labelsList.hashCode());
        INDArray costArray = getCostArray();
        int hashCode10 = (hashCode9 * 59) + (costArray == null ? 43 : costArray.hashCode());
        Map<Pair<Integer, Integer>, List<Object>> confusionMatrixMetaData = getConfusionMatrixMetaData();
        return (hashCode10 * 59) + (confusionMatrixMetaData == null ? 43 : confusionMatrixMetaData.hashCode());
    }

    public Integer getBinaryPositiveClass() {
        return this.binaryPositiveClass;
    }

    public int getTopN() {
        return this.topN;
    }

    public Counter<Integer> getTruePositives() {
        return this.truePositives;
    }

    public Counter<Integer> getFalsePositives() {
        return this.falsePositives;
    }

    public Counter<Integer> getTrueNegatives() {
        return this.trueNegatives;
    }

    public Counter<Integer> getFalseNegatives() {
        return this.falseNegatives;
    }

    public ConfusionMatrix<Integer> getConfusion() {
        return this.confusion;
    }

    public Double getBinaryDecisionThreshold() {
        return this.binaryDecisionThreshold;
    }

    public INDArray getCostArray() {
        return this.costArray;
    }

    public Map<Pair<Integer, Integer>, List<Object>> getConfusionMatrixMetaData() {
        return this.confusionMatrixMetaData;
    }

    public void setBinaryPositiveClass(Integer num) {
        this.binaryPositiveClass = num;
    }

    public void setTopNCorrectCount(int i) {
        this.topNCorrectCount = i;
    }

    public void setTopNTotalCount(int i) {
        this.topNTotalCount = i;
    }

    public void setTruePositives(Counter<Integer> counter) {
        this.truePositives = counter;
    }

    public void setFalsePositives(Counter<Integer> counter) {
        this.falsePositives = counter;
    }

    public void setTrueNegatives(Counter<Integer> counter) {
        this.trueNegatives = counter;
    }

    public void setFalseNegatives(Counter<Integer> counter) {
        this.falseNegatives = counter;
    }

    public void setConfusion(ConfusionMatrix<Integer> confusionMatrix) {
        this.confusion = confusionMatrix;
    }

    public void setNumRowCounter(int i) {
        this.numRowCounter = i;
    }

    public void setBinaryDecisionThreshold(Double d) {
        this.binaryDecisionThreshold = d;
    }

    public void setCostArray(INDArray iNDArray) {
        this.costArray = iNDArray;
    }

    public void setConfusionMatrixMetaData(Map<Pair<Integer, Integer>, List<Object>> map) {
        this.confusionMatrixMetaData = map;
    }

    public List<String> getLabelsList() {
        return this.labelsList;
    }

    public void setLabelsList(List<String> list) {
        this.labelsList = list;
    }

    public int getMaxWarningClassesToPrint() {
        return this.maxWarningClassesToPrint;
    }

    public void setMaxWarningClassesToPrint(int i) {
        this.maxWarningClassesToPrint = i;
    }
}
