Class BertBlock.Builder

java.lang.Object
ai.djl.nn.transformer.BertBlock.Builder
Enclosing class:
BertBlock

public static final class BertBlock.Builder extends Object
The Builder to construct a BertBlock type of Block.
  • Method Details

    • setTokenDictionarySize

      public BertBlock.Builder setTokenDictionarySize(int tokenDictionarySize)
      Sets the number of tokens in the dictionary.
      Parameters:
      tokenDictionarySize - the number of tokens in the dictionary
      Returns:
      this builder
    • optTypeDictionarySize

      public BertBlock.Builder optTypeDictionarySize(int typeDictionarySize)
      Sets the number of possible token types. This should be a very small number (2-16).
      Parameters:
      typeDictionarySize - the number of possible token types. This should be a very small number (2-16)
      Returns:
      this builder
    • optEmbeddingSize

      public BertBlock.Builder optEmbeddingSize(int embeddingSize)
      Sets the embedding size to use for input tokens. This size must be divisible by the number of attention heads.
      Parameters:
      embeddingSize - the embedding size to use for input tokens.
      Returns:
      this builder
    • optTransformerBlockCount

      public BertBlock.Builder optTransformerBlockCount(int transformerBlockCount)
      Sets the number of transformer blocks to use.
      Parameters:
      transformerBlockCount - the number of transformer blocks to use
      Returns:
      this builder
    • optAttentionHeadCount

      public BertBlock.Builder optAttentionHeadCount(int attentionHeadCount)
      Sets the number of attention heads to use in each transformer block. This number must divide the embedding size without rest.
      Parameters:
      attentionHeadCount - the number of attention heads to use in each transformer block.
      Returns:
      this builder
    • optHiddenSize

      public BertBlock.Builder optHiddenSize(int hiddenSize)
      Sets the size of the hidden layers in the fully connected networks used.
      Parameters:
      hiddenSize - the size of the hidden layers in the fully connected networks used.
      Returns:
      this builder
    • optHiddenDropoutProbability

      public BertBlock.Builder optHiddenDropoutProbability(float hiddenDropoutProbability)
      Sets the dropout probabilty in the hidden fully connected networks.
      Parameters:
      hiddenDropoutProbability - the dropout probabilty in the hidden fully connected networks.
      Returns:
      this builder
    • optMaxSequenceLength

      public BertBlock.Builder optMaxSequenceLength(int maxSequenceLength)
      Sets the maximum sequence length this model can process. Memory and compute requirements of the attention mechanism is O(n²), so large values can easily exhaust your GPU memory!
      Parameters:
      maxSequenceLength - the maximum sequence length this model can process.
      Returns:
      this builder
    • nano

      public BertBlock.Builder nano()
      Tiny config for testing on laptops.
      Returns:
      this builder
    • micro

      public BertBlock.Builder micro()
      Sets this builder's params to a minimal configuration that nevertheless performs quite well.
      Returns:
      this builder
    • base

      public BertBlock.Builder base()
      Sets this builder's params to the BASE config of the original BERT paper. (except for the dictionary size)
      Returns:
      this builder
    • large

      public BertBlock.Builder large()
      Sets this builder's params to the LARGE config of the original BERT paper. (except for the dictionary size)
      Returns:
      this builder
    • build

      public BertBlock build()
      Returns a new BertBlock with the parameters of this builder.
      Returns:
      a new BertBlock with the parameters of this builder.