hydrax.Batch#

class hydrax.Batch(group: _GroupState[D], epoch: int, index: int, indices: Sequence[int])#

Bases: Generic[D]

An abstract base class representing a batch of data loaded by a Dataloader.

Caution

Don’t derive from or instantiate this type, it is created internally.

Batches are returned by Dataloader.__next__().

If is_training is True, this is a TrainingBatch. Otherwise, is_validation is True and this instance is a ValidationBatch. If you did not opt in to validation on the Dataloader`, all your batches will be TrainingBatch.

If you are not using DataGroup caching, you can ignore the following explanation. All of your Batches will be loaded uncached and you cannot call cache().

If you are using DataGroup caching and specified both loader_arrays and cache_arrays then you will receive batches in both the cached and uncached layouts. Even if you do not intend to persist a given Batch, you should convert it into the cached form by calling cache() after transforming your data from its uncached layout, for example by applying a fixed encoder network. In future epochs, your cached data will be loaded directly, unless you opted out with cache_readonly = True or persist = False.

Attention

Do not retain references to a Batch or its arrays beyond one model batch loop, and be careful not to do so accidentally. In limited cases, such as asynchronous logging, it is okay to retain references while that is occurring, but it’s important to ensure all references are released once they are no longer needed. As long as a Batch’s arrays are referenced (including by the Batch object itself), it consumes one Dataloader depth. If all of the depth is consumed in this way the Dataloader will stall (or deadlock, if the references are never released).

property additional: mappingproxy[str, Sequence[Any]]#

The additional data returned by the loader for each item of the batch.

This data is a readonly mapping of the form { 'key_0': [value_0, ...], ... }, where the number of values for each key is equal to the batch size of the DataGroup. A value will be None if the loader did not return additional data with the corresponding key for the corresponding item in this batch.

property arrays: mappingproxy[str, Array]#

The JAX arrays for this batch, as defined by the DataGroup.

The leading dimension of each array is the batch size of the DataGroup. The shapes and dtypes are as specified by group.loader_layouts if the batch is uncached, otherwise they are as specified by group.cache_layouts. The current layout is available as the layouts property.

Tip

To convert an uncached batch to a cached batch, use cache().

cache(arrays: dict[str, Array], additional: dict[str, list[object]] | None = None, *, persist: bool = True) None#

Converts an uncached batch into a cached batch by providing the arrays and additional data to cache.

Important

This method will raise a RuntimeError if the batch is already cached, or if cache layouts were not specified in the DataGroup.

Parameters:
  • arrays – The JAX arrays to cache. These must exactly correspond to DataGroup.cache_layouts.

  • additional – Optional additional data to cache. The contents must be pickleable and the length of each list must be exactly equal to DataGroup.batch_size. If not specified, the batch’s current additional data is retained.

  • persist – If False the batch will not actually be persisted to disk and so will need to be reloaded. Additionally, persistence is skipped if cache_readonly was specified when creating the DataGroup.

property data: Sequence[D]#

The data descriptors of the data in this batch.

The length of this sequence is equal to the batch size of the group.

get_additional(key: str, index: int | None = None, *, default: Any = None) Any#

Returns the additional data returned by the loader with the specified name.

Parameters:
  • key – The key returned by the loader.

  • index – The index of the data item within this batch. If this is not specified, a sequence corresponding to each item in the batch is returned.

  • default – The default value to return if the additional data is not present.

get_array(name: str) Array#

Returns the specified JAX array for this batch, as defined in the DataGroup.

Parameters:

name – The name of the array, as defined in the DataGroup.

See also

This is equivalent to arrays[name]. See arrays.

get_data(index: int) D#

Returns the data descriptor for the specified item in this batch.

Parameters:

index – The name of the array, as defined in the DataGroup.

See also

This is equivalent to data[index]. See data.

property group: DataGroup[D]#

The DataGroup from which this batch was loaded.

property group_index: int#

The index of this batch within group.

property indices: Sequence[int]#

The corresponding indices in group of the data in this batch.

property is_training: bool#

True if this instance is a TrainingBatch, and False otherwise.

property is_validation: bool#

True if this instance is a ValidationBatch, and False otherwise.

property layouts: mappingproxy[str, tuple[tuple[int, ...], dtype, dtype, int, int]]#

The layouts of arrays.

If this batch is uncached, this is group.loader_layouts, otherwise it is group.cache_layouts.

property uncached: bool#

True if the batch is in the loader format specified by the DataGroup, and False if it is in the cached format.

Tip

To convert an uncached batch to a cached batch, use cache().