Package ai.djl.training.dataset
Class BatchSampler
java.lang.Object
ai.djl.training.dataset.BatchSampler
- All Implemented Interfaces:
Sampler
BatchSampler is a Sampler that returns a single epoch over the data.
BatchSampler wraps another Sampler.SubSampler to yield
a mini-batch of indices.
-
Nested Class Summary
Nested classes/interfaces inherited from interface ai.djl.training.dataset.Sampler
Sampler.SubSampler -
Constructor Summary
ConstructorsConstructorDescriptionBatchSampler(Sampler.SubSampler subSampler, int batchSize) Creates a new instance ofBatchSamplerthat samples from the givenSampler.SubSampler, and yields a mini-batch of indices.BatchSampler(Sampler.SubSampler subSampler, int batchSize, boolean dropLast) Creates a new instance ofBatchSamplerthat samples from the givenSampler.SubSampler, and yields a mini-batch of indices. -
Method Summary
Modifier and TypeMethodDescriptionintReturns the batch size of theSampler.sample(RandomAccessDataset dataset) Fetches an iterator that iterates through the givenRandomAccessDatasetin mini-batches of indices.
-
Constructor Details
-
BatchSampler
Creates a new instance ofBatchSamplerthat samples from the givenSampler.SubSampler, and yields a mini-batch of indices.The last batch will not be dropped. The size of the last batch maybe smaller than batch size in case the size of the dataset is not a multiple of batch size.
- Parameters:
subSampler- theSampler.SubSamplerto sample frombatchSize- the required batch size
-
BatchSampler
Creates a new instance ofBatchSamplerthat samples from the givenSampler.SubSampler, and yields a mini-batch of indices.- Parameters:
subSampler- theSampler.SubSamplerto sample frombatchSize- the required batch sizedropLast- whether theBatchSamplershould drop the last few samples in case the size of the dataset is not a multiple of batch size
-
-
Method Details
-
sample
Fetches an iterator that iterates through the givenRandomAccessDatasetin mini-batches of indices.- Specified by:
samplein interfaceSampler- Parameters:
dataset- theRandomAccessDatasetto sample from- Returns:
- an iterator that iterates through the given
RandomAccessDatasetin mini-batches of indices
-
getBatchSize
public int getBatchSize()Returns the batch size of theSampler.- Specified by:
getBatchSizein interfaceSampler- Returns:
- the batch size of the
Sampler, -1 if batch size is not fixed
-