Package ai.djl.training.dataset
Class ArrayDataset
java.lang.Object
ai.djl.training.dataset.RandomAccessDataset
ai.djl.training.dataset.ArrayDataset
- All Implemented Interfaces:
Dataset
ArrayDataset is an implementation of RandomAccessDataset that consist entirely of
large NDArrays. It is recommended only for datasets small enough to fit in memory that
come in array formats. Otherwise, consider directly using the RandomAccessDataset
instead.
There can be multiple data and label NDArrays within the dataset. Each sample will be
retrieved by indexing each NDArray along the first dimension.
The following is an example of how to use ArrayDataset:
ArrayDataset dataset = new ArrayDataset.Builder()
.setData(data1, data2)
.optLabels(labels1, labels2, labels3)
.setSampling(20, false)
.build();
Suppose you get a Batch from trainer.iterateDataset(dataset) or
dataset.getData(manager). In the data of this batch, it will be an NDList with one NDArray for
each data input. In this case, it would be 2 arrays. Similarly, the labels would have 3 arrays.
- See Also:
-
Nested Class Summary
Nested ClassesNested classes/interfaces inherited from class ai.djl.training.dataset.RandomAccessDataset
RandomAccessDataset.BaseBuilder<T extends RandomAccessDataset.BaseBuilder<T>>Nested classes/interfaces inherited from interface ai.djl.training.dataset.Dataset
Dataset.Usage -
Field Summary
FieldsFields inherited from class ai.djl.training.dataset.RandomAccessDataset
dataBatchifier, device, labelBatchifier, limit, pipeline, prefetchNumber, sampler, targetPipeline -
Constructor Summary
ConstructorsConstructorDescriptionArrayDataset(RandomAccessDataset.BaseBuilder<?> builder) Creates a new instance ofArrayDatasetwith the arguments inArrayDataset.Builder. -
Method Summary
Modifier and TypeMethodDescriptionprotected longReturns the number of records available to be read in thisDataset.Gets theRecordfor the given index from the dataset.getByIndices(NDManager manager, long... indices) Gets theBatchfor the given indices from the dataset.getByRange(NDManager manager, long fromIndex, long toIndex) Gets theBatchfor the given range from the dataset.getData(NDManager manager, Sampler sampler, ExecutorService executorService) Fetches an iterator that can iterate through theDatasetwith a custom sampler multi-threaded.protected RandomAccessDatasetnewSubDataset(int[] indices, int from, int to) protected RandomAccessDatasetnewSubDataset(List<Long> subIndices) voidprepare(ai.djl.util.Progress progress) Prepares the dataset for use with tracked progress.Methods inherited from class ai.djl.training.dataset.RandomAccessDataset
getData, getData, getData, randomSplit, size, subDataset, subDataset, subDataset, subDataset, toArrayMethods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, waitMethods inherited from interface ai.djl.training.dataset.Dataset
matchingTranslatorOptions, prepare
-
Field Details
-
data
-
labels
-
-
Constructor Details
-
ArrayDataset
Creates a new instance ofArrayDatasetwith the arguments inArrayDataset.Builder.- Parameters:
builder- a builder with the required arguments
-
-
Method Details
-
availableSize
protected long availableSize()Returns the number of records available to be read in thisDataset.- Specified by:
availableSizein classRandomAccessDataset- Returns:
- the number of records available to be read in this
Dataset
-
get
Gets theRecordfor the given index from the dataset.- Specified by:
getin classRandomAccessDataset- Parameters:
manager- the manager used to create the arraysindex- the index of the requested data item- Returns:
- a
Recordthat contains the data and label of the requested data item
-
getByIndices
Gets theBatchfor the given indices from the dataset.- Parameters:
manager- the manager used to create the arraysindices- indices of the requested data items- Returns:
- a
Batchthat contains the data and label of the requested data items
-
getByRange
Gets theBatchfor the given range from the dataset.- Parameters:
manager- the manager used to create the arraysfromIndex- low endpoint (inclusive) of the datasettoIndex- high endpoint (exclusive) of the dataset- Returns:
- a
Batchthat contains the data and label of the requested data items
-
newSubDataset
- Overrides:
newSubDatasetin classRandomAccessDataset
-
newSubDataset
- Overrides:
newSubDatasetin classRandomAccessDataset
-
getData
public Iterable<Batch> getData(NDManager manager, Sampler sampler, ExecutorService executorService) throws IOException, TranslateException Fetches an iterator that can iterate through theDatasetwith a custom sampler multi-threaded.- Overrides:
getDatain classRandomAccessDataset- Parameters:
manager- the manager to create the arrayssampler- the sampler to use to iterate through the datasetexecutorService- the executorService to multi-thread with- Returns:
- an
IterableofBatchthat contains batches of data from the dataset - Throws:
IOException- for various exceptions depending on the datasetTranslateException- if there is an error while processing input
-
prepare
Prepares the dataset for use with tracked progress.- Parameters:
progress- the progress tracker- Throws:
IOException- for various exceptions depending on the dataset
-