Package ai.djl.training.loss
Class TabNetRegressionLoss
java.lang.Object
ai.djl.training.evaluator.Evaluator
ai.djl.training.loss.Loss
ai.djl.training.loss.TabNetRegressionLoss
Calculates the loss of tabNet for regression tasks.
Actually, tabNet is not only used for Supervised Learning, it's also widely used in unsupervised learning. For unsupervised learning, it should come from the decoder(aka attentionTransformer of tabNet)
-
Field Summary
Fields inherited from class ai.djl.training.evaluator.Evaluator
totalInstances -
Constructor Summary
ConstructorsConstructorDescriptionCalculates the loss of a TabNet instance for regression tasks.TabNetRegressionLoss(String name) Calculates the loss of a TabNet instance for regression tasks. -
Method Summary
Methods inherited from class ai.djl.training.loss.Loss
addAccumulator, elasticNetWeightedDecay, elasticNetWeightedDecay, elasticNetWeightedDecay, elasticNetWeightedDecay, getAccumulator, hingeLoss, hingeLoss, hingeLoss, l1Loss, l1Loss, l1Loss, l1WeightedDecay, l1WeightedDecay, l1WeightedDecay, l2Loss, l2Loss, l2Loss, l2WeightedDecay, l2WeightedDecay, l2WeightedDecay, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, quantileL1Loss, quantileL1Loss, resetAccumulator, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss, updateAccumulator, updateAccumulatorsMethods inherited from class ai.djl.training.evaluator.Evaluator
checkLabelShapes, checkLabelShapes, getName
-
Constructor Details
-
TabNetRegressionLoss
public TabNetRegressionLoss()Calculates the loss of a TabNet instance for regression tasks. -
TabNetRegressionLoss
Calculates the loss of a TabNet instance for regression tasks.- Parameters:
name- the name of the loss function
-
-
Method Details