hydrax.TrainingBatch#

final class hydrax.TrainingBatch(group: _GroupState[D], index: int, indices: Sequence[int], epoch: int, epoch_batch: int, batch_num: int, seed: int)#

Bases: Batch[D], Generic[D]

A batch of training data loaded by a Dataloader.

Batches are returned by Dataloader.__next__(). You can determine if a Batch is a TrainingBatch by checking is_training.

Caution

Don’t derive from or instantiate this type, it is created internally.

property additional: mappingproxy[str, Sequence[Any]]#

The additional data returned by the loader for each item of the batch.

This data is a readonly mapping of the form { 'key_0': [value_0, ...], ... }, where the number of values for each key is equal to the batch size of the DataGroup. A value will be None if the loader did not return additional data with the corresponding key for the corresponding item in this batch.

property arrays: mappingproxy[str, Array]#

The JAX arrays for this batch, as defined by the DataGroup.

The leading dimension of each array is the batch size of the DataGroup. The shapes and dtypes are as specified by group.loader_layouts if the batch is uncached, otherwise they are as specified by group.cache_layouts. The current layout is available as the layouts property.

Tip

To convert an uncached batch to a cached batch, use cache().

property batch_num: int#

The overall zero-based index of this batch.

cache(arrays: dict[str, Array], additional: dict[str, list[object]] | None = None, *, persist: bool = True) None#

Converts an uncached batch into a cached batch by providing the arrays and additional data to cache.

Important

This method will raise a RuntimeError if the batch is already cached, or if cache layouts were not specified in the DataGroup.

Parameters:
  • arrays – The JAX arrays to cache. These must exactly correspond to DataGroup.cache_layouts.

  • additional – Optional additional data to cache. The contents must be pickleable and the length of each list must be exactly equal to DataGroup.batch_size. If not specified, the batch’s current additional data is retained.

  • persist – If False the batch will not actually be persisted to disk and so will need to be reloaded. Additionally, persistence is skipped if cache_readonly was specified when creating the DataGroup.

property data: Sequence[D]#

The data descriptors of the data in this batch.

The length of this sequence is equal to the batch size of the group.

property epoch: int#

The zero-based training epoch number for this batch.

property epoch_batch: int#

The zero-based index of this batch within the current epoch.

get_additional(key: str, index: int | None = None, *, default: Any = None) Any#

Returns the additional data returned by the loader with the specified name.

Parameters:
  • key – The key returned by the loader.

  • index – The index of the data item within this batch. If this is not specified, a sequence corresponding to each item in the batch is returned.

  • default – The default value to return if the additional data is not present.

get_array(name: str) Array#

Returns the specified JAX array for this batch, as defined in the DataGroup.

Parameters:

name – The name of the array, as defined in the DataGroup.

See also

This is equivalent to arrays[name]. See arrays.

get_data(index: int) D#

Returns the data descriptor for the specified item in this batch.

Parameters:

index – The name of the array, as defined in the DataGroup.

See also

This is equivalent to data[index]. See data.

property group: DataGroup[D]#

The DataGroup from which this batch was loaded.

property group_index: int#

The index of this batch within group.

property indices: Sequence[int]#

The corresponding indices in group of the data in this batch.

property is_training: bool#

True

property is_validation: bool#

False

property layouts: mappingproxy[str, tuple[tuple[int, ...], dtype, dtype, int, int]]#

The layouts of arrays.

If this batch is uncached, this is group.loader_layouts, otherwise it is group.cache_layouts.

property seed: int#

The deterministic seed for randomness associated with this batch.

property seeds: Sequence[int]#

The deterministic seeds for randomness associated with each item of this batch.

Each seed is the same seed that was passed to the Dataloader loader_func for the corresponding item of this batch.

property uncached: bool#

True if the batch is in the loader format specified by the DataGroup, and False if it is in the cached format.

Tip

To convert an uncached batch to a cached batch, use cache().