get_sampler (original) (raw)

Back to top

Edit this page

Toggle table of contents sidebar

composer.utils.dist.get_sampler(dataset, *, drop_last=False, shuffle=False, num_replicas=None, rank=None)[source]#

Constructs a DistributedSampler for a dataset.

The DistributedSampler assumes that each rank has a complete copy of the dataset. It ensures that each rank sees a unique shard for each epoch containinglen(dataset) / get_world_size() samples.

Parameters

Returns

torch.utils.data.distributed.DistributedSampler – The sampler.