package org.nd4j.linalg.api.ops.random;

import java.util.Collections;
import java.util.List;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.RandomOp;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/nd4j/linalg/api/ops/random/BaseRandomOp.class */
public abstract class BaseRandomOp extends BaseOp implements RandomOp {
    protected long[] shape;
    protected DataType dataType;

    public BaseRandomOp(SameDiff sameDiff, SDVariable sDVariable) {
        this.dataType = Nd4j.defaultFloatingPointType();
        Preconditions.checkNotNull(sDVariable, "Input variable can't be null with this constructor");
        this.sameDiff = sameDiff;
        this.xVertexId = sDVariable.name();
        if (sDVariable.getShape() != null) {
            this.shape = sDVariable.getShape();
        } else if (sDVariable.getArr() != null && sDVariable.getArr().shape() != null) {
            this.shape = sDVariable.getArr().shape();
        }
        sameDiff.addArgsFor(new String[]{this.xVertexId}, this);
    }

    public BaseRandomOp(SameDiff sameDiff, long[] jArr) {
        super(sameDiff, (Object[]) null);
        this.dataType = Nd4j.defaultFloatingPointType();
        Preconditions.checkArgument(jArr != null && jArr.length > 0, "Shape must be non-null, length > 0. Got: %s", jArr);
        this.sameDiff = sameDiff;
        this.shape = jArr;
        setInstanceId();
        this.sameDiff.addArgsFor(new String[0], this);
    }

    public BaseRandomOp(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        super(iNDArray, iNDArray2, iNDArray3);
        this.dataType = Nd4j.defaultFloatingPointType();
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Op.Type opType() {
        return Op.Type.RANDOM;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<LongShapeDescriptor> calculateOutputShape() {
        return calculateOutputShape(null);
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<DataType> calculateOutputDataTypes(List<DataType> list) {
        return Collections.singletonList(DataType.FLOAT);
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public boolean isInPlace() {
        return this.x == null || this.x == this.z || this.x.data().pointer().address() == this.z.data().pointer().address();
    }

    public boolean isTripleArgRngOp() {
        return false;
    }

    public BaseRandomOp() {
        this.dataType = Nd4j.defaultFloatingPointType();
    }
}
