Package ai.djl.modality.nlp.embedding
Class TrainableWordEmbedding
- All Implemented Interfaces:
WordEmbedding,Block,AbstractEmbedding<String>,AbstractIndexedEmbedding<String>
TrainableWordEmbedding is an implementation of WordEmbedding and Embedding based on a DefaultVocabulary. This WordEmbedding is ideal when there
are no pre-trained embeddings available.-
Nested Class Summary
Nested ClassesNested classes/interfaces inherited from class ai.djl.nn.core.Embedding
Embedding.BaseBuilder<T,B extends Embedding.BaseBuilder<T, B>>, Embedding.DefaultEmbedding, Embedding.DefaultItem -
Field Summary
Fields inherited from class ai.djl.nn.core.Embedding
embedding, embeddingSize, fallthroughEmbedding, numEmbeddings, sparseFormatFields inherited from class ai.djl.nn.AbstractBlock
children, parametersFields inherited from class ai.djl.nn.AbstractBaseBlock
inputNames, inputShapes, outputDataTypes, version -
Constructor Summary
ConstructorsConstructorDescriptionConstructs a new instance ofTrainableWordEmbeddingfrom theTrainableWordEmbedding.Builder.TrainableWordEmbedding(Vocabulary vocabulary, int embeddingSize) Constructs a new instance ofTrainableWordEmbeddingfrom aDefaultVocabularyand a given embedding size. -
Method Summary
Modifier and TypeMethodDescriptionbuilder()Creates a builder to build anEmbedding.decode(byte[] byteArray) Decodes the given byte array into an object of input parameter type.longEmbeds an item.Embeds the word after preprocessed usingWordEmbedding.preprocessWordToEmbed(String).byte[]Encodes an object of input type into a byte array.static TrainableWordEmbeddingfromPretrained(NDArray embedding, List<String> items) Constructs a pretrained embedding.static TrainableWordEmbeddingfromPretrained(NDArray embedding, List<String> items, SparseFormat sparseFormat) Constructs a pretrained embedding.booleanReturns whether an item is in the embedding.longpreprocessWordToEmbed(String word) Pre-processes the word to embed into an array to pass into the model.unembed(long index) Returns the item corresponding to the given index.unembedWord(NDArray word) Returns the closest matching word for the given index.booleanvocabularyContains(String word) Returns whether an embedding exists for a word.Methods inherited from class ai.djl.nn.core.Embedding
embed, embedding, forwardInternal, getOutputShapes, loadParameters, prepare, saveParametersMethods inherited from class ai.djl.nn.AbstractBlock
addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParametersMethods inherited from class ai.djl.nn.AbstractBaseBlock
beforeInitialize, cast, clear, describeInput, forward, forward, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, initializeChildBlocks, isInitialized, loadMetadata, readInputShapes, saveInputShapes, saveMetadata, setInitializer, setInitializer, setInitializer, toStringMethods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, waitMethods inherited from interface ai.djl.nn.Block
forward, freezeParameters, freezeParameters, getOutputShapesMethods inherited from interface ai.djl.modality.nlp.embedding.WordEmbedding
embedWord, embedWord
-
Constructor Details
-
TrainableWordEmbedding
Constructs a new instance ofTrainableWordEmbeddingfrom theTrainableWordEmbedding.Builder.- Parameters:
builder- theTrainableWordEmbedding.Builder
-
TrainableWordEmbedding
Constructs a new instance ofTrainableWordEmbeddingfrom aDefaultVocabularyand a given embedding size.- Parameters:
vocabulary- aVocabularyto get tokens fromembeddingSize- the required embedding size
-
-
Method Details
-
fromPretrained
Constructs a pretrained embedding.Because it is created with preTrained data, it is created as a frozen block. If you with to update it, call
Block.freezeParameters(boolean).- Parameters:
embedding- the embedding arrayitems- the items in the embedding (in matching order to the embedding array)- Returns:
- the created embedding
-
fromPretrained
public static TrainableWordEmbedding fromPretrained(NDArray embedding, List<String> items, SparseFormat sparseFormat) Constructs a pretrained embedding.Because it is created with preTrained data, it is created as a frozen block. If you with to update it, call
Block.freezeParameters(boolean).- Parameters:
embedding- the embedding arrayitems- the items in the embedding (in matching order to the embedding array)sparseFormat- whether to compute row sparse gradient in the backward calculation- Returns:
- the created embedding
-
vocabularyContains
Returns whether an embedding exists for a word.- Specified by:
vocabularyContainsin interfaceWordEmbedding- Parameters:
word- the word to check- Returns:
- true if an embedding exists
-
preprocessWordToEmbed
Pre-processes the word to embed into an array to pass into the model.Make sure to call
WordEmbedding.embedWord(NDManager, long)after this.- Specified by:
preprocessWordToEmbedin interfaceWordEmbedding- Parameters:
word- the word to embed- Returns:
- the word that is ready to embed
-
embedWord
Embeds the word after preprocessed usingWordEmbedding.preprocessWordToEmbed(String).- Specified by:
embedWordin interfaceWordEmbedding- Parameters:
index- the index of the word to embed- Returns:
- the embedded word
- Throws:
EmbeddingException- if there is an error while trying to embed
-
unembedWord
Returns the closest matching word for the given index.- Specified by:
unembedWordin interfaceWordEmbedding- Parameters:
word- the word embedding to find the matching string word for.- Returns:
- a word similar to the passed in embedding
-
encode
Encodes an object of input type into a byte array. This is used in saving and loading theEmbeddingobjects.- Specified by:
encodein interfaceAbstractIndexedEmbedding<String>- Parameters:
input- the input object to be encoded- Returns:
- the encoded byte array.
-
decode
Decodes the given byte array into an object of input parameter type.- Specified by:
decodein interfaceAbstractIndexedEmbedding<String>- Parameters:
byteArray- the byte array to be decoded- Returns:
- the decode object of input parameter type
-
embed
Embeds an item.- Specified by:
embedin interfaceAbstractIndexedEmbedding<String>- Parameters:
item- the item to embed- Returns:
- the index of the item in the embedding
-
unembed
Returns the item corresponding to the given index.- Specified by:
unembedin interfaceAbstractIndexedEmbedding<String>- Parameters:
index- the index- Returns:
- the item corresponding to the given index
-
builder
Creates a builder to build anEmbedding.- Returns:
- a new builder
-
hasItem
Returns whether an item is in the embedding.- Specified by:
hasItemin interfaceAbstractEmbedding<String>- Parameters:
item- the item to test- Returns:
- true if the item is in the embedding
-