hydrax.Dataloader#

final class hydrax.Dataloader(loader_func: Callable[[D, mappingproxy[str, ndarray], int | None], dict[str, Any] | None], training: Iterable[DataGroup[D]] | DataGroup[D], *, validation: tuple[str, int, Iterable[DataGroup[D]] | DataGroup[D]] | None = None, depth: int = 3, loader_count: int = 1, loader_nice: int = 0, cacher_count: int = 1, start_at: tuple[int, int] = (0, 0), end_at: tuple[str, int] | None = None, interleave_groups: bool = True, deterministic: bool = True, timeout_sec: int = 60, startup_func: Callable[[], None] | None = None, placement_func: Callable[[mappingproxy[str, ndarray], mappingproxy[str, tuple[tuple[int, ...], dtype, dtype, int, int]], mappingproxy[str, Sequence[Any]], DataGroup[D]], dict[str, Array]] | None = None)#

Bases: Generic[D]

A zero-copy multiprocess JAX dataloader.

Caution

Don’t derive from Dataloader. Everything customizable is provided as an argument.

Parameters:
  • loader_func – A callable which accepts a DataGroup data descriptor, dictionary of arrays (as specified by the DataGroup) corresponding to a single batch element, and an integer seed for use in random augmentation and returns a (possibly empty) dictionary of additional data. This callable is called repeatedly by loader processes in order to load data items. This callable cannot be a lambda, as it must be loadable from a child process. This function must fully populate the passed-in arrays (which may contain invalid data from a previous batch), and must not retain any reference to them after completing. The seed argument will be None in the case of validation batches, as validation should operate consistently across epochs. The dictionary of additional data must be pickleable, as it is returned to the main process. Do not return any of the input arrays; they’re already shared with the main process for zero-copy operation. Avoid sending any additional arrays via the return dictionary. Instead, add additional zero-copy arrays to the DataGroups and fill them in. If for some reason an element cannot be loaded, you must raise an exception or allow one to propagate, which will eventually result in the corresponding batch being dropped.

  • training – A DataGroup for training, or any iterable of them. A single pass through all training DataGroups constitutes an epoch.

  • validation – An optional tuple specifying a validation mode, interval, and data. The validation mode can be either "batch" or "epoch", and the interval is how many batches or epochs between validation runs.

  • depth – The maximum number of batches that can exist at any point in time. Memory usage is proportional to the size of the largest possible batch multiplied by the loader depth. This should be at least two (one batch being processed, one batch loading), but should be larger if the dataloader needs to work ahead further to amortize loading time outliers. The default is 3, allowing the dataloader to work one batch ahead.

  • loader_count – The number of loader processes. Each loader process loads a single item at a time. This defaults to 1, but should probably be higher. Optimally, it should be tuned to saturate the available throughput of your data origin (disk/network) without introducing unnecessary process context switching. This may be 0 only if all data is exclusively loaded from cache, i.e. no DataGroup loader_arrays are specified.

  • loader_nice – An optional niceness applied to loader processes. Ignored if unsupported by the underlying operating system.

  • cacher_count – The number of cache writer threads. Each cache writer saves a single cache entry at a time. This defaults to 1, but should probably be higher. This may be 0 only if all DataGroup caches are readonly.

  • start_at – A tuple specifying how far to skip ahead before loading, for example to resume from a checkpoint. The first element is the epoch to skip to, and the second is a number of additional batches to skip. The number of batches to skip can exceed the number of batches in an epoch, in which case additional epochs are skipped. The default is (0, 0), indicating to start at the beginning.

  • end_at – An optional tuple specifying when to stop. The first element is either "batch" or "epoch", and the second specifies which zero-indexed batch or epoch to stop before. So ("epoch", 1) stops after epoch 0. This argument specifies an absolute position and is not relative to start_at. If this is not specified, the dataloader runs until it is interrupted by interrupt() or its controlling with block is exited.

  • interleave_groups – If multiple training DataGroups are specified, this controls how batches from the different groups are interleaved within an epoch. If False, groups are loaded sequentially in the order specified. If True, the default, batches from different groups are interleaved, with the least-utilized earliest-index group being selected for each batch.

  • deterministic – If False, batches are permitted to be processed out-of-order in the event that a later batch is ready prior to a preceding batch. This option is not compatible with hydrax.tqdm.

  • timeout_sec – Raise BatchTimeoutError if no batches have completed within the specified timeout. The default is 60, and 0 or less disables. A value less than 20 is not recommended.

  • startup_func – An optional callable which is called by each loader process once at startup immediately before loading commences. This callable cannot be a lambda, as it must be loadable from a child process.

  • placement_func – If specified this function is called with a dictionary of NumPy arrays and is responsible for orchestrating the placement of the arrays on JAX devices. In addition to the arrays, it is provided their layouts, additional data, and the associated DataGroup. See https://jax.readthedocs.io/en/latest/faq.html#faq-data-placement for details on data placement in JAX. The default implementation places batches uncommitted on the default JAX device.

Tip

In Hydrax, a single Dataloader is usually responsible for producing both your training and validation batches, in order to conserve resources and ensure perfectly smooth loading throughout.

Example:

from hydrax import Dataloader, DataGroup, TrainingBatch, ValidationBatch

def my_loader(data, arrays, seed):
    # load data from data source into arrays, optionally augmenting using 'seed'.
    # if 'seed' is None this is a data from a validation batch
    # return any additional data for the batch

if __name__ == "main":
    my_data = ...
    array_defs = {
        "array_name": ((dim, ...), numpy_dtype, jax_dtype),
        ...
    }

    train = DataGroup(batch_size, my_data[1000:], loader_arrays=array_defs)
    valid = DataGroup(batch_size, my_data[:1000], loader_arrays=array_defs)

    dataloader = Dataloader(
        my_loader,
        train,
        validation = ("epoch", 1, valid), # run validation after every epoch
        end_at = ("epoch", 5)             # run 5 epochs in total
    )

    with dataloader: # a with block is required
        # consider using hydrax.tqdm.tbatches instead of a vanilla for loop here
        for batch in dataloader:
            if isinstance(batch, TrainingBatch):
                run_training_batch(batch)
            elif isinstance(batch, ValidationBatch):
                run_validation_batch(batch)

            del batch # important, release batch before waiting for next one or cleaning up

Important

Read the documentation for your loader_func carefully. If you receive a warning from Hydrax about your loader, you should fix your code. Failure to do this could result in your batch data changing out from underneath you, leading to significant training issues such as NaNs.

If you’re using Python’s built in for loop to iterate over batches, it’s important to remember not to accidentally retain a reference to a batch while “going around” the loop. Python’s local variables are not scoped. See the code example above for a way to address this with del.

If you are experiencing deadlocks as a result of retaining batch or array references between iterations, consider using debug_batch_references() or gc.get_referrers to find out what’s holding on to your batches, though do keep in mind that JAX dispatch will retain references while running ahead. You can check your work by running the Dataloader with depth = 1, which will immediately deadlock if the first batch is not properly released.

Warning

Do not attempt to construct a Dataloader inside a loader process. Ensure your training code is guarded with if __name__ == '__main__':, or is otherwise prevented from running. As a last resort, you can check hydrax.is_worker() and bail.

Note

The Dataloader installs a handler for KeyboardInterrupt (Ctrl+C / SIGINT), which stops the flow of batches as soon as possible. After the dataloader has completed, you can check if this occurred by reading its interrupted property. You may want to save a checkpoint along with the number of the last completed batch, so that you can resume from where you left off with start_at.

If you send a second KeyboardInterrupt, Hydrax will raise a KeyboardInterrupt at the beginning of the next batch. This exception may cause you to lose progress unless you or a framework takes care to save a checkpoint in response.

If you send a third KeyboardInterrupt, the Python interpreter is immediately stopped and control is returned to you. You will lose all progress since the last checkpoint.

__enter__() Dataloader[D]#

Use via a with block.

__next__() Batch[D]#

Retrieves the next Batch.

property batches_per_epoch: int#

The total number of batches per epoch.

property batches_per_validation: int#

The total number of batches per validation run.

If no validation data was specified, this is 0.

property deterministic: bool#

True if the dataloader is deterministic, and False otherwise.

property first_batch: int#

The index of the first batch to load.

Controlled by the start_at argument.

idle_usec() int#

Returns the total amount of time, in microseconds, since the last call to idle_usec, that __next__() has spent waiting for a batch.

This represents the amount of time that the dataloader has stalled JAX dispatch. Ideally, this value should always be zero. If it is consistently high, you either have too few loaders (loader_count) or are bottlenecked by a shared resource (disk / network / cpu / swap). If it has spikes, you may need to increase loader_depth to allow loaders to work ahead in order to amortize longer loading times.

hydrax.tqdm consumes this metric if report_interval is specified.

interrupt()#

Interrupts the dataloader, so no further batches are returned by __next__().

property interrupted: bool#

True if this dataloader has been interrupted, and False otherwise.

property last_batch: int | None#

The index of the batch to end at, or None if no end point was specified.

Controlled by the end_at argument.