Package ai.djl.nn.transformer
Class BertPretrainingLoss
java.lang.Object
ai.djl.training.evaluator.Evaluator
ai.djl.training.loss.Loss
ai.djl.training.loss.AbstractCompositeLoss
ai.djl.nn.transformer.BertPretrainingLoss
Loss that combines the next sentence and masked language losses of bert pretraining.
-
Field Summary
Fields inherited from class ai.djl.training.loss.AbstractCompositeLoss
componentsFields inherited from class ai.djl.training.evaluator.Evaluator
totalInstances -
Constructor Summary
ConstructorsConstructorDescriptionCreates a loss combining the next sentence and masked language loss for bert pretraining. -
Method Summary
Modifier and TypeMethodDescriptiongets BertMaskedLanguageModelLoss.gets BertNextSentenceLoss.inputForComponent(int componentIndex, NDList labels, NDList predictions) Returns the inputs to computing the loss for a component loss.Methods inherited from class ai.djl.training.loss.AbstractCompositeLoss
addAccumulator, evaluate, getAccumulator, getComponents, resetAccumulator, updateAccumulatorsMethods inherited from class ai.djl.training.loss.Loss
elasticNetWeightedDecay, elasticNetWeightedDecay, elasticNetWeightedDecay, elasticNetWeightedDecay, hingeLoss, hingeLoss, hingeLoss, l1Loss, l1Loss, l1Loss, l1WeightedDecay, l1WeightedDecay, l1WeightedDecay, l2Loss, l2Loss, l2Loss, l2WeightedDecay, l2WeightedDecay, l2WeightedDecay, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, quantileL1Loss, quantileL1Loss, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss, updateAccumulatorMethods inherited from class ai.djl.training.evaluator.Evaluator
checkLabelShapes, checkLabelShapes, getName
-
Constructor Details
-
BertPretrainingLoss
public BertPretrainingLoss()Creates a loss combining the next sentence and masked language loss for bert pretraining.
-
-
Method Details
-
inputForComponent
protected ai.djl.util.Pair<NDList,NDList> inputForComponent(int componentIndex, NDList labels, NDList predictions) Description copied from class:AbstractCompositeLossReturns the inputs to computing the loss for a component loss.- Specified by:
inputForComponentin classAbstractCompositeLoss- Parameters:
componentIndex- the index of the component losslabels- the label input to the composite losspredictions- the predictions input to the composite loss- Returns:
- a pair of the (labels, predictions) inputs to the component loss
-
getBertNextSentenceLoss
gets BertNextSentenceLoss.- Returns:
- BertNextSentenceLoss
-
getBertMaskedLanguageModelLoss
gets BertMaskedLanguageModelLoss.- Returns:
- BertMaskedLanguageModelLoss
-