Package ai.djl.training
Class Trainer
java.lang.Object
ai.djl.training.Trainer
- All Implemented Interfaces:
AutoCloseable
The
Trainer interface provides a session for model training.
Trainer provides an easy, and manageable interface for training. Trainer is
not thread-safe.
See the tutorials on:
- See Also:
-
Constructor Summary
Constructors -
Method Summary
Modifier and TypeMethodDescriptionvoidHelper to add a metric for a time difference.voidclose()Evaluates function of the model once on the given inputNDList.protected voidfinalize()Applies the forward function of the model once on the given inputNDList.Applies the forward function of the model once with both data and labels.Device[]Returns the devices used for training.Gets allEvaluators.Returns theExecutorService.getLoss()Gets the trainingLossfunction of the trainer.Gets theNDManagerfrom the model.Returns the Metrics param used for benchmarking.getModel()Returns the model used to create this trainer.Returns theTrainingResult.voidinitialize(Shape... shapes) Initializes theModelthat theTraineris going to train.iterateDataset(Dataset dataset) Fetches an iterator that can iterate through the givenDataset.Returns a new instance ofGradientCollector.final voidnotifyListeners(Consumer<TrainingListener> listenerConsumer) Executes a method on each of theTrainingListeners.voidsetMetrics(Metrics metrics) Attaches a Metrics param to use for benchmarking.voidstep()Updates all of the parameters of the model once.
-
Constructor Details
-
Trainer
- Parameters:
model- the model the trainer will train ontrainingConfig- the configuration used by the trainer
-
-
Method Details
-
initialize
Initializes theModelthat theTraineris going to train.- Parameters:
shapes- an array ofShapeof the inputs
-
iterateDataset
Fetches an iterator that can iterate through the givenDataset.- Parameters:
dataset- the dataset to iterate through- Returns:
- an
IterableofBatchthat contains batches of data from the dataset - Throws:
IOException- for various exceptions depending on the datasetTranslateException- if there is an error while processing input
-
newGradientCollector
Returns a new instance ofGradientCollector.- Returns:
- a new instance of
GradientCollector
-
forward
Applies the forward function of the model once on the given inputNDList.- Parameters:
input- the inputNDList- Returns:
- the output of the forward function
-
forward
Applies the forward function of the model once with both data and labels. -
evaluate
Evaluates function of the model once on the given inputNDList.- Parameters:
input- the inputNDList- Returns:
- the output of the predict function
-
step
public void step()Updates all of the parameters of the model once. -
getMetrics
Returns the Metrics param used for benchmarking.- Returns:
- the the Metrics param used for benchmarking
-
setMetrics
Attaches a Metrics param to use for benchmarking.- Parameters:
metrics- the Metrics class
-
getDevices
Returns the devices used for training.- Returns:
- the devices used for training
-
getLoss
Gets the trainingLossfunction of the trainer.- Returns:
- the
Lossfunction
-
getModel
Returns the model used to create this trainer.- Returns:
- the model associated with this trainer
-
getExecutorService
Returns theExecutorService.- Returns:
- the
ExecutorService
-
getEvaluators
Gets allEvaluators.- Returns:
- the evaluators used during training
-
notifyListeners
Executes a method on each of theTrainingListeners.- Parameters:
listenerConsumer- a consumer that executes the method
-
getTrainingResult
Returns theTrainingResult.- Returns:
- the
TrainingResult
-
getManager
Gets theNDManagerfrom the model.- Returns:
- the
NDManager
-
finalize
-
close
public void close()- Specified by:
closein interfaceAutoCloseable
-
addMetric
Helper to add a metric for a time difference.- Parameters:
metricName- the metric namebegin- the time difference start (this method is called at the time difference end)
-