Package ai.djl.nn.recurrent
Class RNN
java.lang.Object
ai.djl.nn.AbstractBaseBlock
ai.djl.nn.AbstractBlock
ai.djl.nn.recurrent.RecurrentBlock
ai.djl.nn.recurrent.RNN
- All Implemented Interfaces:
Block
RNN is an implementation of recurrent neural networks which applies a single-gate
recurrent layer to input. Two kinds of activation function are supported: ReLU and Tanh.
Current implementation refers the [paper](https://crl.ucsd.edu/~elman/Papers/fsit.pdf), Finding structure in time - Elman, 1988.
The RNN operator is formulated as below:
With ReLU activation function: \(h_t = relu(W_{ih} * x_t + b_{ih} + W_{hh} * h_{(t-1)} + b_{hh})\)
With Tanh activation function: \(h_t = \tanh(W_{ih} * x_t + b_{ih} + W_{hh} * h_{(t-1)} + b_{hh})\)
-
Nested Class Summary
Nested ClassesModifier and TypeClassDescriptionstatic enumAn enum that enumerates the type of activation.static final classNested classes/interfaces inherited from class ai.djl.nn.recurrent.RecurrentBlock
RecurrentBlock.BaseBuilder<T extends RecurrentBlock.BaseBuilder> -
Field Summary
Fields inherited from class ai.djl.nn.recurrent.RecurrentBlock
batchFirst, bidirectional, dropRate, gates, hasBiases, numLayers, returnState, stateSizeFields inherited from class ai.djl.nn.AbstractBlock
children, parametersFields inherited from class ai.djl.nn.AbstractBaseBlock
inputNames, inputShapes, outputDataTypes, version -
Method Summary
Modifier and TypeMethodDescriptionstatic RNN.Builderbuilder()Creates a builder to build aRNN.protected NDListforwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<String, Object> params) A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)after initialization.Methods inherited from class ai.djl.nn.recurrent.RecurrentBlock
beforeInitialize, getNumDirections, getOutputShapes, loadMetadata, prepareMethods inherited from class ai.djl.nn.AbstractBlock
addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParametersMethods inherited from class ai.djl.nn.AbstractBaseBlock
cast, clear, describeInput, forward, forward, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, initializeChildBlocks, isInitialized, loadParameters, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, setInitializer, toStringMethods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, waitMethods inherited from interface ai.djl.nn.Block
forward, freezeParameters, freezeParameters, getOutputShapes
-
Method Details
-
forwardInternal
protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<String, Object> params) A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)after initialization.- Specified by:
forwardInternalin classAbstractBaseBlock- Parameters:
parameterStore- the parameter storeinputs- the input NDListtraining- true for a training forward passparams- optional parameters- Returns:
- the output of the forward pass
-
builder
Creates a builder to build aRNN.- Returns:
- a new builder
-