Class RecurrentBlock

All Implemented Interfaces:
Block
Direct Known Subclasses:
GRU, LSTM, RNN

public abstract class RecurrentBlock extends AbstractBlock
RecurrentBlock is an abstract implementation of recurrent neural networks.

Recurrent neural networks are neural networks with hidden states. They are very popular for natural language processing tasks, and other tasks which involve sequential data.

This [article](http://karpathy.github.io/2015/05/21/rnn-effectiveness/) written by Andrej Karpathy provides a detailed explanation of recurrent neural networks.

Currently, vanilla RNN, LSTM and GRU are implemented, with both multi-layer and bidirectional support.

  • Field Details

    • stateSize

      protected long stateSize
    • dropRate

      protected float dropRate
    • numLayers

      protected int numLayers
    • gates

      protected int gates
    • batchFirst

      protected boolean batchFirst
    • hasBiases

      protected boolean hasBiases
    • bidirectional

      protected boolean bidirectional
    • returnState

      protected boolean returnState
  • Constructor Details

    • RecurrentBlock

      public RecurrentBlock(RecurrentBlock.BaseBuilder<?> builder)
      Creates a RecurrentBlock object.
      Parameters:
      builder - the Builder that has the necessary configurations
  • Method Details

    • getOutputShapes

      public Shape[] getOutputShapes(Shape[] inputs)
      Returns the expected output shapes of the block for the specified input shapes.
      Parameters:
      inputs - the shapes of the inputs
      Returns:
      the expected output shapes of the block
    • beforeInitialize

      protected void beforeInitialize(Shape... inputShapes)
      Performs any action necessary before initialization. For example, keep the input information or verify the layout.
      Overrides:
      beforeInitialize in class AbstractBaseBlock
      Parameters:
      inputShapes - the expected shapes of the input
    • prepare

      public void prepare(Shape[] inputs)
      Sets the shape of Parameters.
      Overrides:
      prepare in class AbstractBaseBlock
      Parameters:
      inputs - the shapes of inputs
    • loadMetadata

      public void loadMetadata(byte loadVersion, DataInputStream is) throws IOException, MalformedModelException
      Overwrite this to load additional metadata with the parameter values.

      If you overwrite AbstractBaseBlock.saveMetadata(DataOutputStream) or need to provide backward compatibility to older binary formats, you probably need to overwrite this. This default implementation checks if the version number fits, if not it throws an MalformedModelException. After that it restores the input shapes.

      Overrides:
      loadMetadata in class AbstractBaseBlock
      Parameters:
      loadVersion - the version used for loading this metadata.
      is - the input stream we are loading from
      Throws:
      IOException - loading failed
      MalformedModelException - data can be loaded but has wrong format
    • getNumDirections

      protected int getNumDirections()