Package ai.djl.nn.core
Class Embedding<T>
java.lang.Object
ai.djl.nn.AbstractBaseBlock
ai.djl.nn.AbstractBlock
ai.djl.nn.core.Embedding<T>
- Type Parameters:
T- the type of item that should be embedded and map to the array
- All Implemented Interfaces:
Block,AbstractEmbedding<T>,AbstractIndexedEmbedding<T>
- Direct Known Subclasses:
TrainableWordEmbedding
An Embedding block map a collection of items to 1-Dimensional representative
NDArrays.-
Nested Class Summary
Nested ClassesModifier and TypeClassDescriptionstatic classEmbedding.BaseBuilder<T,B extends Embedding.BaseBuilder<T, B>> protected classprotected class -
Field Summary
FieldsModifier and TypeFieldDescriptionprotected Parameterprotected intprotected AbstractIndexedEmbedding<T>protected intprotected SparseFormatFields inherited from class ai.djl.nn.AbstractBlock
children, parametersFields inherited from class ai.djl.nn.AbstractBaseBlock
inputNames, inputShapes, outputDataTypes, version -
Constructor Summary
ConstructorsModifierConstructorDescriptionprotectedConstructs a pretrained embedding.protectedEmbedding(NDArray embedding, SparseFormat format) Constructs a pretrained embedding.protectedEmbedding(Embedding.BaseBuilder<T, ?> baseBuilder) -
Method Summary
Modifier and TypeMethodDescriptionEmbeds an array of items.static NDListembedding(NDArray input, NDArray weight, SparseFormat sparse) A simple lookup table that looks up embeddings in a fixed dictionary and size.protected NDListforwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<String, Object> params) A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)after initialization.Shape[]getOutputShapes(Shape[] inputShapes) Returns the expected output shapes of the block for the specified input shapes.voidloadParameters(NDManager manager, DataInputStream is) Loads the parameters from the given input stream.voidSets the shape ofParameters.voidWrites the parameters of the block to the given outputStream.Methods inherited from class ai.djl.nn.AbstractBlock
addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParametersMethods inherited from class ai.djl.nn.AbstractBaseBlock
beforeInitialize, cast, clear, describeInput, forward, forward, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, initializeChildBlocks, isInitialized, loadMetadata, readInputShapes, saveInputShapes, saveMetadata, setInitializer, setInitializer, setInitializer, toStringMethods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, waitMethods inherited from interface ai.djl.nn.core.AbstractEmbedding
hasItemMethods inherited from interface ai.djl.nn.core.AbstractIndexedEmbedding
decode, embed, encode, unembedMethods inherited from interface ai.djl.nn.Block
forward, freezeParameters, freezeParameters, getOutputShapes
-
Field Details
-
numEmbeddings
protected int numEmbeddings -
embeddingSize
protected int embeddingSize -
sparseFormat
-
fallthroughEmbedding
-
embedding
-
-
Constructor Details
-
Embedding
-
Embedding
Constructs a pretrained embedding.- Parameters:
embedding- the embedding array
-
Embedding
Constructs a pretrained embedding.Because it is created with preTrained data, it is created as a frozen block. If you with to update it, call
Block.freezeParameters(boolean).- Parameters:
embedding- the embedding arrayformat- whether to compute row sparse gradient in the backward calculation
-
-
Method Details
-
prepare
Sets the shape ofParameters.- Overrides:
preparein classAbstractBaseBlock- Parameters:
inputShapes- the shapes of inputs
-
getOutputShapes
Returns the expected output shapes of the block for the specified input shapes.- Specified by:
getOutputShapesin interfaceBlock- Parameters:
inputShapes- the shapes of the inputs- Returns:
- the expected output shapes of the block
-
forwardInternal
protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<String, Object> params) A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)after initialization.- Specified by:
forwardInternalin classAbstractBaseBlock- Parameters:
parameterStore- the parameter storeinputs- the input NDListtraining- true for a training forward passparams- optional parameters- Returns:
- the output of the forward pass
-
saveParameters
Writes the parameters of the block to the given outputStream.- Specified by:
saveParametersin interfaceBlock- Overrides:
saveParametersin classAbstractBaseBlock- Parameters:
os- the outputstream to save the parameters to- Throws:
IOException- if an I/O error occurs
-
loadParameters
public void loadParameters(NDManager manager, DataInputStream is) throws IOException, MalformedModelException Loads the parameters from the given input stream.- Specified by:
loadParametersin interfaceBlock- Overrides:
loadParametersin classAbstractBaseBlock- Parameters:
manager- an NDManager to create the parameter arraysis- the inputstream that stream the parameter values- Throws:
IOException- if an I/O error occursMalformedModelException- if the model file is corrupted or unsupported
-
embed
Embeds an array of items.- Specified by:
embedin interfaceAbstractEmbedding<T>- Parameters:
manager- the manager for the new embeddingsitems- the items to embed- Returns:
- the embedding
NDArrayof Shape(items.length, embeddingSize)
-
embedding
A simple lookup table that looks up embeddings in a fixed dictionary and size.- Parameters:
input- NDArray containing indices into the embedding matrixweight- The embedding matrix with number of rows equal to the maximum possible index + 1, and number of columns equal to the embedding sizesparse- SparseFormat of the gradient- Returns:
- output NDArray
-