Class BertPretrainingLoss


public class BertPretrainingLoss extends AbstractCompositeLoss
Loss that combines the next sentence and masked language losses of bert pretraining.
  • Constructor Details

    • BertPretrainingLoss

      public BertPretrainingLoss()
      Creates a loss combining the next sentence and masked language loss for bert pretraining.
  • Method Details

    • inputForComponent

      protected ai.djl.util.Pair<NDList,NDList> inputForComponent(int componentIndex, NDList labels, NDList predictions)
      Description copied from class: AbstractCompositeLoss
      Returns the inputs to computing the loss for a component loss.
      Specified by:
      inputForComponent in class AbstractCompositeLoss
      Parameters:
      componentIndex - the index of the component loss
      labels - the label input to the composite loss
      predictions - the predictions input to the composite loss
      Returns:
      a pair of the (labels, predictions) inputs to the component loss
    • getBertNextSentenceLoss

      public BertNextSentenceLoss getBertNextSentenceLoss()
      gets BertNextSentenceLoss.
      Returns:
      BertNextSentenceLoss
    • getBertMaskedLanguageModelLoss

      public BertMaskedLanguageModelLoss getBertMaskedLanguageModelLoss()
      gets BertMaskedLanguageModelLoss.
      Returns:
      BertMaskedLanguageModelLoss