Class BulkDataIterable

java.lang.Object
ai.djl.training.dataset.DataIterable
ai.djl.training.dataset.BulkDataIterable
All Implemented Interfaces:
Iterable<Batch>, Iterator<Batch>

public class BulkDataIterable extends DataIterable
BulkDataIterable specializes DataIterable in using ArrayDataset.getByRange(NDManager, long, long) or ArrayDataset.getByIndices(NDManager, long...) to create Batch instances more efficiently.
  • Constructor Details

    • BulkDataIterable

      public BulkDataIterable(ArrayDataset dataset, NDManager manager, Sampler sampler, Batchifier dataBatchifier, Batchifier labelBatchifier, Pipeline pipeline, Pipeline targetPipeline, ExecutorService executor, int preFetchNumber, Device device)
      Creates a new instance of BulkDataIterable with the given parameters.
      Parameters:
      dataset - the dataset to iterate on
      manager - the manager to create the arrays
      sampler - a sampler to sample data with
      dataBatchifier - a batchifier for data
      labelBatchifier - a batchifier for labels
      pipeline - the pipeline of transforms to apply on the data
      targetPipeline - the pipeline of transforms to apply on the labels
      executor - an ExecutorService
      preFetchNumber - the number of samples to prefetch
      device - the Device
  • Method Details

    • fetch

      protected Batch fetch(List<Long> indices, int progress) throws IOException
      Overrides:
      fetch in class DataIterable
      Throws:
      IOException
    • isRange

      public static boolean isRange(List<Long> indices)
      Checks whether the given indices actually represents a range.
      Parameters:
      indices - the indices to examine
      Returns:
      whether the given indices are sorted in ascending order with no gap and has at least one element