Package ai.djl.nn.transformer
Class ScaledDotProductAttentionBlock.Builder
java.lang.Object
ai.djl.nn.transformer.ScaledDotProductAttentionBlock.Builder
- Enclosing class:
- ScaledDotProductAttentionBlock
A builder for
ScaledDotProductAttentionBlocks.-
Method Summary
Modifier and TypeMethodDescriptionbuild()Creates a newScaledDotProductAttentionBlockwith the current configuration.optAttentionProbsDropoutProb(float attentionProbsDropoutProb) Sets the probability of applying dropout to the attention probability distribution.setEmbeddingSize(int embeddingSize) Sets the embedding Size to be used for the internal token representation.setHeadCount(int headCount) Sets the number of attention Heads, must divide the embedding size without rest.
-
Method Details
-
setEmbeddingSize
Sets the embedding Size to be used for the internal token representation.- Parameters:
embeddingSize- the embedding Size to be used for the internal token representation.- Returns:
- this builder
-
setHeadCount
Sets the number of attention Heads, must divide the embedding size without rest. I.e. if embeddingSize = 10, a headCount of 3 would not be valid, a headCount of 1, 2 or 5 would be.- Parameters:
headCount- the number of attention Heads- Returns:
- this builder
-
optAttentionProbsDropoutProb
public ScaledDotProductAttentionBlock.Builder optAttentionProbsDropoutProb(float attentionProbsDropoutProb) Sets the probability of applying dropout to the attention probability distribution. This dropout can randomly remove a complete token from the result at a position.- Parameters:
attentionProbsDropoutProb- the probability of applying dropout to the attention probability distribution- Returns:
- this builder
-
build
Creates a newScaledDotProductAttentionBlockwith the current configuration.- Returns:
- a new
ScaledDotProductAttentionBlockwith the current configuration.
-