Package ai.djl.training
Class DefaultTrainingConfig
java.lang.Object
ai.djl.training.DefaultTrainingConfig
- All Implemented Interfaces:
TrainingConfig
DefaultTrainingConfig is an implementation of the TrainingConfig interface.-
Constructor Summary
ConstructorsConstructorDescriptionDefaultTrainingConfig(Loss loss) Creates an instance ofDefaultTrainingConfigwith the givenLoss. -
Method Summary
Modifier and TypeMethodDescriptionaddEvaluator(Evaluator evaluator) Adds anEvaluatorthat needs to be computed during training.<T extends Evaluator>
DefaultTrainingConfigaddEvaluators(Collection<T> evaluators) Adds multipleEvaluators that needs to be computed during training.addTrainingListeners(TrainingListener... listeners) AddsTrainingListeners for training.Device[]Gets theDevicethat are available for computation.Returns the list ofEvaluators that should be computed during training.Gets theExecutorServicefor parallelization.ai.djl.util.PairList<Initializer,Predicate<Parameter>> Gets a list ofInitializerand Predicate to initialize the parameters of the model.Gets theLossfunction to compute the loss against.Gets theOptimizerto use during training.Returns the list ofTrainingListeners that should be used during training.optDevices(Device[] devices) Sets the array ofDeviceavailable for training.Sets theExecutorServicewith the globalForkJoinPool.commonPool().optExecutorService(ExecutorService executorService) Sets theExecutorServiceto train with multiple threads.optInitializer(Initializer initializer, Parameter.Type type) Sets theInitializerto use for the parameters (default from paper).optInitializer(Initializer initializer, String name) Sets theInitializerto use for the parameters (default from paper).optInitializer(Initializer initializer, Predicate<Parameter> predicate) Sets theInitializerto use for the parameters (default from paper).optOptimizer(Optimizer optimizer)
-
Constructor Details
-
DefaultTrainingConfig
Creates an instance ofDefaultTrainingConfigwith the givenLoss.DefaultTrainingConfigcreates a defaultTrainingConfig,Adamas optimiser, and the givenLoss. The evaluators and listeners are left to the user's discretion.- Parameters:
loss- the loss to use for training
-
-
Method Details
-
optInitializer
Sets theInitializerto use for the parameters (default from paper).- Parameters:
initializer- the initialer to use for the parameterstype- theParameter.Typeof the parameters- Returns:
- this
DefaultTrainingConfig
-
optInitializer
Sets theInitializerto use for the parameters (default from paper).- Parameters:
initializer- the initialer to use for the parametersname- the name of the parameter- Returns:
- this
DefaultTrainingConfig
-
optInitializer
public DefaultTrainingConfig optInitializer(Initializer initializer, Predicate<Parameter> predicate) Sets theInitializerto use for the parameters (default from paper).- Parameters:
initializer- the initialer to use for the parameterspredicate- the predicate to identify parameter- Returns:
- this
DefaultTrainingConfig
-
optDevices
Sets the array ofDeviceavailable for training.- Parameters:
devices- an array of devices to be set- Returns:
- this
DefaultTrainingConfig
-
optOptimizer
- Parameters:
optimizer- the optimizer to be set- Returns:
- this
DefaultTrainingConfig
-
optExecutorService
Sets theExecutorServicewith the globalForkJoinPool.commonPool().- Returns:
- this
DefaultTrainingConfig
-
optExecutorService
Sets theExecutorServiceto train with multiple threads.- Parameters:
executorService- the executor service- Returns:
- this
DefaultTrainingConfig
-
addEvaluators
Adds multipleEvaluators that needs to be computed during training.- Type Parameters:
T- the type of evaluator to be added- Parameters:
evaluators- the evaluators to be added- Returns:
- this
DefaultTrainingConfig
-
addEvaluator
Adds anEvaluatorthat needs to be computed during training.- Parameters:
evaluator- the evaluator to be added- Returns:
- this
DefaultTrainingConfig
-
addTrainingListeners
AddsTrainingListeners for training.- Parameters:
listeners- theTrainingListeners to add- Returns:
- this
DefaultTrainingConfig
-
getDevices
Gets theDevicethat are available for computation.This is necessary for a
Traineras it needs to know what kind of device it is running on, and how many devices it is running on.- Specified by:
getDevicesin interfaceTrainingConfig- Returns:
- an array of
Device
-
getInitializers
Gets a list ofInitializerand Predicate to initialize the parameters of the model.- Specified by:
getInitializersin interfaceTrainingConfig- Returns:
- an
Initializer
-
getOptimizer
Gets theOptimizerto use during training.- Specified by:
getOptimizerin interfaceTrainingConfig- Returns:
- an
Optimizer
-
getLossFunction
Gets theLossfunction to compute the loss against.- Specified by:
getLossFunctionin interfaceTrainingConfig- Returns:
- a
Lossfunction
-
getExecutorService
Gets theExecutorServicefor parallelization.- Specified by:
getExecutorServicein interfaceTrainingConfig- Returns:
- an
ExecutorService
-
getEvaluators
Returns the list ofEvaluators that should be computed during training.- Specified by:
getEvaluatorsin interfaceTrainingConfig- Returns:
- a list of
Evaluators
-
getTrainingListeners
Returns the list ofTrainingListeners that should be used during training.- Specified by:
getTrainingListenersin interfaceTrainingConfig- Returns:
- a list of
TrainingListeners
-