hydrax.TrainingBatch#
- final class hydrax.TrainingBatch(group: _GroupState[D], index: int, indices: Sequence[int], epoch: int, epoch_batch: int, batch_num: int, seed: int)#
Bases:
Batch
[D
],Generic
[D
]A batch of training data loaded by a
Dataloader
.Batches are returned by
Dataloader.__next__()
. You can determine if aBatch
is aTrainingBatch
by checkingis_training
.Caution
Don’t derive from or instantiate this type, it is created internally.
- property additional: mappingproxy[str, Sequence[Any]]#
The additional data returned by the loader for each item of the batch.
This data is a readonly mapping of the form
{ 'key_0': [value_0, ...], ... }
, where the number of values for each key is equal to the batch size of theDataGroup
. A value will beNone
if the loader did not return additional data with the corresponding key for the corresponding item in this batch.
- property arrays: mappingproxy[str, Array]#
The JAX arrays for this batch, as defined by the
DataGroup
.The leading dimension of each array is the batch size of the
DataGroup
. The shapes and dtypes are as specified bygroup.loader_layouts
if the batch isuncached
, otherwise they are as specified bygroup.cache_layouts
. The current layout is available as thelayouts
property.Tip
To convert an uncached batch to a cached batch, use
cache()
.
- property batch_num: int#
The overall zero-based index of this batch.
- cache(arrays: dict[str, Array], additional: dict[str, list[object]] | None = None, *, persist: bool = True) None #
Converts an uncached batch into a cached batch by providing the arrays and additional data to cache.
Important
This method will raise a
RuntimeError
if the batch is already cached, or if cache layouts were not specified in theDataGroup
.- Parameters:
arrays – The JAX arrays to cache. These must exactly correspond to
DataGroup.cache_layouts
.additional – Optional additional data to cache. The contents must be pickleable and the length of each list must be exactly equal to
DataGroup.batch_size
. If not specified, the batch’s currentadditional
data is retained.persist – If
False
the batch will not actually be persisted to disk and so will need to be reloaded. Additionally, persistence is skipped ifcache_readonly
was specified when creating theDataGroup
.
- property data: Sequence[D]#
The data descriptors of the data in this batch.
The length of this sequence is equal to the batch size of the
group
.
- property epoch: int#
The zero-based training epoch number for this batch.
- property epoch_batch: int#
The zero-based index of this batch within the current epoch.
- get_additional(key: str, index: int | None = None, *, default: Any = None) Any #
Returns the additional data returned by the loader with the specified name.
- Parameters:
key – The key returned by the loader.
index – The index of the data item within this batch. If this is not specified, a sequence corresponding to each item in the batch is returned.
default – The default value to return if the additional data is not present.
- get_array(name: str) Array #
Returns the specified JAX array for this batch, as defined in the
DataGroup
.- Parameters:
name – The name of the array, as defined in the DataGroup.
See also
This is equivalent to
arrays[name]
. Seearrays
.
- get_data(index: int) D #
Returns the data descriptor for the specified item in this batch.
- Parameters:
index – The name of the array, as defined in the DataGroup.
See also
This is equivalent to
data[index]
. Seedata
.
- property is_training: bool#
True
- property is_validation: bool#
False
- property layouts: mappingproxy[str, tuple[tuple[int, ...], dtype, dtype, int, int]]#
The layouts of
arrays
.If this batch is
uncached
, this isgroup.loader_layouts
, otherwise it isgroup.cache_layouts
.
- property seed: int#
The deterministic seed for randomness associated with this batch.
- property seeds: Sequence[int]#
The deterministic seeds for randomness associated with each item of this batch.
Each seed is the same seed that was passed to the
Dataloader
loader_func
for the corresponding item of this batch.