Package ai.djl.nn.recurrent
Class LSTM
java.lang.Object
ai.djl.nn.AbstractBaseBlock
ai.djl.nn.AbstractBlock
ai.djl.nn.recurrent.RecurrentBlock
ai.djl.nn.recurrent.LSTM
- All Implemented Interfaces:
Block
LSTM is an implementation of recurrent neural networks which applies Long Short-Term
Memory recurrent layer to input.
Reference paper - LONG SHORT-TERM MEMORY - Hochreiter, 1997. http://www.bioinf.jku.at/publications/older/2604.pdf
The LSTM operator is formulated as below:
$$ \begin{split}\begin{array}{ll} i_t = \mathrm{sigmoid}(W_{ii} x_t + b_{ii} + W_{hi} h_{(t-1)} + b_{hi}) \\ f_t = \mathrm{sigmoid}(W_{if} x_t + b_{if} + W_{hf} h_{(t-1)} + b_{hf}) \\ g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hc} h_{(t-1)} + b_{hg}) \\ o_t = \mathrm{sigmoid}(W_{io} x_t + b_{io} + W_{ho} h_{(t-1)} + b_{ho}) \\ c_t = f_t * c_{(t-1)} + i_t * g_t \\ h_t = o_t * \tanh(c_t) \end{array}\end{split} $$
-
Nested Class Summary
Nested ClassesModifier and TypeClassDescriptionstatic final classNested classes/interfaces inherited from class ai.djl.nn.recurrent.RecurrentBlock
RecurrentBlock.BaseBuilder<T extends RecurrentBlock.BaseBuilder> -
Field Summary
Fields inherited from class ai.djl.nn.recurrent.RecurrentBlock
batchFirst, bidirectional, dropRate, gates, hasBiases, numLayers, returnState, stateSizeFields inherited from class ai.djl.nn.AbstractBlock
children, parametersFields inherited from class ai.djl.nn.AbstractBaseBlock
inputNames, inputShapes, outputDataTypes, version -
Method Summary
Modifier and TypeMethodDescriptionstatic LSTM.Builderbuilder()Creates a builder to build aLSTM.protected NDListforwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<String, Object> params) A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)after initialization.Methods inherited from class ai.djl.nn.recurrent.RecurrentBlock
beforeInitialize, getNumDirections, getOutputShapes, loadMetadata, prepareMethods inherited from class ai.djl.nn.AbstractBlock
addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParametersMethods inherited from class ai.djl.nn.AbstractBaseBlock
cast, clear, describeInput, forward, forward, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, initializeChildBlocks, isInitialized, loadParameters, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, setInitializer, toStringMethods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, waitMethods inherited from interface ai.djl.nn.Block
forward, freezeParameters, freezeParameters, getOutputShapes
-
Method Details
-
forwardInternal
protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<String, Object> params) A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)after initialization.- Specified by:
forwardInternalin classAbstractBaseBlock- Parameters:
parameterStore- the parameter storeinputs- the input NDListtraining- true for a training forward passparams- optional parameters- Returns:
- the output of the forward pass
-
builder
Creates a builder to build aLSTM.- Returns:
- a new builder
-