Package ai.djl.training
Class EasyTrain
java.lang.Object
ai.djl.training.EasyTrain
Helper for easy training of a whole model, a trainining batch, or a validation batch.
-
Method Summary
Modifier and TypeMethodDescriptionstatic voidevaluateDataset(Trainer trainer, Dataset testDataset) Evaluates the test dataset.static voidRuns a basic epoch training experience with a given trainer.static voidtrainBatch(Trainer trainer, Batch batch) Trains the model with one iteration of the givenBatchof data.static voidvalidateBatch(Trainer trainer, Batch batch) Validates the given batch of data.
-
Method Details
-
fit
public static void fit(Trainer trainer, int numEpoch, Dataset trainingDataset, Dataset validateDataset) throws IOException, TranslateException Runs a basic epoch training experience with a given trainer.- Parameters:
trainer- the trainer to train fornumEpoch- the number of epochs to traintrainingDataset- the dataset to train onvalidateDataset- the dataset to validate against. Can be null for no validation- Throws:
IOException- for various exceptions depending on the datasetTranslateException- if there is an error while processing input
-
trainBatch
Trains the model with one iteration of the givenBatchof data.- Parameters:
trainer- the trainer to validate the batch withbatch- aBatchthat contains data, and its respective labels- Throws:
IllegalArgumentException- if the batch engine does not match the trainer engine
-
validateBatch
Validates the given batch of data.During validation, the evaluators and losses are computed, but gradients aren't computed, and parameters aren't updated.
- Parameters:
trainer- the trainer to validate the batch withbatch- aBatchof data- Throws:
IllegalArgumentException- if the batch engine does not match the trainer engine
-
evaluateDataset
public static void evaluateDataset(Trainer trainer, Dataset testDataset) throws IOException, TranslateException Evaluates the test dataset.- Parameters:
trainer- the trainer to evaluate ontestDataset- the test dataset to evaluate- Throws:
IOException- for various exceptions depending on the datasetTranslateException- if there is an error while processing input
-