Class SeqBatcher

java.lang.Object
ai.djl.modality.nlp.generate.SeqBatcher

public class SeqBatcher extends Object
SeqBatcher stores the search state (BatchTensorList), the control variables (e.g. seqLength, offSets, etc), and batch operations (merge, trim, exitCriteria, etc) on BatchTensorList.
  • Method Details

    • getData

      public BatchTensorList getData()
      Returns the batch data which is stored as a BatchTensorList.
      Returns:
      the batch data stored as BatchTensorList
    • addBatch

      public void addBatch(SeqBatcher seqBatcherNew)
      Adds new batch.

      Modify the batch dimension and the left padding.

      Parameters:
      seqBatcherNew - the seqBatcher to add.
    • exitCriteria

      public void exitCriteria(NDArray outputIds, long maxLength, long eosTokenId)
      Checks which batch needs to exit, according certain criteria like EOS or maxLength.

      It is an iteration over batch and is thus also considered as batch operation.

      Parameters:
      outputIds - output token ids in an incremental forward call
      maxLength - max total sequence length
      eosTokenId - end of sentence token id
    • collectAndTrim

      public Map<Long,NDArray> collectAndTrim()
      Collects the finished sequences and trim the left padding.
      Returns:
      a map that stores request id to output token ids
    • sequenceComplete

      public boolean sequenceComplete()
      Computes the position ids by linear search from the left.
      Returns:
      the boolean indicating whether all sequences are empty