hydrax.Dataloader#
- final class hydrax.Dataloader(loader_func: Callable[[D, mappingproxy[str, ndarray], int | None], dict[str, Any] | None], training: Iterable[DataGroup[D]] | DataGroup[D], *, validation: tuple[str, int, Iterable[DataGroup[D]] | DataGroup[D]] | None = None, depth: int = 3, loader_count: int = 1, loader_nice: int = 0, cacher_count: int = 1, start_at: tuple[int, int] = (0, 0), end_at: tuple[str, int] | None = None, interleave_groups: bool = True, deterministic: bool = True, timeout_sec: int = 60, startup_func: Callable[[], None] | None = None, placement_func: Callable[[mappingproxy[str, ndarray], mappingproxy[str, tuple[tuple[int, ...], dtype, dtype, int, int]], mappingproxy[str, Sequence[Any]], DataGroup[D]], dict[str, Array]] | None = None)#
Bases:
Generic
[D
]A zero-copy multiprocess JAX dataloader.
Caution
Don’t derive from
Dataloader
. Everything customizable is provided as an argument.- Parameters:
loader_func – A callable which accepts a
DataGroup
data descriptor, dictionary of arrays (as specified by the DataGroup) corresponding to a single batch element, and an integer seed for use in random augmentation and returns a (possibly empty) dictionary of additional data. This callable is called repeatedly by loader processes in order to load data items. This callable cannot be a lambda, as it must be loadable from a child process. This function must fully populate the passed-in arrays (which may contain invalid data from a previous batch), and must not retain any reference to them after completing. The seed argument will beNone
in the case of validation batches, as validation should operate consistently across epochs. The dictionary of additional data must be pickleable, as it is returned to the main process. Do not return any of the input arrays; they’re already shared with the main process for zero-copy operation. Avoid sending any additional arrays via the return dictionary. Instead, add additional zero-copy arrays to the DataGroups and fill them in. If for some reason an element cannot be loaded, you must raise an exception or allow one to propagate, which will eventually result in the corresponding batch being dropped.training – A
DataGroup
for training, or any iterable of them. A single pass through all training DataGroups constitutes an epoch.validation – An optional tuple specifying a validation mode, interval, and data. The validation mode can be either
"batch"
or"epoch"
, and the interval is how many batches or epochs between validation runs.depth – The maximum number of batches that can exist at any point in time. Memory usage is proportional to the size of the largest possible batch multiplied by the loader depth. This should be at least two (one batch being processed, one batch loading), but should be larger if the dataloader needs to work ahead further to amortize loading time outliers. The default is 3, allowing the dataloader to work one batch ahead.
loader_count – The number of loader processes. Each loader process loads a single item at a time. This defaults to 1, but should probably be higher. Optimally, it should be tuned to saturate the available throughput of your data origin (disk/network) without introducing unnecessary process context switching. This may be 0 only if all data is exclusively loaded from cache, i.e. no
DataGroup
loader_arrays
are specified.loader_nice – An optional niceness applied to loader processes. Ignored if unsupported by the underlying operating system.
cacher_count – The number of cache writer threads. Each cache writer saves a single cache entry at a time. This defaults to 1, but should probably be higher. This may be 0 only if all
DataGroup
caches are readonly.start_at – A tuple specifying how far to skip ahead before loading, for example to resume from a checkpoint. The first element is the epoch to skip to, and the second is a number of additional batches to skip. The number of batches to skip can exceed the number of batches in an epoch, in which case additional epochs are skipped. The default is
(0, 0)
, indicating to start at the beginning.end_at – An optional tuple specifying when to stop. The first element is either
"batch"
or"epoch"
, and the second specifies which zero-indexed batch or epoch to stop before. So("epoch", 1)
stops after epoch 0. This argument specifies an absolute position and is not relative tostart_at
. If this is not specified, the dataloader runs until it is interrupted byinterrupt()
or its controllingwith
block is exited.interleave_groups – If multiple training DataGroups are specified, this controls how batches from the different groups are interleaved within an epoch. If
False
, groups are loaded sequentially in the order specified. IfTrue
, the default, batches from different groups are interleaved, with the least-utilized earliest-index group being selected for each batch.deterministic – If
False
, batches are permitted to be processed out-of-order in the event that a later batch is ready prior to a preceding batch. This option is not compatible withhydrax.tqdm
.timeout_sec – Raise
BatchTimeoutError
if no batches have completed within the specified timeout. The default is60
, and0
or less disables. A value less than20
is not recommended.startup_func – An optional callable which is called by each loader process once at startup immediately before loading commences. This callable cannot be a lambda, as it must be loadable from a child process.
placement_func – If specified this function is called with a dictionary of NumPy arrays and is responsible for orchestrating the placement of the arrays on JAX devices. In addition to the arrays, it is provided their layouts, additional data, and the associated
DataGroup
. See https://jax.readthedocs.io/en/latest/faq.html#faq-data-placement for details on data placement in JAX. The default implementation places batches uncommitted on the default JAX device.
Tip
In Hydrax, a single Dataloader is usually responsible for producing both your training and validation batches, in order to conserve resources and ensure perfectly smooth loading throughout.
Example:
from hydrax import Dataloader, DataGroup, TrainingBatch, ValidationBatch def my_loader(data, arrays, seed): # load data from data source into arrays, optionally augmenting using 'seed'. # if 'seed' is None this is a data from a validation batch # return any additional data for the batch if __name__ == "main": my_data = ... array_defs = { "array_name": ((dim, ...), numpy_dtype, jax_dtype), ... } train = DataGroup(batch_size, my_data[1000:], loader_arrays=array_defs) valid = DataGroup(batch_size, my_data[:1000], loader_arrays=array_defs) dataloader = Dataloader( my_loader, train, validation = ("epoch", 1, valid), # run validation after every epoch end_at = ("epoch", 5) # run 5 epochs in total ) with dataloader: # a with block is required # consider using hydrax.tqdm.tbatches instead of a vanilla for loop here for batch in dataloader: if isinstance(batch, TrainingBatch): run_training_batch(batch) elif isinstance(batch, ValidationBatch): run_validation_batch(batch) del batch # important, release batch before waiting for next one or cleaning up
Important
Read the documentation for your
loader_func
carefully. If you receive a warning from Hydrax about your loader, you should fix your code. Failure to do this could result in your batch data changing out from underneath you, leading to significant training issues such as NaNs.If you’re using Python’s built in
for
loop to iterate over batches, it’s important to remember not to accidentally retain a reference to a batch while “going around” the loop. Python’s local variables are not scoped. See the code example above for a way to address this withdel
.If you are experiencing deadlocks as a result of retaining batch or array references between iterations, consider using
debug_batch_references()
or gc.get_referrers to find out what’s holding on to your batches, though do keep in mind that JAX dispatch will retain references while running ahead. You can check your work by running the Dataloader withdepth = 1
, which will immediately deadlock if the first batch is not properly released.Warning
Do not attempt to construct a Dataloader inside a loader process. Ensure your training code is guarded with
if __name__ == '__main__':
, or is otherwise prevented from running. As a last resort, you can checkhydrax.is_worker()
and bail.Note
The Dataloader installs a handler for
KeyboardInterrupt
(Ctrl+C / SIGINT), which stops the flow of batches as soon as possible. After the dataloader has completed, you can check if this occurred by reading itsinterrupted
property. You may want to save a checkpoint along with the number of the last completed batch, so that you can resume from where you left off withstart_at
.If you send a second
KeyboardInterrupt
, Hydrax will raise aKeyboardInterrupt
at the beginning of the next batch. This exception may cause you to lose progress unless you or a framework takes care to save a checkpoint in response.If you send a third
KeyboardInterrupt
, the Python interpreter is immediately stopped and control is returned to you. You will lose all progress since the last checkpoint.- __enter__() Dataloader[D] #
Use via a
with
block.
- property batches_per_epoch: int#
The total number of batches per epoch.
- property batches_per_validation: int#
The total number of batches per validation run.
If no validation data was specified, this is
0
.
- property deterministic: bool#
True
if the dataloader is deterministic, andFalse
otherwise.
- property first_batch: int#
The index of the first batch to load.
Controlled by the
start_at
argument.
- idle_usec() int #
Returns the total amount of time, in microseconds, since the last call to
idle_usec
, that__next__()
has spent waiting for a batch.This represents the amount of time that the dataloader has stalled JAX dispatch. Ideally, this value should always be zero. If it is consistently high, you either have too few loaders (
loader_count
) or are bottlenecked by a shared resource (disk / network / cpu / swap). If it has spikes, you may need to increaseloader_depth
to allow loaders to work ahead in order to amortize longer loading times.hydrax.tqdm
consumes this metric ifreport_interval
is specified.
- interrupt()#
Interrupts the dataloader, so no further batches are returned by
__next__()
.
- property interrupted: bool#
True
if this dataloader has been interrupted, andFalse
otherwise.
- property last_batch: int | None#
The index of the batch to end at, or
None
if no end point was specified.Controlled by the
end_at
argument.