Class Prelu

All Implemented Interfaces:
Block

public class Prelu extends AbstractBlock
Applies Leaky Parametric ReLU activation element-wise to the input.

Leaky ReLUs attempt to fix the 'dying ReLU' problem by allowing a small slope when the input is negative and has a slope of one when input is positive. This is defined by \(y= x \gt 0 ? x : slope * x\).

Parametric ReLU is a Leaky ReLU in which the slope is learnt during training.

  • Constructor Details

    • Prelu

      public Prelu()
      Creates a Parametric ReLU Block.
  • Method Details

    • forwardInternal

      protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<String,Object> params)
      Specified by:
      forwardInternal in class AbstractBaseBlock
      Parameters:
      parameterStore - the parameter store
      inputs - the input NDList
      training - true for a training forward pass
      params - optional parameters
      Returns:
      the output of the forward pass
    • getOutputShapes

      public Shape[] getOutputShapes(Shape[] inputs)
      Returns the expected output shapes of the block for the specified input shapes.
      Parameters:
      inputs - the shapes of the inputs
      Returns:
      the expected output shapes of the block
    • loadMetadata

      public void loadMetadata(byte loadVersion, DataInputStream is) throws IOException, MalformedModelException
      Overwrite this to load additional metadata with the parameter values.

      If you overwrite AbstractBaseBlock.saveMetadata(DataOutputStream) or need to provide backward compatibility to older binary formats, you probably need to overwrite this. This default implementation checks if the version number fits, if not it throws an MalformedModelException. After that it restores the input shapes.

      Overrides:
      loadMetadata in class AbstractBaseBlock
      Parameters:
      loadVersion - the version used for loading this metadata.
      is - the input stream we are loading from
      Throws:
      IOException - loading failed
      MalformedModelException - data can be loaded but has wrong format
    • prelu

      public static NDList prelu(NDArray input, NDArray alpha)
      Applies a Prelu activation on the input NDArray.

      Prelu is defined as \(y = max(0,x) + alpha * min(0, x) \) where alpha is learnable parameter

      Parameters:
      input - the input NDArray
      alpha - learnable parameter
      Returns:
      the NDArray after applying Prelu activation