Package ai.djl.nn

Class Blocks

java.lang.Object
ai.djl.nn.Blocks

public final class Blocks extends Object
Utility class that provides some useful blocks.
  • Method Details

    • batchFlatten

      public static NDArray batchFlatten(NDArray array)
      Inflates the NDArray provided as input to a 2-D NDArray of shape (batch, size).
      Parameters:
      array - a array to be flattened
      Returns:
      a NDList that contains the inflated NDArray
    • batchFlatten

      public static NDArray batchFlatten(NDArray array, long size)
      Inflates the NDArray provided as input to a 2-D NDArray of shape (batch, size).
      Parameters:
      array - a array to be flattened
      size - the input size
      Returns:
      a NDList that contains the inflated NDArray
      Throws:
      IndexOutOfBoundsException - if the input NDList has more than one NDArray
    • batchFlattenBlock

      public static Block batchFlattenBlock()
      Creates a Block whose forward function applies the batchFlatten method.
      Returns:
      a Block whose forward function applies the batchFlatten method
    • batchFlattenBlock

      public static Block batchFlattenBlock(long size)
      Creates a Block whose forward function applies the batchFlatten method. The size of input to the block returned must be batch_size * size.
      Parameters:
      size - the expected size of each input
      Returns:
      a Block whose forward function applies the batchFlatten method
    • identityBlock

      public static Block identityBlock()
      Creates a LambdaBlock that performs the identity function.
      Returns:
      an identity Block
    • onesBlock

      public static Block onesBlock(ai.djl.util.PairList<DataType,Shape> shapes, String[] names)
      Creates a LambdaBlock that return all-ones NDList.
      Returns:
      an all-ones Block
    • describe

      public static String describe(Block block, String blockName, int beginAxis)
      Returns a string representation of the passed Block describing the input axes, output axes, and the block's children.
      Parameters:
      block - the block to describe
      blockName - the name to be used for the passed block, or null if its class name is to be used
      beginAxis - skips all axes before this axis; use 0 to print all axes and 1 to skip the batch axis.
      Returns:
      the string representation