grain.experimental.RebatchIterDataset

grain.experimental.RebatchIterDataset#

class grain.experimental.RebatchIterDataset(parent, batch_size, drop_remainder=False)#

Rebatches the input PyTree elements.

Parameters:
__init__(parent, batch_size, drop_remainder=False)#

An IterDataset that rebatches elements.

Parameters:
  • parent (IterDataset) – The parent IterDataset whose elements are to be rebatched.

  • batch_size (int) – The number of elements to batch together.

  • drop_remainder (bool) – Whether to drop the last batch if it is smaller than batch_size.

Methods

__init__(parent, batch_size[, drop_remainder])

An IterDataset that rebatches elements.

apply(transformations)

Returns a dataset with the given transformation(s) applied.

batch(batch_size, *[, drop_remainder, batch_fn])

Returns a dataset of elements batched along a new first dimension.

filter(transform)

Returns a dataset containing only the elements that match the filter.

map(transform)

Returns a dataset containing the elements transformed by transform.

map_with_index(transform)

Returns a dataset of the elements transformed by the transform.

mp_prefetch([options, worker_init_fn, ...])

Returns a dataset prefetching elements in multiple processes.

pipe(func, /, *args, **kwargs)

Syntactic sugar for applying a callable to this dataset.

prefetch(multiprocessing_options)

Deprecated, use mp_prefetch instead.

random_map(transform, *[, seed])

Returns a dataset containing the elements transformed by transform.

seed(seed)

Returns a dataset that uses the seed for default seed generation.

Attributes

parents