Package ai.djl.modality.nlp.generate
Class SeqBatcher
java.lang.Object
ai.djl.modality.nlp.generate.SeqBatcher
SeqBatcher stores the search state (BatchTensorList), the control variables (e.g.
seqLength, offSets, etc), and batch operations (merge, trim, exitCriteria, etc) on
BatchTensorList.-
Method Summary
Modifier and TypeMethodDescriptionvoidaddBatch(SeqBatcher seqBatcherNew) Adds new batch.Collects the finished sequences and trim the left padding.voidexitCriteria(NDArray outputIds, long maxLength, long eosTokenId) Checks which batch needs to exit, according certain criteria like EOS or maxLength.getData()Returns the batch data which is stored as aBatchTensorList.booleanComputes the position ids by linear search from the left.
-
Method Details
-
getData
Returns the batch data which is stored as aBatchTensorList.- Returns:
- the batch data stored as BatchTensorList
-
addBatch
Adds new batch.Modify the batch dimension and the left padding.
- Parameters:
seqBatcherNew- the seqBatcher to add.
-
exitCriteria
Checks which batch needs to exit, according certain criteria like EOS or maxLength.It is an iteration over batch and is thus also considered as batch operation.
- Parameters:
outputIds- output token ids in an incremental forward callmaxLength- max total sequence lengtheosTokenId- end of sentence token id
-
collectAndTrim
Collects the finished sequences and trim the left padding.- Returns:
- a map that stores request id to output token ids
-
sequenceComplete
public boolean sequenceComplete()Computes the position ids by linear search from the left.- Returns:
- the boolean indicating whether all sequences are empty
-