package org.nd4j.linalg.api.ops.impl.shape;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

/* loaded from: input_file:org/nd4j/linalg/api/ops/impl/shape/ConfusionMatrix.class */
public class ConfusionMatrix extends DynamicCustomOp {
    public static final DataType DEFAULT_DTYPE = DataType.INT;
    private DataType outputType;

    public ConfusionMatrix() {
        this.outputType = DEFAULT_DTYPE;
    }

    public ConfusionMatrix(@NonNull INDArray iNDArray, @NonNull INDArray iNDArray2, @NonNull DataType dataType) {
        super(new INDArray[]{iNDArray, iNDArray2}, (INDArray[]) null);
        this.outputType = DEFAULT_DTYPE;
        if (iNDArray == null) {
            throw new NullPointerException("labels is marked non-null but is null");
        }
        if (iNDArray2 == null) {
            throw new NullPointerException("predicted is marked non-null but is null");
        }
        if (dataType == null) {
            throw new NullPointerException("dataType is marked non-null but is null");
        }
        this.outputType = dataType;
        addDArgument(dataType);
    }

    public ConfusionMatrix(@NonNull INDArray iNDArray, @NonNull INDArray iNDArray2, int i) {
        this(iNDArray, iNDArray2, Integer.valueOf(i), DEFAULT_DTYPE);
        if (iNDArray == null) {
            throw new NullPointerException("labels is marked non-null but is null");
        }
        if (iNDArray2 == null) {
            throw new NullPointerException("predicted is marked non-null but is null");
        }
    }

    public ConfusionMatrix(@NonNull INDArray iNDArray, @NonNull INDArray iNDArray2, INDArray iNDArray3) {
        this(iNDArray, iNDArray2, iNDArray3, (Integer) null);
        if (iNDArray == null) {
            throw new NullPointerException("labels is marked non-null but is null");
        }
        if (iNDArray2 == null) {
            throw new NullPointerException("predicted is marked non-null but is null");
        }
    }

    public ConfusionMatrix(@NonNull INDArray iNDArray, @NonNull INDArray iNDArray2, INDArray iNDArray3, Integer num) {
        this(iNDArray, iNDArray2, iNDArray3, num, DEFAULT_DTYPE);
        if (iNDArray == null) {
            throw new NullPointerException("labels is marked non-null but is null");
        }
        if (iNDArray2 == null) {
            throw new NullPointerException("predicted is marked non-null but is null");
        }
    }

    public ConfusionMatrix(@NonNull INDArray iNDArray, @NonNull INDArray iNDArray2, Integer num, @NonNull DataType dataType) {
        this(iNDArray, iNDArray2, (INDArray) null, num, dataType);
        if (iNDArray == null) {
            throw new NullPointerException("labels is marked non-null but is null");
        }
        if (iNDArray2 == null) {
            throw new NullPointerException("predicted is marked non-null but is null");
        }
        if (dataType == null) {
            throw new NullPointerException("dataType is marked non-null but is null");
        }
    }

    public ConfusionMatrix(@NonNull INDArray iNDArray, @NonNull INDArray iNDArray2, INDArray iNDArray3, Integer num, @NonNull DataType dataType) {
        super((INDArray[]) wrapFilterNull(iNDArray, iNDArray2, iNDArray3), (INDArray[]) null);
        this.outputType = DEFAULT_DTYPE;
        if (iNDArray == null) {
            throw new NullPointerException("labels is marked non-null but is null");
        }
        if (iNDArray2 == null) {
            throw new NullPointerException("predicted is marked non-null but is null");
        }
        if (dataType == null) {
            throw new NullPointerException("dataType is marked non-null but is null");
        }
        this.outputType = dataType;
        if (num != null) {
            addIArgument(num.intValue());
        }
        addDArgument(dataType);
    }

    public ConfusionMatrix(SameDiff sameDiff, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, DataType dataType) {
        this(sameDiff, sDVariable, sDVariable2, sDVariable3);
        this.outputType = dataType;
    }

    public ConfusionMatrix(SameDiff sameDiff, SDVariable sDVariable, SDVariable sDVariable2, DataType dataType) {
        super((String) null, sameDiff, new SDVariable[]{sDVariable, sDVariable2});
        this.outputType = DEFAULT_DTYPE;
        this.outputType = dataType;
        addDArgument(dataType);
    }

    public ConfusionMatrix(SameDiff sameDiff, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3) {
        super((String) null, sameDiff, new SDVariable[]{sDVariable, sDVariable2, sDVariable3});
        this.outputType = DEFAULT_DTYPE;
    }

    public ConfusionMatrix(SameDiff sameDiff, SDVariable sDVariable, SDVariable sDVariable2, Integer num) {
        super((String) null, sameDiff, new SDVariable[]{sDVariable, sDVariable2});
        this.outputType = DEFAULT_DTYPE;
        addIArgument(num.intValue());
    }

    public ConfusionMatrix(SameDiff sameDiff, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, Integer num) {
        super((String) null, sameDiff, new SDVariable[]{sDVariable, sDVariable2, sDVariable3});
        this.outputType = DEFAULT_DTYPE;
        addIArgument(num.intValue());
    }

    public ConfusionMatrix(SameDiff sameDiff, SDVariable sDVariable, SDVariable sDVariable2, Integer num, SDVariable sDVariable3) {
        super((String) null, sameDiff, new SDVariable[]{sDVariable, sDVariable2, sDVariable3});
        this.outputType = DEFAULT_DTYPE;
        if (num != null) {
            addIArgument(num.intValue());
        }
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff sameDiff, Map<String, AttrValue> map, GraphDef graphDef) {
        super.initFromTensorFlow(nodeDef, sameDiff, map, graphDef);
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction, org.nd4j.linalg.api.ops.CustomOp
    public String opName() {
        return "confusion_matrix";
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String tensorflowName() {
        return "ConfusionMatrix";
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public List<SDVariable> doDiff(List<SDVariable> list) {
        return Arrays.asList(this.sameDiff.zerosLike(arg(0)), this.sameDiff.zerosLike(arg(1)));
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<DataType> calculateOutputDataTypes(List<DataType> list) {
        return Collections.singletonList(this.outputType);
    }
}
