package org.nd4j.linalg.api.ops;

import java.util.Collections;
import java.util.List;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.util.SameDiffUtils;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/api/ops/BaseIndexAccumulation.class */
public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccumulation {
    private static final Logger log = LoggerFactory.getLogger(BaseIndexAccumulation.class);
    protected boolean keepDims;

    public BaseIndexAccumulation(SameDiff sameDiff, SDVariable sDVariable, boolean z, int[] iArr) {
        super(sameDiff, (Object[]) null);
        this.keepDims = false;
        if (sDVariable == null) {
            throw new IllegalArgumentException("Input not null variable.");
        }
        this.dimensions = iArr;
        sameDiff.addArgsFor(new SDVariable[]{sDVariable}, this);
        this.xVertexId = sDVariable.name();
        this.keepDims = z;
        defineDimensions(iArr);
    }

    public BaseIndexAccumulation(SameDiff sameDiff, SDVariable sDVariable, SDVariable sDVariable2, boolean z, int[] iArr) {
        super(sameDiff, (Object[]) null);
        this.keepDims = false;
        if (sDVariable == null) {
            throw new IllegalArgumentException("Input not null variable.");
        }
        this.dimensions = iArr;
        SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, sDVariable, this);
        SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, sDVariable2, this);
        this.xVertexId = sDVariable.name();
        this.yVertexId = sDVariable2.name();
        sameDiff.addArgsFor(new SDVariable[]{sDVariable, sDVariable2}, this);
        this.keepDims = z;
        defineDimensions(iArr);
    }

    public BaseIndexAccumulation() {
        this.keepDims = false;
    }

    public BaseIndexAccumulation(INDArray iNDArray, int[] iArr) {
        this(iNDArray, (INDArray) null, iArr);
    }

    public BaseIndexAccumulation(INDArray iNDArray, boolean z, int[] iArr) {
        this(iNDArray, (INDArray) null, iArr);
        this.keepDims = z;
        defineDimensions(iArr);
    }

    public BaseIndexAccumulation(INDArray iNDArray, INDArray iNDArray2, int[] iArr) {
        super(iNDArray, iNDArray2);
        this.keepDims = false;
        defineDimensions(iArr);
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<LongShapeDescriptor> calculateOutputShape() {
        return calculateOutputShape(null);
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<LongShapeDescriptor> calculateOutputShape(OpContext opContext) {
        INDArray inputArray = opContext != null ? opContext.getInputArray(0) : x();
        return inputArray == null ? Collections.emptyList() : Collections.singletonList(LongShapeDescriptor.fromShape(Shape.getReducedShape(inputArray.shape(), this.dimensions, this.keepDims), DataType.LONG));
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Op.Type opType() {
        return Op.Type.INDEXREDUCE;
    }

    @Override // org.nd4j.linalg.api.ops.IndexAccumulation
    public boolean validateDataTypes() {
        if (z() == null) {
            return true;
        }
        Preconditions.checkArgument(z().dataType() == DataType.LONG, "IndexReduce operations require LONG output: got result array of type %s for op %s", this.z.dataType(), getClass());
        return true;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<DataType> calculateOutputDataTypes(List<DataType> list) {
        Preconditions.checkState(list != null && list.size() == 1, "Expected exactly 1 input datatype for %s, got input %s", getClass(), list);
        return Collections.singletonList(DataType.LONG);
    }

    @Override // org.nd4j.linalg.api.ops.IndexAccumulation
    public boolean isKeepDims() {
        return this.keepDims;
    }

    public void setKeepDims(boolean z) {
        this.keepDims = z;
    }

    @Override // org.nd4j.linalg.api.ops.BaseOp, org.nd4j.autodiff.functions.DifferentialFunction
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof BaseIndexAccumulation)) {
            return false;
        }
        BaseIndexAccumulation baseIndexAccumulation = (BaseIndexAccumulation) obj;
        return baseIndexAccumulation.canEqual(this) && isKeepDims() == baseIndexAccumulation.isKeepDims();
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof BaseIndexAccumulation;
    }

    @Override // org.nd4j.linalg.api.ops.BaseOp, org.nd4j.autodiff.functions.DifferentialFunction
    public int hashCode() {
        return (1 * 59) + (isKeepDims() ? 79 : 97);
    }

    @Override // org.nd4j.linalg.api.ops.BaseOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String toString() {
        return "BaseIndexAccumulation(keepDims=" + isKeepDims() + ")";
    }
}
