hydrax.Batch#
- class hydrax.Batch(group: _GroupState[D], epoch: int, index: int, indices: Sequence[int])#
Bases:
Generic
[D
]An abstract base class representing a batch of data loaded by a
Dataloader
.Caution
Don’t derive from or instantiate this type, it is created internally.
Batches are returned by
Dataloader.__next__()
.If
is_training
isTrue
, this is aTrainingBatch
. Otherwise,is_validation
isTrue
and this instance is aValidationBatch
. If you did not opt in to validation on theDataloader`
, all your batches will beTrainingBatch
.If you are not using
DataGroup
caching, you can ignore the following explanation. All of your Batches will be loadeduncached
and you cannot callcache()
.If you are using
DataGroup
caching and specified bothloader_arrays
andcache_arrays
then you will receive batches in both the cached anduncached
layouts. Even if you do not intend to persist a given Batch, you should convert it into the cached form by callingcache()
after transforming your data from its uncached layout, for example by applying a fixed encoder network. In future epochs, your cached data will be loaded directly, unless you opted out withcache_readonly = True
orpersist = False
.Attention
Do not retain references to a Batch or its
arrays
beyond one model batch loop, and be careful not to do so accidentally. In limited cases, such as asynchronous logging, it is okay to retain references while that is occurring, but it’s important to ensure all references are released once they are no longer needed. As long as a Batch’s arrays are referenced (including by the Batch object itself), it consumes oneDataloader
depth
. If all of thedepth
is consumed in this way the Dataloader will stall (or deadlock, if the references are never released).- property additional: mappingproxy[str, Sequence[Any]]#
The additional data returned by the loader for each item of the batch.
This data is a readonly mapping of the form
{ 'key_0': [value_0, ...], ... }
, where the number of values for each key is equal to the batch size of theDataGroup
. A value will beNone
if the loader did not return additional data with the corresponding key for the corresponding item in this batch.
- property arrays: mappingproxy[str, Array]#
The JAX arrays for this batch, as defined by the
DataGroup
.The leading dimension of each array is the batch size of the
DataGroup
. The shapes and dtypes are as specified bygroup.loader_layouts
if the batch isuncached
, otherwise they are as specified bygroup.cache_layouts
. The current layout is available as thelayouts
property.Tip
To convert an uncached batch to a cached batch, use
cache()
.
- cache(arrays: dict[str, Array], additional: dict[str, list[object]] | None = None, *, persist: bool = True) None #
Converts an uncached batch into a cached batch by providing the arrays and additional data to cache.
Important
This method will raise a
RuntimeError
if the batch is already cached, or if cache layouts were not specified in theDataGroup
.- Parameters:
arrays – The JAX arrays to cache. These must exactly correspond to
DataGroup.cache_layouts
.additional – Optional additional data to cache. The contents must be pickleable and the length of each list must be exactly equal to
DataGroup.batch_size
. If not specified, the batch’s currentadditional
data is retained.persist – If
False
the batch will not actually be persisted to disk and so will need to be reloaded. Additionally, persistence is skipped ifcache_readonly
was specified when creating theDataGroup
.
- property data: Sequence[D]#
The data descriptors of the data in this batch.
The length of this sequence is equal to the batch size of the
group
.
- get_additional(key: str, index: int | None = None, *, default: Any = None) Any #
Returns the additional data returned by the loader with the specified name.
- Parameters:
key – The key returned by the loader.
index – The index of the data item within this batch. If this is not specified, a sequence corresponding to each item in the batch is returned.
default – The default value to return if the additional data is not present.
- get_array(name: str) Array #
Returns the specified JAX array for this batch, as defined in the
DataGroup
.- Parameters:
name – The name of the array, as defined in the DataGroup.
See also
This is equivalent to
arrays[name]
. Seearrays
.
- get_data(index: int) D #
Returns the data descriptor for the specified item in this batch.
- Parameters:
index – The name of the array, as defined in the DataGroup.
See also
This is equivalent to
data[index]
. Seedata
.
- property is_training: bool#
True
if this instance is aTrainingBatch
, andFalse
otherwise.
- property is_validation: bool#
True
if this instance is aValidationBatch
, andFalse
otherwise.