hydrax.Batch#
- class hydrax.Batch(group: _GroupState[D], epoch: int, index: int, indices: Sequence[int])#
Bases:
Generic[D]An abstract base class representing a batch of data loaded by a
Dataloader.Caution
Don’t derive from or instantiate this type, it is created internally.
Batches are returned by
Dataloader.__next__().If
is_trainingisTrue, this is aTrainingBatch. Otherwise,is_validationisTrueand this instance is aValidationBatch. If you did not opt in to validation on theDataloader`, all your batches will beTrainingBatch.If you are not using
DataGroupcaching, you can ignore the following explanation. All of your Batches will be loadeduncachedand you cannot callcache().If you are using
DataGroupcaching and specified bothloader_arraysandcache_arraysthen you will receive batches in both the cached anduncachedlayouts. Even if you do not intend to persist a given Batch, you should convert it into the cached form by callingcache()after transforming your data from its uncached layout, for example by applying a fixed encoder network. In future epochs, your cached data will be loaded directly, unless you opted out withcache_readonly = Trueorpersist = False.Attention
Do not retain references to a Batch or its
arraysbeyond one model batch loop, and be careful not to do so accidentally. In limited cases, such as asynchronous logging, it is okay to retain references while that is occurring, but it’s important to ensure all references are released once they are no longer needed. As long as a Batch’s arrays are referenced (including by the Batch object itself), it consumes oneDataloaderdepth. If all of thedepthis consumed in this way the Dataloader will stall (or deadlock, if the references are never released).- property additional: mappingproxy[str, Sequence[Any]]#
The additional data returned by the loader for each item of the batch.
This data is a readonly mapping of the form
{ 'key_0': [value_0, ...], ... }, where the number of values for each key is equal to the batch size of theDataGroup. A value will beNoneif the loader did not return additional data with the corresponding key for the corresponding item in this batch.
- property arrays: mappingproxy[str, Array]#
The JAX arrays for this batch, as defined by the
DataGroup.The leading dimension of each array is the batch size of the
DataGroup. The shapes and dtypes are as specified bygroup.loader_layoutsif the batch isuncached, otherwise they are as specified bygroup.cache_layouts. The current layout is available as thelayoutsproperty.Tip
To convert an uncached batch to a cached batch, use
cache().
- cache(arrays: dict[str, Array], additional: dict[str, list[object]] | None = None, *, persist: bool = True) None#
Converts an uncached batch into a cached batch by providing the arrays and additional data to cache.
Important
This method will raise a
RuntimeErrorif the batch is already cached, or if cache layouts were not specified in theDataGroup.- Parameters:
arrays – The JAX arrays to cache. These must exactly correspond to
DataGroup.cache_layouts.additional – Optional additional data to cache. The contents must be pickleable and the length of each list must be exactly equal to
DataGroup.batch_size. If not specified, the batch’s currentadditionaldata is retained.persist – If
Falsethe batch will not actually be persisted to disk and so will need to be reloaded. Additionally, persistence is skipped ifcache_readonlywas specified when creating theDataGroup.
- property data: Sequence[D]#
The data descriptors of the data in this batch.
The length of this sequence is equal to the batch size of the
group.
- get_additional(key: str, index: int | None = None, *, default: Any = None) Any#
Returns the additional data returned by the loader with the specified name.
- Parameters:
key – The key returned by the loader.
index – The index of the data item within this batch. If this is not specified, a sequence corresponding to each item in the batch is returned.
default – The default value to return if the additional data is not present.
- get_array(name: str) Array#
Returns the specified JAX array for this batch, as defined in the
DataGroup.- Parameters:
name – The name of the array, as defined in the DataGroup.
See also
This is equivalent to
arrays[name]. Seearrays.
- get_data(index: int) D#
Returns the data descriptor for the specified item in this batch.
- Parameters:
index – The name of the array, as defined in the DataGroup.
See also
This is equivalent to
data[index]. Seedata.
- property is_training: bool#
Trueif this instance is aTrainingBatch, andFalseotherwise.
- property is_validation: bool#
Trueif this instance is aValidationBatch, andFalseotherwise.