get_sampler (original) (raw)
Toggle table of contents sidebar
composer.utils.dist.get_sampler(dataset, *, drop_last=False, shuffle=False, num_replicas=None, rank=None)[source]#
Constructs a DistributedSampler for a dataset.
The DistributedSampler assumes that each rank has a complete copy of the dataset. It ensures that each rank sees a unique shard for each epoch containinglen(dataset) / get_world_size()
samples.
Parameters
- dataset (Dataset) – The dataset.
- drop_last (bool) – Whether to trop the last batch.
- shuffle (bool) – Whether to shuffle the dataset.
- num_replicas (int, optional) – The number of replicas. If
None
, defaults to the world size. - rank (int, optional) – The rank. If
None
, defaults to the global rank.
Returns
torch.utils.data.distributed.DistributedSampler – The sampler.