Package ai.djl.training.hyperparameter
Class EasyHpo
java.lang.Object
ai.djl.training.hyperparameter.EasyHpo
Helper for easy training with hyperparameters.
-
Constructor Summary
Constructors -
Method Summary
Modifier and TypeMethodDescriptionprotected abstract ModelbuildModel(HpSet hpVals) ai.djl.util.Pair<Model,TrainingResult> fit()Fits the model given the implemented abstract methods.protected abstract RandomAccessDatasetgetDataset(Dataset.Usage usage) Returns the dataset to train with.protected abstract ShapeinputShape(HpSet hpVals) Returns the input shape for the model.protected abstract intReturns the number of epochs to train for the current hyperparameter set.protected abstract intReturns the number of hyperparameter sets to train with.protected voidsaveModel(Model model, TrainingResult result) Saves the best hyperparameter set.protected abstract HpSetReturns the initial hyperparameters.protected abstract TrainingConfigsetupTrainingConfig(HpSet hpVals) Returns theTrainingConfigto use to train each hyperparameter set.
-
Constructor Details
-
EasyHpo
public EasyHpo()
-
-
Method Details
-
fit
Fits the model given the implemented abstract methods.- Returns:
- the best model and training results
- Throws:
IOException- for various exceptions depending on the datasetTranslateException- if there is an error while processing input
-
setupHyperParams
Returns the initial hyperparameters.- Returns:
- the initial hyperparameters
-
getDataset
Returns the dataset to train with.- Parameters:
usage- the usage of the dataset- Returns:
- the dataset to train with
- Throws:
IOException- if the dataset could not be loaded
-
setupTrainingConfig
Returns theTrainingConfigto use to train each hyperparameter set.- Parameters:
hpVals- the hyperparameters to train with- Returns:
- the
TrainingConfigto use to train each hyperparameter set
-
buildModel
- Parameters:
hpVals- the hyperparameter values to use for the model- Returns:
- the model to train
-
inputShape
Returns the input shape for the model.- Parameters:
hpVals- the hyperparameter values for the model- Returns:
- returns the model input shape
-
numEpochs
Returns the number of epochs to train for the current hyperparameter set.- Parameters:
hpVals- the current hyperparameter set- Returns:
- the number of epochs
-
numHyperParameterTests
protected abstract int numHyperParameterTests()Returns the number of hyperparameter sets to train with.- Returns:
- the number of hyperparameter sets to train with
-
saveModel
Saves the best hyperparameter set.- Parameters:
model- the model to saveresult- the training result for training with this model's hyperparameters- Throws:
IOException- if the model could not be saved
-