Class TabNetRegressionLoss


public class TabNetRegressionLoss extends Loss
Calculates the loss of tabNet for regression tasks.

Actually, tabNet is not only used for Supervised Learning, it's also widely used in unsupervised learning. For unsupervised learning, it should come from the decoder(aka attentionTransformer of tabNet)

  • Constructor Details

    • TabNetRegressionLoss

      public TabNetRegressionLoss()
      Calculates the loss of a TabNet instance for regression tasks.
    • TabNetRegressionLoss

      public TabNetRegressionLoss(String name)
      Calculates the loss of a TabNet instance for regression tasks.
      Parameters:
      name - the name of the loss function
  • Method Details

    • evaluate

      public NDArray evaluate(NDList labels, NDList predictions)
      Calculates the evaluation between the labels and the predictions.
      Specified by:
      evaluate in class Evaluator
      Parameters:
      labels - the correct values
      predictions - the predicted values
      Returns:
      the evaluation result