Package ai.djl.training.loss
Class ElasticNetWeightDecay
java.lang.Object
ai.djl.training.evaluator.Evaluator
ai.djl.training.loss.Loss
ai.djl.training.loss.ElasticNetWeightDecay
ElasticWeightDecay calculates L1+L2 penalty of a set of parameters. Used for
regularization.
L loss is defined as \(L = \lambda_1 \sum_i \vert W_i\vert + \lambda_2 \sum_i {W_i}^2\).
-
Field Summary
Fields inherited from class ai.djl.training.evaluator.Evaluator
totalInstances -
Constructor Summary
ConstructorsConstructorDescriptionElasticNetWeightDecay(NDList parameters) Calculates Elastic Net weight decay for regularization.ElasticNetWeightDecay(String name, NDList parameters) Calculates Elastic Net weight decay for regularization.ElasticNetWeightDecay(String name, NDList parameters, float lambda) Calculates Elastic Net weight decay for regularization.ElasticNetWeightDecay(String name, NDList parameters, float lambda1, float lambda2) Calculates Elastic Net weight decay for regularization. -
Method Summary
Methods inherited from class ai.djl.training.loss.Loss
addAccumulator, elasticNetWeightedDecay, elasticNetWeightedDecay, elasticNetWeightedDecay, elasticNetWeightedDecay, getAccumulator, hingeLoss, hingeLoss, hingeLoss, l1Loss, l1Loss, l1Loss, l1WeightedDecay, l1WeightedDecay, l1WeightedDecay, l2Loss, l2Loss, l2Loss, l2WeightedDecay, l2WeightedDecay, l2WeightedDecay, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, quantileL1Loss, quantileL1Loss, resetAccumulator, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss, updateAccumulator, updateAccumulatorsMethods inherited from class ai.djl.training.evaluator.Evaluator
checkLabelShapes, checkLabelShapes, getName
-
Constructor Details
-
ElasticNetWeightDecay
Calculates Elastic Net weight decay for regularization.- Parameters:
parameters- holds the model weights that will be penalized
-
ElasticNetWeightDecay
Calculates Elastic Net weight decay for regularization.- Parameters:
name- the name of the penaltyparameters- holds the model weights that will be penalized
-
ElasticNetWeightDecay
Calculates Elastic Net weight decay for regularization.- Parameters:
name- the name of the penaltyparameters- holds the model weights that will be penalizedlambda- the weight to apply to the penalty value, default 1 (both L1 and L2)
-
ElasticNetWeightDecay
Calculates Elastic Net weight decay for regularization.- Parameters:
name- the name of the penaltyparameters- holds the model weights that will be penalizedlambda1- the weight to apply to the L1 penalty value, default 1lambda2- the weight to apply to the L2 penalty value, default 1
-
-
Method Details