Class AdamW

java.lang.Object
ai.djl.training.optimizer.Optimizer
ai.djl.training.optimizer.AdamW

public class AdamW extends Optimizer
Adam is a generalization of the AdaGrad Optimizer.

Adam updates the weights using:

\( w *= (1 - learning_rate * weight_decay\)
\( m = beta1 * m + (1 - beta1) * grad\)
\( v = beta2 * v + (1 - beta2) * grad^2 \)
\( learning_rate_bias_correction = learning_rate / beta1**t * sqrt(beta2**t) \)
\( w -= learning_rate_bias_correction * m / (sqrt(v) + epsilon) \)

where g represents the gradient, and m/v are 1st and 2nd order moment estimates (mean and variance), t is the step.

See Also:
  • Constructor Details

    • AdamW

      protected AdamW(AdamW.Builder builder)
      Creates a new instance of Adam optimizer.
      Parameters:
      builder - the builder to create a new instance of Adam optimizer
  • Method Details

    • update

      public void update(String parameterId, NDArray weight, NDArray grad)
      Updates the parameters according to the gradients.
      Specified by:
      update in class Optimizer
      Parameters:
      parameterId - the parameter to be updated
      weight - the weights of the parameter
      grad - the gradients
    • builder

      public static AdamW.Builder builder()
      Creates a builder to build a Adam.
      Returns:
      a new builder