Class RecurrentBlock
- All Implemented Interfaces:
Block
RecurrentBlock is an abstract implementation of recurrent neural networks.
Recurrent neural networks are neural networks with hidden states. They are very popular for natural language processing tasks, and other tasks which involve sequential data.
This [article](http://karpathy.github.io/2015/05/21/rnn-effectiveness/) written by Andrej Karpathy provides a detailed explanation of recurrent neural networks.
Currently, vanilla RNN, LSTM and GRU are implemented, with both multi-layer and bidirectional support.
-
Nested Class Summary
Nested ClassesModifier and TypeClassDescriptionstatic classThe Builder to construct aRecurrentBlocktype ofBlock. -
Field Summary
FieldsModifier and TypeFieldDescriptionprotected booleanprotected booleanprotected floatprotected intprotected booleanprotected intprotected booleanprotected longFields inherited from class ai.djl.nn.AbstractBlock
children, parametersFields inherited from class ai.djl.nn.AbstractBaseBlock
inputNames, inputShapes, outputDataTypes, version -
Constructor Summary
ConstructorsConstructorDescriptionRecurrentBlock(RecurrentBlock.BaseBuilder<?> builder) Creates aRecurrentBlockobject. -
Method Summary
Modifier and TypeMethodDescriptionprotected voidbeforeInitialize(Shape... inputShapes) Performs any action necessary before initialization.protected intShape[]getOutputShapes(Shape[] inputs) Returns the expected output shapes of the block for the specified input shapes.voidloadMetadata(byte loadVersion, DataInputStream is) Overwrite this to load additional metadata with the parameter values.voidSets the shape ofParameters.Methods inherited from class ai.djl.nn.AbstractBlock
addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParametersMethods inherited from class ai.djl.nn.AbstractBaseBlock
cast, clear, describeInput, forward, forward, forwardInternal, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, initializeChildBlocks, isInitialized, loadParameters, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, setInitializer, toStringMethods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, waitMethods inherited from interface ai.djl.nn.Block
forward, freezeParameters, freezeParameters, getOutputShapes
-
Field Details
-
stateSize
protected long stateSize -
dropRate
protected float dropRate -
numLayers
protected int numLayers -
gates
protected int gates -
batchFirst
protected boolean batchFirst -
hasBiases
protected boolean hasBiases -
bidirectional
protected boolean bidirectional -
returnState
protected boolean returnState
-
-
Constructor Details
-
RecurrentBlock
Creates aRecurrentBlockobject.- Parameters:
builder- theBuilderthat has the necessary configurations
-
-
Method Details
-
getOutputShapes
Returns the expected output shapes of the block for the specified input shapes.- Parameters:
inputs- the shapes of the inputs- Returns:
- the expected output shapes of the block
-
beforeInitialize
Performs any action necessary before initialization. For example, keep the input information or verify the layout.- Overrides:
beforeInitializein classAbstractBaseBlock- Parameters:
inputShapes- the expected shapes of the input
-
prepare
Sets the shape ofParameters.- Overrides:
preparein classAbstractBaseBlock- Parameters:
inputs- the shapes of inputs
-
loadMetadata
public void loadMetadata(byte loadVersion, DataInputStream is) throws IOException, MalformedModelException Overwrite this to load additional metadata with the parameter values.If you overwrite
AbstractBaseBlock.saveMetadata(DataOutputStream)or need to provide backward compatibility to older binary formats, you probably need to overwrite this. This default implementation checks if the version number fits, if not it throws anMalformedModelException. After that it restores the input shapes.- Overrides:
loadMetadatain classAbstractBaseBlock- Parameters:
loadVersion- the version used for loading this metadata.is- the input stream we are loading from- Throws:
IOException- loading failedMalformedModelException- data can be loaded but has wrong format
-
getNumDirections
protected int getNumDirections()
-