Class ParameterStore

java.lang.Object
ai.djl.training.ParameterStore

public class ParameterStore extends Object
The ParameterStore contains a map from a parameter to the mirrors of it on other devices.
  • Constructor Details

    • ParameterStore

      public ParameterStore()
      Constructs a new ParameterStore instance.
    • ParameterStore

      public ParameterStore(NDManager manager, boolean copy)
      Constructs an empty ParameterStore.
      Parameters:
      manager - the manager to attach mirrored parameters to
      copy - whether to always copy even for the same device as the original parameter
  • Method Details

    • setParameterServer

      public void setParameterServer(ParameterServer parameterServer, Device[] devices)
      Sets the parameterServer used to apply updates to the parameters.
      Parameters:
      parameterServer - the parameterServer
      devices - the devices to create mirrored parameters on
    • updateAllParameters

      public void updateAllParameters()
      Updates all the mirrored parameters.
    • getValue

      public NDArray getValue(Parameter parameter, Device device, boolean training)
      Returns the value of a mirrored parameter on a device.
      Parameters:
      parameter - the parameter to get the value for
      device - the device to get the mirror from
      training - true for a training forward pass
      Returns:
      the value of the mirrored parameter on the device
    • getManager

      public NDManager getManager()
      Get the NDManager associated with ParameterStore.
      Returns:
      the NDManager
    • sync

      public void sync()
      Synchronizes the values on all mirrors with the main parameter.