Class SaveModelTrainingListener

java.lang.Object
ai.djl.training.listener.TrainingListenerAdapter
ai.djl.training.listener.SaveModelTrainingListener
All Implemented Interfaces:
TrainingListener

public class SaveModelTrainingListener extends TrainingListenerAdapter
A TrainingListener that saves a model and can save checkpoints.
  • Constructor Details

    • SaveModelTrainingListener

      public SaveModelTrainingListener(String outputDir)
      Constructs a SaveModelTrainingListener using the model's name.
      Parameters:
      outputDir - the directory to output the checkpointed models in
    • SaveModelTrainingListener

      public SaveModelTrainingListener(String outputDir, String overrideModelName)
      Parameters:
      overrideModelName - an override model name to save checkpoints with
      outputDir - the directory to output the checkpointed models in
    • SaveModelTrainingListener

      public SaveModelTrainingListener(String outputDir, String overrideModelName, int checkpoint)
      Parameters:
      overrideModelName - an override model name to save checkpoints with
      outputDir - the directory to output the checkpointed models in
      checkpoint - adds a checkpoint every n epochs
  • Method Details

    • onEpoch

      public void onEpoch(Trainer trainer)
      Listens to the end of an epoch during training.
      Specified by:
      onEpoch in interface TrainingListener
      Overrides:
      onEpoch in class TrainingListenerAdapter
      Parameters:
      trainer - the trainer the listener is attached to
    • onTrainingEnd

      public void onTrainingEnd(Trainer trainer)
      Listens to the end of training.
      Specified by:
      onTrainingEnd in interface TrainingListener
      Overrides:
      onTrainingEnd in class TrainingListenerAdapter
      Parameters:
      trainer - the trainer the listener is attached to
    • getOverrideModelName

      public String getOverrideModelName()
      Returns the override model name to save checkpoints with.
      Returns:
      the override model name to save checkpoints with
    • setOverrideModelName

      public void setOverrideModelName(String overrideModelName)
      Sets the override model name to save checkpoints with.
      Parameters:
      overrideModelName - the override model name to save checkpoints with
    • getCheckpoint

      public int getCheckpoint()
      Returns the checkpoint frequency (or -1 for no checkpointing) in SaveModelTrainingListener.
      Returns:
      the checkpoint frequency (or -1 for no checkpointing)
    • setCheckpoint

      public void setCheckpoint(int checkpoint)
      Sets the checkpoint frequency in SaveModelTrainingListener.
      Parameters:
      checkpoint - how many epochs between checkpoints (or -1 for no checkpoints)
    • setSaveModelCallback

      public void setSaveModelCallback(Consumer<Trainer> onSaveModel)
      Sets the callback function on model saving.

      This allows user to set custom properties to model metadata.

      Parameters:
      onSaveModel - the callback function on model saving
    • saveModel

      protected void saveModel(Trainer trainer)