package org.nd4j.autodiff.samediff.ops;

import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.enums.PadMode;
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd;
import org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm;
import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU;
import org.nd4j.linalg.api.ops.impl.scalar.PRelu;
import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear;
import org.nd4j.linalg.api.ops.impl.scalar.Relu6;
import org.nd4j.linalg.api.ops.impl.transforms.Pad;
import org.nd4j.linalg.api.ops.impl.transforms.ReluLayer;
import org.nd4j.linalg.api.ops.impl.transforms.custom.CReLU;
import org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.custom.TopK;
import org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.ELU;
import org.nd4j.linalg.api.ops.impl.transforms.strict.GELU;
import org.nd4j.linalg.api.ops.impl.transforms.strict.HardSigmoid;
import org.nd4j.linalg.api.ops.impl.transforms.strict.HardTanh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid;
import org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SELU;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftPlus;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftSign;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Swish;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh;
import org.nd4j.linalg.api.ops.random.impl.DropOut;
import org.nd4j.linalg.api.ops.random.impl.DropOutInverted;

/* loaded from: input_file:org/nd4j/autodiff/samediff/ops/SDNN.class */
public class SDNN extends SDOps {
    public SDNN(SameDiff sameDiff) {
        super(sameDiff);
    }

    public SDVariable cReLU(SDVariable sDVariable) {
        SDValidation.validateNumerical("CReLU", "x", sDVariable);
        return new CReLU(this.sd, sDVariable).outputVariable();
    }

    public SDVariable cReLU(String str, SDVariable sDVariable) {
        SDValidation.validateNumerical("CReLU", "x", sDVariable);
        return this.sd.updateVariableNameAndReference(new CReLU(this.sd, sDVariable).outputVariable(), str);
    }

    public SDVariable batchNorm(SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, SDVariable sDVariable4, SDVariable sDVariable5, double d, int... iArr) {
        SDValidation.validateNumerical("batchNorm", "input", sDVariable);
        SDValidation.validateNumerical("batchNorm", "mean", sDVariable2);
        SDValidation.validateNumerical("batchNorm", "variance", sDVariable3);
        SDValidation.validateNumerical("batchNorm", "gamma", sDVariable4);
        SDValidation.validateNumerical("batchNorm", "beta", sDVariable5);
        Preconditions.checkArgument(iArr.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", iArr.length);
        return new BatchNorm(this.sd, sDVariable, sDVariable2, sDVariable3, sDVariable4, sDVariable5, d, iArr).outputVariable();
    }

    public SDVariable batchNorm(String str, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, SDVariable sDVariable4, SDVariable sDVariable5, double d, int... iArr) {
        SDValidation.validateNumerical("batchNorm", "input", sDVariable);
        SDValidation.validateNumerical("batchNorm", "mean", sDVariable2);
        SDValidation.validateNumerical("batchNorm", "variance", sDVariable3);
        SDValidation.validateNumerical("batchNorm", "gamma", sDVariable4);
        SDValidation.validateNumerical("batchNorm", "beta", sDVariable5);
        Preconditions.checkArgument(iArr.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", iArr.length);
        return this.sd.updateVariableNameAndReference(new BatchNorm(this.sd, sDVariable, sDVariable2, sDVariable3, sDVariable4, sDVariable5, d, iArr).outputVariable(), str);
    }

    public SDVariable biasAdd(SDVariable sDVariable, SDVariable sDVariable2, boolean z) {
        SDValidation.validateNumerical("biasAdd", "input", sDVariable);
        SDValidation.validateNumerical("biasAdd", "bias", sDVariable2);
        return new BiasAdd(this.sd, sDVariable, sDVariable2, z).outputVariable();
    }

    public SDVariable biasAdd(String str, SDVariable sDVariable, SDVariable sDVariable2, boolean z) {
        SDValidation.validateNumerical("biasAdd", "input", sDVariable);
        SDValidation.validateNumerical("biasAdd", "bias", sDVariable2);
        return this.sd.updateVariableNameAndReference(new BiasAdd(this.sd, sDVariable, sDVariable2, z).outputVariable(), str);
    }

    public SDVariable dotProductAttention(SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, SDVariable sDVariable4, boolean z) {
        SDValidation.validateNumerical("dotProductAttention", "queries", sDVariable);
        SDValidation.validateNumerical("dotProductAttention", "keys", sDVariable2);
        SDValidation.validateNumerical("dotProductAttention", "values", sDVariable3);
        SDValidation.validateNumerical("dotProductAttention", "mask", sDVariable4);
        return new DotProductAttention(this.sd, sDVariable, sDVariable2, sDVariable3, sDVariable4, z, false).outputVariable();
    }

    public SDVariable dotProductAttention(String str, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, SDVariable sDVariable4, boolean z) {
        SDValidation.validateNumerical("dotProductAttention", "queries", sDVariable);
        SDValidation.validateNumerical("dotProductAttention", "keys", sDVariable2);
        SDValidation.validateNumerical("dotProductAttention", "values", sDVariable3);
        SDValidation.validateNumerical("dotProductAttention", "mask", sDVariable4);
        return this.sd.updateVariableNameAndReference(new DotProductAttention(this.sd, sDVariable, sDVariable2, sDVariable3, sDVariable4, z, false).outputVariable(), str);
    }

    public SDVariable dropout(SDVariable sDVariable, double d) {
        SDValidation.validateNumerical("dropout", "input", sDVariable);
        return new DropOut(this.sd, sDVariable, d).outputVariable();
    }

    public SDVariable dropout(String str, SDVariable sDVariable, double d) {
        SDValidation.validateNumerical("dropout", "input", sDVariable);
        return this.sd.updateVariableNameAndReference(new DropOut(this.sd, sDVariable, d).outputVariable(), str);
    }

    public SDVariable dropoutInverted(SDVariable sDVariable, double d) {
        SDValidation.validateNumerical("dropoutInverted", "input", sDVariable);
        return new DropOutInverted(this.sd, sDVariable, d).outputVariable();
    }

    public SDVariable dropoutInverted(String str, SDVariable sDVariable, double d) {
        SDValidation.validateNumerical("dropoutInverted", "input", sDVariable);
        return this.sd.updateVariableNameAndReference(new DropOutInverted(this.sd, sDVariable, d).outputVariable(), str);
    }

    public SDVariable elu(SDVariable sDVariable) {
        SDValidation.validateNumerical("elu", "x", sDVariable);
        return new ELU(this.sd, sDVariable).outputVariable();
    }

    public SDVariable elu(String str, SDVariable sDVariable) {
        SDValidation.validateNumerical("elu", "x", sDVariable);
        return this.sd.updateVariableNameAndReference(new ELU(this.sd, sDVariable).outputVariable(), str);
    }

    public SDVariable gelu(SDVariable sDVariable) {
        SDValidation.validateNumerical("gelu", "x", sDVariable);
        return new GELU(this.sd, sDVariable).outputVariable();
    }

    public SDVariable gelu(String str, SDVariable sDVariable) {
        SDValidation.validateNumerical("gelu", "x", sDVariable);
        return this.sd.updateVariableNameAndReference(new GELU(this.sd, sDVariable).outputVariable(), str);
    }

    public SDVariable hardSigmoid(SDVariable sDVariable) {
        SDValidation.validateNumerical("hardSigmoid", "x", sDVariable);
        return new HardSigmoid(this.sd, sDVariable).outputVariable();
    }

    public SDVariable hardSigmoid(String str, SDVariable sDVariable) {
        SDValidation.validateNumerical("hardSigmoid", "x", sDVariable);
        return this.sd.updateVariableNameAndReference(new HardSigmoid(this.sd, sDVariable).outputVariable(), str);
    }

    public SDVariable hardTanh(SDVariable sDVariable) {
        SDValidation.validateNumerical("hardTanh", "x", sDVariable);
        return new HardTanh(this.sd, sDVariable).outputVariable();
    }

    public SDVariable hardTanh(String str, SDVariable sDVariable) {
        SDValidation.validateNumerical("hardTanh", "x", sDVariable);
        return this.sd.updateVariableNameAndReference(new HardTanh(this.sd, sDVariable).outputVariable(), str);
    }

    public SDVariable hardTanhDerivative(SDVariable sDVariable) {
        SDValidation.validateNumerical("hardTanhDerivative", "x", sDVariable);
        return new HardTanhDerivative(this.sd, sDVariable).outputVariable();
    }

    public SDVariable hardTanhDerivative(String str, SDVariable sDVariable) {
        SDValidation.validateNumerical("hardTanhDerivative", "x", sDVariable);
        return this.sd.updateVariableNameAndReference(new HardTanhDerivative(this.sd, sDVariable).outputVariable(), str);
    }

    public SDVariable layerNorm(SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, boolean z, int... iArr) {
        SDValidation.validateNumerical("layerNorm", "input", sDVariable);
        SDValidation.validateNumerical("layerNorm", "gain", sDVariable2);
        SDValidation.validateNumerical("layerNorm", "bias", sDVariable3);
        Preconditions.checkArgument(iArr.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", iArr.length);
        return new LayerNorm(this.sd, sDVariable, sDVariable2, sDVariable3, z, iArr).outputVariable();
    }

    public SDVariable layerNorm(String str, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, boolean z, int... iArr) {
        SDValidation.validateNumerical("layerNorm", "input", sDVariable);
        SDValidation.validateNumerical("layerNorm", "gain", sDVariable2);
        SDValidation.validateNumerical("layerNorm", "bias", sDVariable3);
        Preconditions.checkArgument(iArr.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", iArr.length);
        return this.sd.updateVariableNameAndReference(new LayerNorm(this.sd, sDVariable, sDVariable2, sDVariable3, z, iArr).outputVariable(), str);
    }

    public SDVariable layerNorm(SDVariable sDVariable, SDVariable sDVariable2, boolean z, int... iArr) {
        SDValidation.validateNumerical("layerNorm", "input", sDVariable);
        SDValidation.validateNumerical("layerNorm", "gain", sDVariable2);
        Preconditions.checkArgument(iArr.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", iArr.length);
        return new LayerNorm(this.sd, sDVariable, sDVariable2, (SDVariable) null, z, iArr).outputVariable();
    }

    public SDVariable layerNorm(String str, SDVariable sDVariable, SDVariable sDVariable2, boolean z, int... iArr) {
        SDValidation.validateNumerical("layerNorm", "input", sDVariable);
        SDValidation.validateNumerical("layerNorm", "gain", sDVariable2);
        Preconditions.checkArgument(iArr.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", iArr.length);
        return this.sd.updateVariableNameAndReference(new LayerNorm(this.sd, sDVariable, sDVariable2, (SDVariable) null, z, iArr).outputVariable(), str);
    }

    public SDVariable leakyRelu(SDVariable sDVariable, double d) {
        SDValidation.validateNumerical("leakyRelu", "x", sDVariable);
        return new LeakyReLU(this.sd, sDVariable, d).outputVariable();
    }

    public SDVariable leakyRelu(String str, SDVariable sDVariable, double d) {
        SDValidation.validateNumerical("leakyRelu", "x", sDVariable);
        return this.sd.updateVariableNameAndReference(new LeakyReLU(this.sd, sDVariable, d).outputVariable(), str);
    }

    public SDVariable leakyReluDerivative(SDVariable sDVariable, double d) {
        SDValidation.validateNumerical("leakyReluDerivative", "x", sDVariable);
        return new LeakyReLUDerivative(this.sd, sDVariable, d).outputVariable();
    }

    public SDVariable leakyReluDerivative(String str, SDVariable sDVariable, double d) {
        SDValidation.validateNumerical("leakyReluDerivative", "x", sDVariable);
        return this.sd.updateVariableNameAndReference(new LeakyReLUDerivative(this.sd, sDVariable, d).outputVariable(), str);
    }

    public SDVariable linear(SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3) {
        SDValidation.validateNumerical("linear", "input", sDVariable);
        SDValidation.validateNumerical("linear", "weights", sDVariable2);
        SDValidation.validateNumerical("linear", "bias", sDVariable3);
        return new XwPlusB(this.sd, sDVariable, sDVariable2, sDVariable3).outputVariable();
    }

    public SDVariable linear(String str, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3) {
        SDValidation.validateNumerical("linear", "input", sDVariable);
        SDValidation.validateNumerical("linear", "weights", sDVariable2);
        SDValidation.validateNumerical("linear", "bias", sDVariable3);
        return this.sd.updateVariableNameAndReference(new XwPlusB(this.sd, sDVariable, sDVariable2, sDVariable3).outputVariable(), str);
    }

    public SDVariable logSigmoid(SDVariable sDVariable) {
        SDValidation.validateNumerical("logSigmoid", "x", sDVariable);
        return new LogSigmoid(this.sd, sDVariable).outputVariable();
    }

    public SDVariable logSigmoid(String str, SDVariable sDVariable) {
        SDValidation.validateNumerical("logSigmoid", "x", sDVariable);
        return this.sd.updateVariableNameAndReference(new LogSigmoid(this.sd, sDVariable).outputVariable(), str);
    }

    public SDVariable logSoftmax(SDVariable sDVariable) {
        SDValidation.validateNumerical("logSoftmax", "x", sDVariable);
        return new LogSoftMax(this.sd, sDVariable).outputVariable();
    }

    public SDVariable logSoftmax(String str, SDVariable sDVariable) {
        SDValidation.validateNumerical("logSoftmax", "x", sDVariable);
        return this.sd.updateVariableNameAndReference(new LogSoftMax(this.sd, sDVariable).outputVariable(), str);
    }

    public SDVariable logSoftmax(SDVariable sDVariable, int i) {
        SDValidation.validateNumerical("logSoftmax", "x", sDVariable);
        return new LogSoftMax(this.sd, sDVariable, i).outputVariable();
    }

    public SDVariable logSoftmax(String str, SDVariable sDVariable, int i) {
        SDValidation.validateNumerical("logSoftmax", "x", sDVariable);
        return this.sd.updateVariableNameAndReference(new LogSoftMax(this.sd, sDVariable, i).outputVariable(), str);
    }

    public SDVariable multiHeadDotProductAttention(SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, SDVariable sDVariable4, SDVariable sDVariable5, SDVariable sDVariable6, SDVariable sDVariable7, SDVariable sDVariable8, boolean z) {
        SDValidation.validateNumerical("multiHeadDotProductAttention", "queries", sDVariable);
        SDValidation.validateNumerical("multiHeadDotProductAttention", "keys", sDVariable2);
        SDValidation.validateNumerical("multiHeadDotProductAttention", "values", sDVariable3);
        SDValidation.validateNumerical("multiHeadDotProductAttention", "Wq", sDVariable4);
        SDValidation.validateNumerical("multiHeadDotProductAttention", "Wk", sDVariable5);
        SDValidation.validateNumerical("multiHeadDotProductAttention", "Wv", sDVariable6);
        SDValidation.validateNumerical("multiHeadDotProductAttention", "Wo", sDVariable7);
        SDValidation.validateNumerical("multiHeadDotProductAttention", "mask", sDVariable8);
        return new MultiHeadDotProductAttention(this.sd, sDVariable, sDVariable2, sDVariable3, sDVariable4, sDVariable5, sDVariable6, sDVariable7, sDVariable8, z, false).outputVariable();
    }

    public SDVariable multiHeadDotProductAttention(String str, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, SDVariable sDVariable4, SDVariable sDVariable5, SDVariable sDVariable6, SDVariable sDVariable7, SDVariable sDVariable8, boolean z) {
        SDValidation.validateNumerical("multiHeadDotProductAttention", "queries", sDVariable);
        SDValidation.validateNumerical("multiHeadDotProductAttention", "keys", sDVariable2);
        SDValidation.validateNumerical("multiHeadDotProductAttention", "values", sDVariable3);
        SDValidation.validateNumerical("multiHeadDotProductAttention", "Wq", sDVariable4);
        SDValidation.validateNumerical("multiHeadDotProductAttention", "Wk", sDVariable5);
        SDValidation.validateNumerical("multiHeadDotProductAttention", "Wv", sDVariable6);
        SDValidation.validateNumerical("multiHeadDotProductAttention", "Wo", sDVariable7);
        SDValidation.validateNumerical("multiHeadDotProductAttention", "mask", sDVariable8);
        return this.sd.updateVariableNameAndReference(new MultiHeadDotProductAttention(this.sd, sDVariable, sDVariable2, sDVariable3, sDVariable4, sDVariable5, sDVariable6, sDVariable7, sDVariable8, z, false).outputVariable(), str);
    }

    public SDVariable pad(SDVariable sDVariable, SDVariable sDVariable2, PadMode padMode, double d) {
        SDValidation.validateNumerical("pad", "input", sDVariable);
        SDValidation.validateNumerical("pad", "padding", sDVariable2);
        return new Pad(this.sd, sDVariable, sDVariable2, padMode, d).outputVariable();
    }

    public SDVariable pad(String str, SDVariable sDVariable, SDVariable sDVariable2, PadMode padMode, double d) {
        SDValidation.validateNumerical("pad", "input", sDVariable);
        SDValidation.validateNumerical("pad", "padding", sDVariable2);
        return this.sd.updateVariableNameAndReference(new Pad(this.sd, sDVariable, sDVariable2, padMode, d).outputVariable(), str);
    }

    public SDVariable pad(SDVariable sDVariable, SDVariable sDVariable2, double d) {
        SDValidation.validateNumerical("pad", "input", sDVariable);
        SDValidation.validateNumerical("pad", "padding", sDVariable2);
        return new Pad(this.sd, sDVariable, sDVariable2, PadMode.CONSTANT, d).outputVariable();
    }

    public SDVariable pad(String str, SDVariable sDVariable, SDVariable sDVariable2, double d) {
        SDValidation.validateNumerical("pad", "input", sDVariable);
        SDValidation.validateNumerical("pad", "padding", sDVariable2);
        return this.sd.updateVariableNameAndReference(new Pad(this.sd, sDVariable, sDVariable2, PadMode.CONSTANT, d).outputVariable(), str);
    }

    public SDVariable preciseGelu(SDVariable sDVariable) {
        SDValidation.validateNumerical("preciseGelu", "x", sDVariable);
        return new PreciseGELU(this.sd, sDVariable).outputVariable();
    }

    public SDVariable preciseGelu(String str, SDVariable sDVariable) {
        SDValidation.validateNumerical("preciseGelu", "x", sDVariable);
        return this.sd.updateVariableNameAndReference(new PreciseGELU(this.sd, sDVariable).outputVariable(), str);
    }

    public SDVariable prelu(SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        SDValidation.validateNumerical("prelu", "input", sDVariable);
        SDValidation.validateNumerical("prelu", "alpha", sDVariable2);
        Preconditions.checkArgument(iArr.length >= 1, "sharedAxes has incorrect size/length. Expected: sharedAxes.length >= 1, got %s", iArr.length);
        return new PRelu(this.sd, sDVariable, sDVariable2, iArr).outputVariable();
    }

    public SDVariable prelu(String str, SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        SDValidation.validateNumerical("prelu", "input", sDVariable);
        SDValidation.validateNumerical("prelu", "alpha", sDVariable2);
        Preconditions.checkArgument(iArr.length >= 1, "sharedAxes has incorrect size/length. Expected: sharedAxes.length >= 1, got %s", iArr.length);
        return this.sd.updateVariableNameAndReference(new PRelu(this.sd, sDVariable, sDVariable2, iArr).outputVariable(), str);
    }

    public SDVariable relu(SDVariable sDVariable, double d) {
        SDValidation.validateNumerical("relu", "x", sDVariable);
        return new RectifiedLinear(this.sd, sDVariable, d).outputVariable();
    }

    public SDVariable relu(String str, SDVariable sDVariable, double d) {
        SDValidation.validateNumerical("relu", "x", sDVariable);
        return this.sd.updateVariableNameAndReference(new RectifiedLinear(this.sd, sDVariable, d).outputVariable(), str);
    }

    public SDVariable relu6(SDVariable sDVariable, double d) {
        SDValidation.validateNumerical("relu6", "x", sDVariable);
        return new Relu6(this.sd, sDVariable, d).outputVariable();
    }

    public SDVariable relu6(String str, SDVariable sDVariable, double d) {
        SDValidation.validateNumerical("relu6", "x", sDVariable);
        return this.sd.updateVariableNameAndReference(new Relu6(this.sd, sDVariable, d).outputVariable(), str);
    }

    public SDVariable reluLayer(SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3) {
        SDValidation.validateNumerical("reluLayer", "input", sDVariable);
        SDValidation.validateNumerical("reluLayer", "weights", sDVariable2);
        SDValidation.validateNumerical("reluLayer", "bias", sDVariable3);
        return new ReluLayer(this.sd, sDVariable, sDVariable2, sDVariable3).outputVariable();
    }

    public SDVariable reluLayer(String str, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3) {
        SDValidation.validateNumerical("reluLayer", "input", sDVariable);
        SDValidation.validateNumerical("reluLayer", "weights", sDVariable2);
        SDValidation.validateNumerical("reluLayer", "bias", sDVariable3);
        return this.sd.updateVariableNameAndReference(new ReluLayer(this.sd, sDVariable, sDVariable2, sDVariable3).outputVariable(), str);
    }

    public SDVariable selu(SDVariable sDVariable) {
        SDValidation.validateNumerical("selu", "x", sDVariable);
        return new SELU(this.sd, sDVariable).outputVariable();
    }

    public SDVariable selu(String str, SDVariable sDVariable) {
        SDValidation.validateNumerical("selu", "x", sDVariable);
        return this.sd.updateVariableNameAndReference(new SELU(this.sd, sDVariable).outputVariable(), str);
    }

    public SDVariable sigmoid(SDVariable sDVariable) {
        SDValidation.validateNumerical("sigmoid", "x", sDVariable);
        return new Sigmoid(this.sd, sDVariable).outputVariable();
    }

    public SDVariable sigmoid(String str, SDVariable sDVariable) {
        SDValidation.validateNumerical("sigmoid", "x", sDVariable);
        return this.sd.updateVariableNameAndReference(new Sigmoid(this.sd, sDVariable).outputVariable(), str);
    }

    public SDVariable sigmoidDerivative(SDVariable sDVariable, SDVariable sDVariable2) {
        SDValidation.validateNumerical("sigmoidDerivative", "x", sDVariable);
        SDValidation.validateNumerical("sigmoidDerivative", "wrt", sDVariable2);
        return new SigmoidDerivative(this.sd, sDVariable, sDVariable2).outputVariable();
    }

    public SDVariable sigmoidDerivative(String str, SDVariable sDVariable, SDVariable sDVariable2) {
        SDValidation.validateNumerical("sigmoidDerivative", "x", sDVariable);
        SDValidation.validateNumerical("sigmoidDerivative", "wrt", sDVariable2);
        return this.sd.updateVariableNameAndReference(new SigmoidDerivative(this.sd, sDVariable, sDVariable2).outputVariable(), str);
    }

    public SDVariable softmax(SDVariable sDVariable, int i) {
        SDValidation.validateNumerical("softmax", "x", sDVariable);
        return new SoftMax(this.sd, sDVariable, i).outputVariable();
    }

    public SDVariable softmax(String str, SDVariable sDVariable, int i) {
        SDValidation.validateNumerical("softmax", "x", sDVariable);
        return this.sd.updateVariableNameAndReference(new SoftMax(this.sd, sDVariable, i).outputVariable(), str);
    }

    public SDVariable softmax(SDVariable sDVariable) {
        SDValidation.validateNumerical("softmax", "x", sDVariable);
        return new SoftMax(this.sd, sDVariable, -1).outputVariable();
    }

    public SDVariable softmax(String str, SDVariable sDVariable) {
        SDValidation.validateNumerical("softmax", "x", sDVariable);
        return this.sd.updateVariableNameAndReference(new SoftMax(this.sd, sDVariable, -1).outputVariable(), str);
    }

    public SDVariable softmaxDerivative(SDVariable sDVariable, SDVariable sDVariable2, int i) {
        SDValidation.validateNumerical("softmaxDerivative", "x", sDVariable);
        SDValidation.validateNumerical("softmaxDerivative", "wrt", sDVariable2);
        return new SoftmaxBp(this.sd, sDVariable, sDVariable2, Integer.valueOf(i)).outputVariable();
    }

    public SDVariable softmaxDerivative(String str, SDVariable sDVariable, SDVariable sDVariable2, int i) {
        SDValidation.validateNumerical("softmaxDerivative", "x", sDVariable);
        SDValidation.validateNumerical("softmaxDerivative", "wrt", sDVariable2);
        return this.sd.updateVariableNameAndReference(new SoftmaxBp(this.sd, sDVariable, sDVariable2, Integer.valueOf(i)).outputVariable(), str);
    }

    public SDVariable softplus(SDVariable sDVariable) {
        SDValidation.validateNumerical("softplus", "x", sDVariable);
        return new SoftPlus(this.sd, sDVariable).outputVariable();
    }

    public SDVariable softplus(String str, SDVariable sDVariable) {
        SDValidation.validateNumerical("softplus", "x", sDVariable);
        return this.sd.updateVariableNameAndReference(new SoftPlus(this.sd, sDVariable).outputVariable(), str);
    }

    public SDVariable softsign(SDVariable sDVariable) {
        SDValidation.validateNumerical("softsign", "x", sDVariable);
        return new SoftSign(this.sd, sDVariable).outputVariable();
    }

    public SDVariable softsign(String str, SDVariable sDVariable) {
        SDValidation.validateNumerical("softsign", "x", sDVariable);
        return this.sd.updateVariableNameAndReference(new SoftSign(this.sd, sDVariable).outputVariable(), str);
    }

    public SDVariable softsignDerivative(SDVariable sDVariable) {
        SDValidation.validateNumerical("softsignDerivative", "x", sDVariable);
        return new SoftSignDerivative(this.sd, sDVariable).outputVariable();
    }

    public SDVariable softsignDerivative(String str, SDVariable sDVariable) {
        SDValidation.validateNumerical("softsignDerivative", "x", sDVariable);
        return this.sd.updateVariableNameAndReference(new SoftSignDerivative(this.sd, sDVariable).outputVariable(), str);
    }

    public SDVariable swish(SDVariable sDVariable) {
        SDValidation.validateNumerical("swish", "x", sDVariable);
        return new Swish(this.sd, sDVariable).outputVariable();
    }

    public SDVariable swish(String str, SDVariable sDVariable) {
        SDValidation.validateNumerical("swish", "x", sDVariable);
        return this.sd.updateVariableNameAndReference(new Swish(this.sd, sDVariable).outputVariable(), str);
    }

    public SDVariable tanh(SDVariable sDVariable) {
        SDValidation.validateNumerical("tanh", "x", sDVariable);
        return new Tanh(this.sd, sDVariable).outputVariable();
    }

    public SDVariable tanh(String str, SDVariable sDVariable) {
        SDValidation.validateNumerical("tanh", "x", sDVariable);
        return this.sd.updateVariableNameAndReference(new Tanh(this.sd, sDVariable).outputVariable(), str);
    }

    public SDVariable[] topK(SDVariable sDVariable, double d, boolean z) {
        SDValidation.validateNumerical("topK", "input", sDVariable);
        return new TopK(this.sd, sDVariable, d, z).outputVariables();
    }

    public SDVariable[] topK(String[] strArr, SDVariable sDVariable, double d, boolean z) {
        SDValidation.validateNumerical("topK", "input", sDVariable);
        return this.sd.updateVariableNamesAndReferences(new TopK(this.sd, sDVariable, d, z).outputVariables(), strArr);
    }
}
