Package ai.djl.training.listener
Class SaveModelTrainingListener
java.lang.Object
ai.djl.training.listener.TrainingListenerAdapter
ai.djl.training.listener.SaveModelTrainingListener
- All Implemented Interfaces:
TrainingListener
A
TrainingListener that saves a model and can save checkpoints.-
Nested Class Summary
Nested classes/interfaces inherited from interface ai.djl.training.listener.TrainingListener
TrainingListener.BatchData, TrainingListener.Defaults -
Constructor Summary
ConstructorsConstructorDescriptionSaveModelTrainingListener(String outputDir) Constructs aSaveModelTrainingListenerusing the model's name.SaveModelTrainingListener(String outputDir, String overrideModelName) Constructs aSaveModelTrainingListener.SaveModelTrainingListener(String outputDir, String overrideModelName, int checkpoint) Constructs aSaveModelTrainingListener. -
Method Summary
Modifier and TypeMethodDescriptionintReturns the checkpoint frequency (or -1 for no checkpointing) inSaveModelTrainingListener.Returns the override model name to save checkpoints with.voidListens to the end of an epoch during training.voidonTrainingEnd(Trainer trainer) Listens to the end of training.protected voidvoidsetCheckpoint(int checkpoint) Sets the checkpoint frequency inSaveModelTrainingListener.voidsetOverrideModelName(String overrideModelName) Sets the override model name to save checkpoints with.voidsetSaveModelCallback(Consumer<Trainer> onSaveModel) Sets the callback function on model saving.Methods inherited from class ai.djl.training.listener.TrainingListenerAdapter
onTrainingBatch, onTrainingBegin, onValidationBatch
-
Constructor Details
-
SaveModelTrainingListener
Constructs aSaveModelTrainingListenerusing the model's name.- Parameters:
outputDir- the directory to output the checkpointed models in
-
SaveModelTrainingListener
Constructs aSaveModelTrainingListener.- Parameters:
overrideModelName- an override model name to save checkpoints withoutputDir- the directory to output the checkpointed models in
-
SaveModelTrainingListener
Constructs aSaveModelTrainingListener.- Parameters:
overrideModelName- an override model name to save checkpoints withoutputDir- the directory to output the checkpointed models incheckpoint- adds a checkpoint every n epochs
-
-
Method Details
-
onEpoch
Listens to the end of an epoch during training.- Specified by:
onEpochin interfaceTrainingListener- Overrides:
onEpochin classTrainingListenerAdapter- Parameters:
trainer- the trainer the listener is attached to
-
onTrainingEnd
Listens to the end of training.- Specified by:
onTrainingEndin interfaceTrainingListener- Overrides:
onTrainingEndin classTrainingListenerAdapter- Parameters:
trainer- the trainer the listener is attached to
-
getOverrideModelName
Returns the override model name to save checkpoints with.- Returns:
- the override model name to save checkpoints with
-
setOverrideModelName
Sets the override model name to save checkpoints with.- Parameters:
overrideModelName- the override model name to save checkpoints with
-
getCheckpoint
public int getCheckpoint()Returns the checkpoint frequency (or -1 for no checkpointing) inSaveModelTrainingListener.- Returns:
- the checkpoint frequency (or -1 for no checkpointing)
-
setCheckpoint
public void setCheckpoint(int checkpoint) Sets the checkpoint frequency inSaveModelTrainingListener.- Parameters:
checkpoint- how many epochs between checkpoints (or -1 for no checkpoints)
-
setSaveModelCallback
Sets the callback function on model saving.This allows user to set custom properties to model metadata.
- Parameters:
onSaveModel- the callback function on model saving
-
saveModel
-