hydrax.TrainingBatch#
- final class hydrax.TrainingBatch(group: _GroupState[D], index: int, indices: Sequence[int], epoch: int, epoch_batch: int, batch_num: int, seed: int)#
Bases:
Batch[D],Generic[D]A batch of training data loaded by a
Dataloader.Batches are returned by
Dataloader.__next__(). You can determine if aBatchis aTrainingBatchby checkingis_training.Caution
Don’t derive from or instantiate this type, it is created internally.
- property additional: mappingproxy[str, Sequence[Any]]#
The additional data returned by the loader for each item of the batch.
This data is a readonly mapping of the form
{ 'key_0': [value_0, ...], ... }, where the number of values for each key is equal to the batch size of theDataGroup. A value will beNoneif the loader did not return additional data with the corresponding key for the corresponding item in this batch.
- property arrays: mappingproxy[str, Array]#
The JAX arrays for this batch, as defined by the
DataGroup.The leading dimension of each array is the batch size of the
DataGroup. The shapes and dtypes are as specified bygroup.loader_layoutsif the batch isuncached, otherwise they are as specified bygroup.cache_layouts. The current layout is available as thelayoutsproperty.Tip
To convert an uncached batch to a cached batch, use
cache().
- property batch_num: int#
The overall zero-based index of this batch.
- cache(arrays: dict[str, Array], additional: dict[str, list[object]] | None = None, *, persist: bool = True) None#
Converts an uncached batch into a cached batch by providing the arrays and additional data to cache.
Important
This method will raise a
RuntimeErrorif the batch is already cached, or if cache layouts were not specified in theDataGroup.- Parameters:
arrays – The JAX arrays to cache. These must exactly correspond to
DataGroup.cache_layouts.additional – Optional additional data to cache. The contents must be pickleable and the length of each list must be exactly equal to
DataGroup.batch_size. If not specified, the batch’s currentadditionaldata is retained.persist – If
Falsethe batch will not actually be persisted to disk and so will need to be reloaded. Additionally, persistence is skipped ifcache_readonlywas specified when creating theDataGroup.
- property data: Sequence[D]#
The data descriptors of the data in this batch.
The length of this sequence is equal to the batch size of the
group.
- property epoch: int#
The zero-based training epoch number for this batch.
- property epoch_batch: int#
The zero-based index of this batch within the current epoch.
- get_additional(key: str, index: int | None = None, *, default: Any = None) Any#
Returns the additional data returned by the loader with the specified name.
- Parameters:
key – The key returned by the loader.
index – The index of the data item within this batch. If this is not specified, a sequence corresponding to each item in the batch is returned.
default – The default value to return if the additional data is not present.
- get_array(name: str) Array#
Returns the specified JAX array for this batch, as defined in the
DataGroup.- Parameters:
name – The name of the array, as defined in the DataGroup.
See also
This is equivalent to
arrays[name]. Seearrays.
- get_data(index: int) D#
Returns the data descriptor for the specified item in this batch.
- Parameters:
index – The name of the array, as defined in the DataGroup.
See also
This is equivalent to
data[index]. Seedata.
- property is_training: bool#
True
- property is_validation: bool#
False
- property layouts: mappingproxy[str, tuple[tuple[int, ...], dtype, dtype, int, int]]#
The layouts of
arrays.If this batch is
uncached, this isgroup.loader_layouts, otherwise it isgroup.cache_layouts.
- property seed: int#
The deterministic seed for randomness associated with this batch.
- property seeds: Sequence[int]#
The deterministic seeds for randomness associated with each item of this batch.
Each seed is the same seed that was passed to the
Dataloaderloader_funcfor the corresponding item of this batch.