torch.utils.data — PyTorch 2.7 documentation (original) (raw)

At the heart of PyTorch data loading utility is the torch.utils.data.DataLoaderclass. It represents a Python iterable over a dataset, with support for

These options are configured by the constructor arguments of aDataLoader, which has signature:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, *, prefetch_factor=2, persistent_workers=False)

The sections below describe in details the effects and usages of these options.

Dataset Types

The most important argument of DataLoaderconstructor is dataset, which indicates a dataset object to load data from. PyTorch supports two different types of datasets:

Map-style datasets

A map-style dataset is one that implements the __getitem__() and__len__() protocols, and represents a map from (possibly non-integral) indices/keys to data samples.

For example, such a dataset, when accessed with dataset[idx], could read the idx-th image and its corresponding label from a folder on the disk.

See Dataset for more details.

Iterable-style datasets

An iterable-style dataset is an instance of a subclass of IterableDatasetthat implements the __iter__() protocol, and represents an iterable over data samples. This type of datasets is particularly suitable for cases where random reads are expensive or even improbable, and where the batch size depends on the fetched data.

For example, such a dataset, when called iter(dataset), could return a stream of data reading from a database, a remote server, or even logs generated in real time.

See IterableDataset for more details.

Note

When using a IterableDataset withmulti-process data loading. The same dataset object is replicated on each worker process, and thus the replicas must be configured differently to avoid duplicated data. SeeIterableDataset documentations for how to achieve this.

Data Loading Order and Sampler

For iterable-style datasets, data loading order is entirely controlled by the user-defined iterable. This allows easier implementations of chunk-reading and dynamic batch size (e.g., by yielding a batched sample at each time).

The rest of this section concerns the case withmap-style datasets. torch.utils.data.Samplerclasses are used to specify the sequence of indices/keys used in data loading. They represent iterable objects over the indices to datasets. E.g., in the common case with stochastic gradient decent (SGD), aSampler could randomly permute a list of indices and yield each one at a time, or yield a small number of them for mini-batch SGD.

A sequential or shuffled sampler will be automatically constructed based on the shuffle argument to a DataLoader. Alternatively, users may use the sampler argument to specify a custom Sampler object that at each time yields the next index/key to fetch.

A custom Sampler that yields a list of batch indices at a time can be passed as the batch_sampler argument. Automatic batching can also be enabled via batch_size anddrop_last arguments. Seethe next section for more details on this.

Note

Neither sampler nor batch_sampler is compatible with iterable-style datasets, since such datasets have no notion of a key or an index.

Loading Batched and Non-Batched Data

DataLoader supports automatically collating individual fetched data samples into batches via argumentsbatch_size, drop_last, batch_sampler, andcollate_fn (which has a default function).

Automatic batching (default)

This is the most common case, and corresponds to fetching a minibatch of data and collating them into batched samples, i.e., containing Tensors with one dimension being the batch dimension (usually the first).

When batch_size (default 1) is not None, the data loader yields batched samples instead of individual samples. batch_size anddrop_last arguments are used to specify how the data loader obtains batches of dataset keys. For map-style datasets, users can alternatively specify batch_sampler, which yields a list of keys at a time.

Note

The batch_size and drop_last arguments essentially are used to construct a batch_sampler from sampler. For map-style datasets, the sampler is either provided by user or constructed based on the shuffle argument. For iterable-style datasets, thesampler is a dummy infinite one. Seethis section on more details on samplers.

Note

When fetching fromiterable-style datasets withmulti-processing, the drop_lastargument drops the last non-full batch of each worker’s dataset replica.

After fetching a list of samples using the indices from sampler, the function passed as the collate_fn argument is used to collate lists of samples into batches.

In this case, loading from a map-style dataset is roughly equivalent with:

for indices in batch_sampler: yield collate_fn([dataset[i] for i in indices])

and loading from an iterable-style dataset is roughly equivalent with:

dataset_iter = iter(dataset) for indices in batch_sampler: yield collate_fn([next(dataset_iter) for _ in indices])

A custom collate_fn can be used to customize collation, e.g., padding sequential data to max length of a batch. Seethis section on more about collate_fn.

Disable automatic batching

In certain cases, users may want to handle batching manually in dataset code, or simply load individual samples. For example, it could be cheaper to directly load batched data (e.g., bulk reads from a database or reading continuous chunks of memory), or the batch size is data dependent, or the program is designed to work on individual samples. Under these scenarios, it’s likely better to not use automatic batching (where collate_fn is used to collate the samples), but let the data loader directly return each member of the dataset object.

When both batch_size and batch_sampler are None (default value for batch_sampler is already None), automatic batching is disabled. Each sample obtained from the dataset is processed with the function passed as the collate_fn argument.

When automatic batching is disabled, the default collate_fn simply converts NumPy arrays into PyTorch Tensors, and keeps everything else untouched.

In this case, loading from a map-style dataset is roughly equivalent with:

for index in sampler: yield collate_fn(dataset[index])

and loading from an iterable-style dataset is roughly equivalent with:

for data in iter(dataset): yield collate_fn(data)

See this section on more about collate_fn.

Working with collate_fn

The use of collate_fn is slightly different when automatic batching is enabled or disabled.

When automatic batching is disabled, collate_fn is called with each individual data sample, and the output is yielded from the data loader iterator. In this case, the default collate_fn simply converts NumPy arrays in PyTorch tensors.

When automatic batching is enabled, collate_fn is called with a list of data samples at each time. It is expected to collate the input samples into a batch for yielding from the data loader iterator. The rest of this section describes the behavior of the default collate_fn(default_collate()).

For instance, if each data sample consists of a 3-channel image and an integral class label, i.e., each element of the dataset returns a tuple(image, class_index), the default collate_fn collates a list of such tuples into a single tuple of a batched image tensor and a batched class label Tensor. In particular, the default collate_fn has the following properties:

Users may use customized collate_fn to achieve custom batching, e.g., collating along a dimension other than the first, padding sequences of various lengths, or adding support for custom data types.

If you run into a situation where the outputs of DataLoaderhave dimensions or type that is different from your expectation, you may want to check your collate_fn.

Single- and Multi-process Data Loading

A DataLoader uses single-process data loading by default.

Within a Python process, theGlobal Interpreter Lock (GIL)prevents true fully parallelizing Python code across threads. To avoid blocking computation code with data loading, PyTorch provides an easy switch to perform multi-process data loading by simply setting the argument num_workersto a positive integer.

Single-process data loading (default)

In this mode, data fetching is done in the same process aDataLoader is initialized. Therefore, data loading may block computing. However, this mode may be preferred when resource(s) used for sharing data among processes (e.g., shared memory, file descriptors) is limited, or when the entire dataset is small and can be loaded entirely in memory. Additionally, single-process loading often shows more readable error traces and thus is useful for debugging.

Multi-process data loading

Setting the argument num_workers as a positive integer will turn on multi-process data loading with the specified number of loader worker processes.

Warning

After several iterations, the loader worker processes will consume the same amount of CPU memory as the parent process for all Python objects in the parent process which are accessed from the worker processes. This can be problematic if the Dataset contains a lot of data (e.g., you are loading a very large list of filenames at Dataset construction time) and/or you are using a lot of workers (overall memory usage is number of workers * size of parent process). The simplest workaround is to replace Python objects with non-refcounted representations such as Pandas, Numpy or PyArrow objects. Check outissue #13246for more details on why this occurs and example code for how to workaround these problems.

In this mode, each time an iterator of a DataLoaderis created (e.g., when you call enumerate(dataloader)), num_workersworker processes are created. At this point, the dataset,collate_fn, and worker_init_fn are passed to each worker, where they are used to initialize, and fetch data. This means that dataset access together with its internal IO, transforms (including collate_fn) runs in the worker process.

torch.utils.data.get_worker_info() returns various useful information in a worker process (including the worker id, dataset replica, initial seed, etc.), and returns None in main process. Users may use this function in dataset code and/or worker_init_fn to individually configure each dataset replica, and to determine whether the code is running in a worker process. For example, this can be particularly helpful in sharding the dataset.

For map-style datasets, the main process generates the indices usingsampler and sends them to the workers. So any shuffle randomization is done in the main process which guides loading by assigning indices to load.

For iterable-style datasets, since each worker process gets a replica of thedataset object, naive multi-process loading will often result in duplicated data. Using torch.utils.data.get_worker_info() and/orworker_init_fn, users may configure each replica independently. (SeeIterableDataset documentations for how to achieve this. ) For similar reasons, in multi-process loading, the drop_lastargument drops the last non-full batch of each worker’s iterable-style dataset replica.

Workers are shut down once the end of the iteration is reached, or when the iterator becomes garbage collected.

Warning

It is generally not recommended to return CUDA tensors in multi-process loading because of many subtleties in using CUDA and sharing CUDA tensors in multiprocessing (see CUDA in multiprocessing). Instead, we recommend using automatic memory pinning (i.e., settingpin_memory=True), which enables fast data transfer to CUDA-enabled GPUs.

Platform-specific behaviors

Since workers rely on Python multiprocessing, worker launch behavior is different on Windows compared to Unix.

This separate serialization means that you should take two steps to ensure you are compatible with Windows while using multi-process data loading:

Randomness in multi-process data loading

By default, each worker will have its PyTorch seed set to base_seed + worker_id, where base_seed is a long generated by main process using its RNG (thereby, consuming a RNG state mandatorily) or a specified generator. However, seeds for other libraries may be duplicated upon initializing workers, causing each worker to return identical random numbers. (See this section in FAQ.).

In worker_init_fn, you may access the PyTorch seed set for each worker with either torch.utils.data.get_worker_info().seedor torch.initial_seed(), and use it to seed other libraries before data loading.

Memory Pinning

Host to GPU copies are much faster when they originate from pinned (page-locked) memory. See Use pinned memory buffers for more details on when and how to use pinned memory generally.

For data loading, passing pin_memory=True to aDataLoader will automatically put the fetched data Tensors in pinned memory, and thus enables faster data transfer to CUDA-enabled GPUs.

The default memory pinning logic only recognizes Tensors and maps and iterables containing Tensors. By default, if the pinning logic sees a batch that is a custom type (which will occur if you have a collate_fn that returns a custom batch type), or if each element of your batch is a custom type, the pinning logic will not recognize them, and it will return that batch (or those elements) without pinning the memory. To enable memory pinning for custom batch or data type(s), define a pin_memory() method on your custom type(s).

See the example below.

Example:

class SimpleCustomBatch: def init(self, data): transposed_data = list(zip(*data)) self.inp = torch.stack(transposed_data[0], 0) self.tgt = torch.stack(transposed_data[1], 0)

# custom memory pinning method on custom type
def pin_memory(self):
    self.inp = self.inp.pin_memory()
    self.tgt = self.tgt.pin_memory()
    return self

def collate_wrapper(batch): return SimpleCustomBatch(batch)

inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5) tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5) dataset = TensorDataset(inps, tgts)

loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper, pin_memory=True)

for batch_ndx, sample in enumerate(loader): print(sample.inp.is_pinned()) print(sample.tgt.is_pinned())

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=None, persistent_workers=False, pin_memory_device='', in_order=True)[source][source]

Data loader combines a dataset and a sampler, and provides an iterable over the given dataset.

The DataLoader supports both map-style and iterable-style datasets with single- or multi-process loading, customizing loading order and optional automatic batching (collation) and memory pinning.

See torch.utils.data documentation page for more details.

Parameters

Warning

If the spawn start method is used, worker_init_fncannot be an unpicklable object, e.g., a lambda function. SeeMultiprocessing best practices on more details related to multiprocessing in PyTorch.

Warning

len(dataloader) heuristic is based on the length of the sampler used. When dataset is an IterableDataset, it instead returns an estimate based on len(dataset) / batch_size, with proper rounding depending on drop_last, regardless of multi-process loading configurations. This represents the best guess PyTorch can make because PyTorch trusts user dataset code in correctly handling multi-process loading to avoid duplicate data.

However, if sharding results in multiple workers having incomplete last batches, this estimate can still be inaccurate, because (1) an otherwise complete batch can be broken into multiple ones and (2) more than one batch worth of samples can be dropped when drop_last is set. Unfortunately, PyTorch can not detect such cases in general.

See Dataset Types for more details on these two types of datasets and howIterableDataset interacts withMulti-process data loading.

Warning

Setting in_order to False can harm reproducibility and may lead to a skewed data distribution being fed to the trainer in cases with imbalanced data.

class torch.utils.data.Dataset[source][source]

An abstract class representing a Dataset.

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite __getitem__(), supporting fetching a data sample for a given key. Subclasses could also optionally overwrite__len__(), which is expected to return the size of the dataset by manySampler implementations and the default options of DataLoader. Subclasses could also optionally implement __getitems__(), for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples.

Note

DataLoader by default constructs an index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.

class torch.utils.data.IterableDataset[source][source]

An iterable Dataset.

All datasets that represent an iterable of data samples should subclass it. Such form of datasets is particularly useful when data come from a stream.

All subclasses should overwrite __iter__(), which would return an iterator of samples in this dataset.

When a subclass is used with DataLoader, each item in the dataset will be yielded from the DataLoaderiterator. When num_workers > 0, each worker process will have a different copy of the dataset object, so it is often desired to configure each copy independently to avoid having duplicate data returned from the workers. get_worker_info(), when called in a worker process, returns information about the worker. It can be used in either the dataset’s __iter__() method or the DataLoader ‘sworker_init_fn option to modify each copy’s behavior.

Example 1: splitting workload across all workers in __iter__():

class MyIterableDataset(torch.utils.data.IterableDataset): ... def init(self, start, end): ... super(MyIterableDataset).init() ... assert end > start, "this example code only works with end >= start" ... self.start = start ... self.end = end ... ... def iter(self): ... worker_info = torch.utils.data.get_worker_info() ... if worker_info is None: # single-process data loading, return the full iterator ... iter_start = self.start ... iter_end = self.end ... else: # in a worker process ... # split workload ... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... iter_start = self.start + worker_id * per_worker ... iter_end = min(iter_start + per_worker, self.end) ... return iter(range(iter_start, iter_end)) ...

should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].

ds = MyIterableDataset(start=3, end=7)

Single-process loading

print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [tensor([3]), tensor([4]), tensor([5]), tensor([6])]

Multi-process loading with two worker processes

Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].

print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [tensor([3]), tensor([5]), tensor([4]), tensor([6])]

With even more workers

print(list(torch.utils.data.DataLoader(ds, num_workers=12))) [tensor([3]), tensor([5]), tensor([4]), tensor([6])]

Example 2: splitting workload across all workers using worker_init_fn:

class MyIterableDataset(torch.utils.data.IterableDataset): ... def init(self, start, end): ... super(MyIterableDataset).init() ... assert end > start, "this example code only works with end >= start" ... self.start = start ... self.end = end ... ... def iter(self): ... return iter(range(self.start, self.end)) ...

should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].

ds = MyIterableDataset(start=3, end=7)

Single-process loading

print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [3, 4, 5, 6]

Directly doing multi-process loading yields duplicate data

print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [3, 3, 4, 4, 5, 5, 6, 6]

Define a worker_init_fn that configures each dataset copy differently

def worker_init_fn(worker_id): ... worker_info = torch.utils.data.get_worker_info() ... dataset = worker_info.dataset # the dataset copy in this worker process ... overall_start = dataset.start ... overall_end = dataset.end ... # configure the dataset to only process the split workload ... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... dataset.start = overall_start + worker_id * per_worker ... dataset.end = min(dataset.start + per_worker, overall_end) ...

Mult-process loading with the custom worker_init_fn

Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].

print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn))) [3, 5, 4, 6]

With even more workers

print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn))) [3, 4, 5, 6]

class torch.utils.data.TensorDataset(*tensors)[source][source]

Dataset wrapping tensors.

Each sample will be retrieved by indexing tensors along the first dimension.

Parameters

*tensors (Tensor) – tensors that have the same size of the first dimension.

class torch.utils.data.StackDataset(*args, **kwargs)[source][source]

Dataset as a stacking of multiple datasets.

This class is useful to assemble different parts of complex input data, given as datasets.

Example

images = ImageDataset() texts = TextDataset() tuple_stack = StackDataset(images, texts) tuple_stack[0] == (images[0], texts[0]) dict_stack = StackDataset(image=images, text=texts) dict_stack[0] == {'image': images[0], 'text': texts[0]}

Parameters

class torch.utils.data.ConcatDataset(datasets)[source][source]

Dataset as a concatenation of multiple datasets.

This class is useful to assemble different existing datasets.

Parameters

datasets (sequence) – List of datasets to be concatenated

class torch.utils.data.ChainDataset(datasets)[source][source]

Dataset for chaining multiple IterableDataset s.

This class is useful to assemble different existing dataset streams. The chaining operation is done on-the-fly, so concatenating large-scale datasets with this class will be efficient.

Parameters

datasets (iterable of IterableDataset) – datasets to be chained together

class torch.utils.data.Subset(dataset, indices)[source][source]

Subset of a dataset at specified indices.

Parameters

torch.utils.data._utils.collate.collate(batch, *, collate_fn_map=None)[source][source]

General collate function that handles collection type of element within each batch.

The function also opens function registry to deal with specific element types. default_collate_fn_mapprovides default collate functions for tensors, numpy arrays, numbers and strings.

Parameters

Examples

def collate_tensor_fn(batch, *, collate_fn_map): ... # Extend this function to handle batch of tensors ... return torch.stack(batch, 0) def custom_collate(batch): ... collate_map = {torch.Tensor: collate_tensor_fn} ... return collate(batch, collate_fn_map=collate_map)

Extend default_collate by in-place modifying default_collate_fn_map

default_collate_fn_map.update({torch.Tensor: collate_tensor_fn})

Note

Each collate function requires a positional argument for batch and a keyword argument for the dictionary of collate functions as collate_fn_map.

torch.utils.data.default_collate(batch)[source][source]

Take in a batch of data and put the elements within the batch into a tensor with an additional outer dimension - batch size.

The exact output type can be a torch.Tensor, a Sequence of torch.Tensor, a Collection of torch.Tensor, or left unchanged, depending on the input type. This is used as the default function for collation whenbatch_size or batch_sampler is defined in DataLoader.

Here is the general input type (based on the type of the element within the batch) to output type mapping:

Parameters

batch – a single batch to be collated

Examples

Example with a batch of ints:

default_collate([0, 1, 2, 3]) tensor([0, 1, 2, 3])

Example with a batch of strs:

default_collate(['a', 'b', 'c']) ['a', 'b', 'c']

Example with Map inside the batch:

default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}]) {'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])}

Example with NamedTuple inside the batch:

Point = namedtuple('Point', ['x', 'y']) default_collate([Point(0, 0), Point(1, 1)]) Point(x=tensor([0, 1]), y=tensor([0, 1]))

Example with Tuple inside the batch:

default_collate([(0, 1), (2, 3)]) [tensor([0, 2]), tensor([1, 3])]

Example with List inside the batch:

default_collate([[0, 1], [2, 3]]) [tensor([0, 2]), tensor([1, 3])]

Two options to extend default_collate to handle specific type

Option 1: Write custom collate function and invoke default_collate

def custom_collate(batch): ... elem = batch[0] ... if isinstance(elem, CustomType): # Some custom condition ... return ... ... else: # Fall back to default_collate ... return default_collate(batch)

Option 2: In-place modify default_collate_fn_map

def collate_customtype_fn(batch, *, collate_fn_map=None): ... return ... default_collate_fn_map.update(CustomType, collate_customtype_fn) default_collate(batch) # Handle CustomType automatically

torch.utils.data.default_convert(data)[source][source]

Convert each NumPy array element into a torch.Tensor.

If the input is a Sequence, Collection, or Mapping, it tries to convert each element inside to a torch.Tensor. If the input is not an NumPy array, it is left unchanged. This is used as the default function for collation when both batch_sampler and batch_sizeare NOT defined in DataLoader.

The general input type to output type mapping is similar to that of default_collate(). See the description there for more details.

Parameters

data – a single data point to be converted

Examples

Example with int

default_convert(0) 0

Example with NumPy array

default_convert(np.array([0, 1])) tensor([0, 1])

Example with NamedTuple

Point = namedtuple('Point', ['x', 'y']) default_convert(Point(0, 0)) Point(x=0, y=0) default_convert(Point(np.array(0), np.array(0))) Point(x=tensor(0), y=tensor(0))

Example with List

default_convert([np.array([0, 1]), np.array([2, 3])]) [tensor([0, 1]), tensor([2, 3])]

torch.utils.data.get_worker_info()[source][source]

Returns the information about the currentDataLoader iterator worker process.

When called in a worker, this returns an object guaranteed to have the following attributes:

When called in the main process, this returns None.

Note

When used in a worker_init_fn passed over toDataLoader, this method can be useful to set up each worker process differently, for instance, using worker_idto configure the dataset object to only read a specific fraction of a sharded dataset, or use seed to seed other libraries used in dataset code.

Return type

Optional[_WorkerInfo_]

torch.utils.data.random_split(dataset, lengths, generator=<torch._C.Generator object>)[source][source]

Randomly split a dataset into non-overlapping new datasets of given lengths.

If a list of fractions that sum up to 1 is given, the lengths will be computed automatically as floor(frac * len(dataset)) for each fraction provided.

After computing the lengths, if there are any remainders, 1 count will be distributed in round-robin fashion to the lengths until there are no remainders left.

Optionally fix the generator for reproducible results, e.g.:

Example

generator1 = torch.Generator().manual_seed(42) generator2 = torch.Generator().manual_seed(42) random_split(range(10), [3, 7], generator=generator1) random_split(range(30), [0.3, 0.3, 0.4], generator=generator2)

Parameters

Return type

list[torch.utils.data.dataset.Subset[~_T]]

class torch.utils.data.Sampler(data_source=None)[source][source]

Base class for all Samplers.

Every Sampler subclass has to provide an __iter__() method, providing a way to iterate over indices or lists of indices (batches) of dataset elements, and may provide a __len__() method that returns the length of the returned iterators.

Parameters

data_source (Dataset) – This argument is not used and will be removed in 2.2.0. You may still have custom implementation that utilizes it.

Example

class AccedingSequenceLengthSampler(Sampler[int]): def init(self, data: List[str]) -> None: self.data = data

def __len__(self) -> int:
    return len(self.data)

def __iter__(self) -> Iterator[int]:
    sizes = torch.tensor([len(x) for x in self.data])
    yield from torch.argsort(sizes).tolist()

class AccedingSequenceLengthBatchSampler(Sampler[List[int]]): def init(self, data: List[str], batch_size: int) -> None: self.data = data self.batch_size = batch_size

def __len__(self) -> int:
    return (len(self.data) + self.batch_size - 1) // self.batch_size

def __iter__(self) -> Iterator[List[int]]:
    sizes = torch.tensor([len(x) for x in self.data])
    for batch in torch.chunk(torch.argsort(sizes), len(self)):
        yield batch.tolist()

Note

The __len__() method isn’t strictly required byDataLoader, but is expected in any calculation involving the length of a DataLoader.

class torch.utils.data.SequentialSampler(data_source)[source][source]

Samples elements sequentially, always in the same order.

Parameters

data_source (Dataset) – dataset to sample from

class torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None, generator=None)[source][source]

Samples elements randomly. If without replacement, then sample from a shuffled dataset.

If with replacement, then user can specify num_samples to draw.

Parameters

class torch.utils.data.SubsetRandomSampler(indices, generator=None)[source][source]

Samples elements randomly from a given list of indices, without replacement.

Parameters

class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True, generator=None)[source][source]

Samples elements from [0,..,len(weights)-1] with given probabilities (weights).

Parameters

Example

list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True)) [4, 4, 1, 4, 5] list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False)) [0, 1, 4, 3, 2]

class torch.utils.data.BatchSampler(sampler, batch_size, drop_last)[source][source]

Wraps another sampler to yield a mini-batch of indices.

Parameters

Example

list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) [[0, 1, 2], [3, 4, 5], [6, 7, 8]]

class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False)[source][source]

Sampler that restricts data loading to a subset of the dataset.

It is especially useful in conjunction withtorch.nn.parallel.DistributedDataParallel. In such a case, each process can pass a DistributedSampler instance as aDataLoader sampler, and load a subset of the original dataset that is exclusive to it.

Note

Dataset is assumed to be of constant size and that any instance of it always returns the same elements in the same order.

Parameters

Warning

In distributed mode, calling the set_epoch() method at the beginning of each epoch before creating the DataLoader iterator is necessary to make shuffling work properly across multiple epochs. Otherwise, the same ordering will be always used.

Example:

sampler = DistributedSampler(dataset) if is_distributed else None loader = DataLoader(dataset, shuffle=(sampler is None), ... sampler=sampler) for epoch in range(start_epoch, n_epochs): ... if is_distributed: ... sampler.set_epoch(epoch) ... train(loader)