CUDA semantics — PyTorch 2.7 documentation (original) (raw)

torch.cuda is used to set up and run CUDA operations. It keeps track of the currently selected GPU, and all CUDA tensors you allocate will by default be created on that device. The selected device can be changed with atorch.cuda.device context manager.

However, once a tensor is allocated, you can do operations on it irrespective of the selected device, and the results will be always placed on the same device as the tensor.

Cross-GPU operations are not allowed by default, with the exception ofcopy_() and other methods with copy-like functionality such as to() and cuda(). Unless you enable peer-to-peer memory access, any attempts to launch ops on tensors spread across different devices will raise an error.

Below you can find a small example showcasing this:

cuda = torch.device('cuda') # Default CUDA device cuda0 = torch.device('cuda:0') cuda2 = torch.device('cuda:2') # GPU 2 (these are 0-indexed)

x = torch.tensor([1., 2.], device=cuda0)

x.device is device(type='cuda', index=0)

y = torch.tensor([1., 2.]).cuda()

y.device is device(type='cuda', index=0)

with torch.cuda.device(1): # allocates a tensor on GPU 1 a = torch.tensor([1., 2.], device=cuda)

# transfers a tensor from CPU to GPU 1
b = torch.tensor([1., 2.]).cuda()
# a.device and b.device are device(type='cuda', index=1)

# You can also use ``Tensor.to`` to transfer a tensor:
b2 = torch.tensor([1., 2.]).to(device=cuda)
# b.device and b2.device are device(type='cuda', index=1)

c = a + b
# c.device is device(type='cuda', index=1)

z = x + y
# z.device is device(type='cuda', index=0)

# even within a context, you can specify the device
# (or give a GPU index to the .cuda call)
d = torch.randn(2, device=cuda2)
e = torch.randn(2).to(cuda2)
f = torch.randn(2).cuda(cuda2)
# d.device, e.device, and f.device are all device(type='cuda', index=2)

TensorFloat-32 (TF32) on Ampere (and later) devices

Starting in PyTorch 1.7, there is a new flag called allow_tf32. This flag defaults to True in PyTorch 1.7 to PyTorch 1.11, and False in PyTorch 1.12 and later. This flag controls whether PyTorch is allowed to use the TensorFloat32 (TF32) tensor cores, available on NVIDIA GPUs since Ampere, internally to compute matmul (matrix multiplies and batched matrix multiplies) and convolutions.

TF32 tensor cores are designed to achieve better performance on matmul and convolutions ontorch.float32 tensors by rounding input data to have 10 bits of mantissa, and accumulating results with FP32 precision, maintaining FP32 dynamic range.

matmuls and convolutions are controlled separately, and their corresponding flags can be accessed at:

The flag below controls whether to allow TF32 on matmul. This flag defaults to False

in PyTorch 1.12 and later.

torch.backends.cuda.matmul.allow_tf32 = True

The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.

torch.backends.cudnn.allow_tf32 = True

The precision of matmuls can also be set more broadly (limited not just to CUDA) via set_float_32_matmul_precision(). Note that besides matmuls and convolutions themselves, functions and nn modules that internally uses matmuls or convolutions are also affected. These include nn.Linear, nn.Conv*, cdist, tensordot, affine grid and grid sample, adaptive log softmax, GRU and LSTM.

To get an idea of the precision and speed, see the example code and benchmark data (on A100) below:

a_full = torch.randn(10240, 10240, dtype=torch.double, device='cuda') b_full = torch.randn(10240, 10240, dtype=torch.double, device='cuda') ab_full = a_full @ b_full mean = ab_full.abs().mean() # 80.7277

a = a_full.float() b = b_full.float()

Do matmul at TF32 mode.

torch.backends.cuda.matmul.allow_tf32 = True ab_tf32 = a @ b # takes 0.016s on GA100 error = (ab_tf32 - ab_full).abs().max() # 0.1747 relative_error = error / mean # 0.0022

Do matmul with TF32 disabled.

torch.backends.cuda.matmul.allow_tf32 = False ab_fp32 = a @ b # takes 0.11s on GA100 error = (ab_fp32 - ab_full).abs().max() # 0.0031 relative_error = error / mean # 0.000039

From the above example, we can see that with TF32 enabled, the speed is ~7x faster on A100, and that relative error compared to double precision is approximately 2 orders of magnitude larger. Note that the exact ratio of TF32 to single precision speed depends on the hardware generation, as properties such as the ratio of memory bandwidth to compute as well as the ratio of TF32 to FP32 matmul throughput may vary from generation to generation or model to model. If full FP32 precision is needed, users can disable TF32 by:

torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False

To toggle the TF32 flags off in C++, you can do

at::globalContext().setAllowTF32CuBLAS(false); at::globalContext().setAllowTF32CuDNN(false);

For more information about TF32, see:

Reduced Precision Reduction in FP16 GEMMs

(Distinct from full FP16 accumulation that is intended for hardware that has higher throughput with FP16 accumulation than FP32 accumulation, see Full FP16 accumulation)

fp16 GEMMs are potentially done with some intermediate reduced precision reductions (e.g., in fp16 rather than fp32). These selective reductions in precision can allow for higher performance on certain workloads (particularly those with a large k dimension) and GPU architectures at the cost of numerical precision and potential for overflow.

Some example benchmark data on V100:

[--------------------------- bench_gemm_transformer --------------------------] [ m , k , n ] | allow_fp16_reduc=True | allow_fp16_reduc=False 1 threads: -------------------------------------------------------------------- [4096, 4048, 4096] | 1634.6 | 1639.8 [4096, 4056, 4096] | 1670.8 | 1661.9 [4096, 4080, 4096] | 1664.2 | 1658.3 [4096, 4096, 4096] | 1639.4 | 1651.0 [4096, 4104, 4096] | 1677.4 | 1674.9 [4096, 4128, 4096] | 1655.7 | 1646.0 [4096, 4144, 4096] | 1796.8 | 2519.6 [4096, 5096, 4096] | 2094.6 | 3190.0 [4096, 5104, 4096] | 2144.0 | 2663.5 [4096, 5112, 4096] | 2149.1 | 2766.9 [4096, 5120, 4096] | 2142.8 | 2631.0 [4096, 9728, 4096] | 3875.1 | 5779.8 [4096, 16384, 4096] | 6182.9 | 9656.5 (times in microseconds).

If full precision reductions are needed, users can disable reduced precision reductions in fp16 GEMMs with:

torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False

To toggle the reduced precision reduction flags in C++, one can do

at::globalContext().setAllowFP16ReductionCuBLAS(false);

Reduced Precision Reduction in BF16 GEMMs

A similar flag (as above) exists for BFloat16 GEMMs. Note that this switch is set to True by default for BF16, if you observe numerical instability in your workload, you may wish to set it to False.

If reduced precision reductions are not desired, users can disable reduced precision reductions in bf16 GEMMs with:

torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False

To toggle the reduced precision reduction flags in C++, one can do

at::globalContext().setAllowBF16ReductionCuBLAS(true);

Full FP16 Accmumulation in FP16 GEMMs

Certain GPUs have increased performance when doing _all_ FP16 GEMM accumulation in FP16, at the cost of numerical precision and greater likelihood of overflow. Note that this setting only has an effect on GPUs of compute capability 7.0 (Volta) or newer.

This behavior can be enabled via:

torch.backends.cuda.matmul.allow_fp16_accumulation = True

To toggle the reduced precision reduction flags in C++, one can do

at::globalContext().setAllowFP16AccumulationCuBLAS(true);

Asynchronous execution

By default, GPU operations are asynchronous. When you call a function that uses the GPU, the operations are enqueued to the particular device, but not necessarily executed until later. This allows us to execute more computations in parallel, including operations on CPU or other GPUs.

In general, the effect of asynchronous computation is invisible to the caller, because (1) each device executes operations in the order they are queued, and (2) PyTorch automatically performs necessary synchronization when copying data between CPU and GPU or between two GPUs. Hence, computation will proceed as if every operation was executed synchronously.

You can force synchronous computation by setting environment variableCUDA_LAUNCH_BLOCKING=1. This can be handy when an error occurs on the GPU. (With asynchronous execution, such an error isn’t reported until after the operation is actually executed, so the stack trace does not show where it was requested.)

A consequence of the asynchronous computation is that time measurements without synchronizations are not accurate. To get precise measurements, one should either call torch.cuda.synchronize() before measuring, or use torch.cuda.Eventto record times as following:

start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record()

Run some things here

end_event.record() torch.cuda.synchronize() # Wait for the events to be recorded! elapsed_time_ms = start_event.elapsed_time(end_event)

As an exception, several functions such as to() andcopy_() admit an explicit non_blocking argument, which lets the caller bypass synchronization when it is unnecessary. Another exception is CUDA streams, explained below.

CUDA streams

A CUDA stream is a linear sequence of execution that belongs to a specific device. You normally do not need to create one explicitly: by default, each device uses its own “default” stream.

Operations inside each stream are serialized in the order they are created, but operations from different streams can execute concurrently in any relative order, unless explicit synchronization functions (such assynchronize() or wait_stream()) are used. For example, the following code is incorrect:

cuda = torch.device('cuda') s = torch.cuda.Stream() # Create a new stream. A = torch.empty((100, 100), device=cuda).normal_(0.0, 1.0) with torch.cuda.stream(s): # sum() may start execution before normal_() finishes! B = torch.sum(A)

When the “current stream” is the default stream, PyTorch automatically performs necessary synchronization when data is moved around, as explained above. However, when using non-default streams, it is the user’s responsibility to ensure proper synchronization. The fixed version of this example is:

cuda = torch.device('cuda') s = torch.cuda.Stream() # Create a new stream. A = torch.empty((100, 100), device=cuda).normal_(0.0, 1.0) s.wait_stream(torch.cuda.default_stream(cuda)) # NEW! with torch.cuda.stream(s): B = torch.sum(A) A.record_stream(s) # NEW!

There are two new additions. The torch.cuda.Stream.wait_stream() call ensures that the normal_() execution has finished before we start runningsum(A) on a side stream. The torch.Tensor.record_stream() (see for more details) ensures that we do not deallocate A before sum(A) has completed. You can also manually wait on the stream at some later point in time with torch.cuda.default_stream(cuda).wait_stream(s) (note that it is pointless to wait immediately, since that will prevent the stream execution from running in parallel with other work on the default stream.) See the documentation for torch.Tensor.record_stream() on more details on when to use one or another.

Note that this synchronization is necessary even when there is no read dependency, e.g., as seen in this example:

cuda = torch.device('cuda') s = torch.cuda.Stream() # Create a new stream. A = torch.empty((100, 100), device=cuda) s.wait_stream(torch.cuda.default_stream(cuda)) # STILL REQUIRED! with torch.cuda.stream(s): A.normal_(0.0, 1.0) A.record_stream(s)

Despite the computation on s not reading the contents of A and no other uses of A, it is still necessary to synchronize, because Amay correspond to memory reallocated by the CUDA caching allocator, with pending operations from the old (deallocated) memory.

Stream semantics of backward passes

Each backward CUDA op runs on the same stream that was used for its corresponding forward op. If your forward pass runs independent ops in parallel on different streams, this helps the backward pass exploit that same parallelism.

The stream semantics of a backward call with respect to surrounding ops are the same as for any other call. The backward pass inserts internal syncs to ensure this even when backward ops run on multiple streams as described in the previous paragraph. More concretely, when callingautograd.backward,autograd.grad, ortensor.backward, and optionally supplying CUDA tensor(s) as the initial gradient(s) (e.g.,autograd.backward(..., grad_tensors=initial_grads),autograd.grad(..., grad_outputs=initial_grads), ortensor.backward(..., gradient=initial_grad)), the acts of

  1. optionally populating initial gradient(s),
  2. invoking the backward pass, and
  3. using the gradients

have the same stream-semantics relationship as any group of ops:

s = torch.cuda.Stream()

Safe, grads are used in the same stream context as backward()

with torch.cuda.stream(s): loss.backward() use grads

Unsafe

with torch.cuda.stream(s): loss.backward() use grads

Safe, with synchronization

with torch.cuda.stream(s): loss.backward() torch.cuda.current_stream().wait_stream(s) use grads

Safe, populating initial grad and invoking backward are in the same stream context

with torch.cuda.stream(s): loss.backward(gradient=torch.ones_like(loss))

Unsafe, populating initial_grad and invoking backward are in different stream contexts,

without synchronization

initial_grad = torch.ones_like(loss) with torch.cuda.stream(s): loss.backward(gradient=initial_grad)

Safe, with synchronization

initial_grad = torch.ones_like(loss) s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): initial_grad.record_stream(s) loss.backward(gradient=initial_grad)

BC note: Using grads on the default stream

In prior versions of PyTorch (1.9 and earlier), the autograd engine always synced the default stream with all backward ops, so the following pattern:

with torch.cuda.stream(s): loss.backward() use grads

was safe as long as use grads happened on the default stream. In present PyTorch, that pattern is no longer safe. If backward()and use grads are in different stream contexts, you must sync the streams:

with torch.cuda.stream(s): loss.backward() torch.cuda.current_stream().wait_stream(s) use grads

even if use grads is on the default stream.

Memory management

PyTorch uses a caching memory allocator to speed up memory allocations. This allows fast memory deallocation without device synchronizations. However, the unused memory managed by the allocator will still show as if used innvidia-smi. You can use memory_allocated() andmax_memory_allocated() to monitor memory occupied by tensors, and use memory_reserved() andmax_memory_reserved() to monitor the total amount of memory managed by the caching allocator. Calling empty_cache()releases all unused cached memory from PyTorch so that those can be used by other GPU applications. However, the occupied GPU memory by tensors will not be freed so it can not increase the amount of GPU memory available for PyTorch.

To better understand how CUDA memory is being used over time,Understanding CUDA Memory Usage describes tools for capturing and visualizing traces of memory use.

For more advanced users, we offer more comprehensive memory benchmarking viamemory_stats(). We also offer the capability to capture a complete snapshot of the memory allocator state viamemory_snapshot(), which can help you understand the underlying allocation patterns produced by your code.

Optimizing memory usage with PYTORCH_CUDA_ALLOC_CONF

Use of a caching allocator can interfere with memory checking tools such ascuda-memcheck. To debug memory errors using cuda-memcheck, setPYTORCH_NO_CUDA_MEMORY_CACHING=1 in your environment to disable caching.

The behavior of the caching allocator can be controlled via the environment variablePYTORCH_CUDA_ALLOC_CONF. The format is PYTORCH_CUDA_ALLOC_CONF=<option>:<value>,<option2>:<value2>...Available options:

Note

Some stats reported by theCUDA memory management APIare specific to backend:native, and are not meaningful withbackend:cudaMallocAsync. See each function’s docstring for details.

Using custom memory allocators for CUDA

It is possible to define allocators as simple functions in C/C++ and compile them as a shared library, the code below shows a basic allocator that just traces all the memory operations.

#include <sys/types.h> #include <cuda_runtime_api.h> #include // Compile with g++ alloc.cc -o alloc.so -I/usr/local/cuda/include -shared -fPIC extern "C" { void* my_malloc(ssize_t size, int device, cudaStream_t stream) { void *ptr; cudaMalloc(&ptr, size); std::cout<<"alloc "<<ptr<<size<<std::endl; return ptr; }

void my_free(void* ptr, ssize_t size, int device, cudaStream_t stream) { std::cout<<"free "<<ptr<< " "<<stream<<std::endl; cudaFree(ptr); } }

This can be used in python through the torch.cuda.memory.CUDAPluggableAllocator. The user is responsible for supplying the path to the .so file and the name of the alloc/free functions that match the signatures specified above.

import torch

Load the allocator

new_alloc = torch.cuda.memory.CUDAPluggableAllocator( 'alloc.so', 'my_malloc', 'my_free')

Swap the current allocator

torch.cuda.memory.change_current_allocator(new_alloc)

This will allocate memory in the device using the new allocator

b = torch.zeros(10, device='cuda')

import torch

Do an initial memory allocator

b = torch.zeros(10, device='cuda')

Load the allocator

new_alloc = torch.cuda.memory.CUDAPluggableAllocator( 'alloc.so', 'my_malloc', 'my_free')

This will error since the current allocator was already instantiated

torch.cuda.memory.change_current_allocator(new_alloc)

Mixing different CUDA system allocators in the same program

Depending on your use case, change_current_allocator() may not be what you want to use, since it swaps the CUDA allocator for the entire program (similar toPYTORCH_CUDA_ALLOC_CONF=backend:cudaMallocAsync). For instance, if the swapped allocator doesn’t have caching mechanism, you will lose all the benefits of PyTorch’s CUDACachingAllocator. Instead, you can selectively mark a region of PyTorch code to use a custom allocator usingtorch.cuda.MemPool. This will let you use multiple CUDA system allocators in the same PyTorch program, along with most of the benefits of the CUDACachingAllocator (e.g. caching). Using torch.cuda.MemPool, you can utilize custom allocators that enable several features, such as:

Note

While cudaMallocManaged offers convenient automatic memory management using CUDA Unified Virtual Memory (UVM), it is not recommended for DL workloads. For DL workloads that fit in GPU memory, explicit placement consistently outperforms UVM, since there are no page faults and access patterns remain predictable. When GPU memory gets saturated, UVM has to perform costly double transfers, evicting pages to CPU before bringing in new ones.

The code below shows ncclMemAlloc wrapped in a torch.cuda.memory.CUDAPluggableAllocator.

import os

import torch import torch.distributed as dist from torch.cuda.memory import CUDAPluggableAllocator from torch.distributed.distributed_c10d import _get_default_group from torch.utils import cpp_extension

create allocator

nccl_allocator_source = """ #include <nccl.h> #include extern "C" {

void* nccl_alloc_plug(size_t size, int device, void* stream) { std::cout << "Using ncclMemAlloc" << std::endl; void* ptr; ncclResult_t err = ncclMemAlloc(&ptr, size); return ptr;

}

void nccl_free_plug(void* ptr, size_t size, int device, void* stream) { std::cout << "Using ncclMemFree" << std::endl; ncclResult_t err = ncclMemFree(ptr); }

} """ nccl_allocator_libname = "nccl_allocator" nccl_allocator = torch.utils.cpp_extension.load_inline( name=nccl_allocator_libname, cpp_sources=nccl_allocator_source, with_cuda=True, extra_ldflags=["-lnccl"], verbose=True, is_python_module=False, build_directory="./", )

allocator = CUDAPluggableAllocator( f"./{nccl_allocator_libname}.so", "nccl_alloc_plug", "nccl_free_plug" ).allocator()

setup distributed

rank = int(os.getenv("RANK")) local_rank = int(os.getenv("LOCAL_RANK")) world_size = int(os.getenv("WORLD_SIZE")) torch.cuda.set_device(local_rank) dist.init_process_group(backend="nccl") device = torch.device(f"cuda:{local_rank}") default_pg = _get_default_group() backend = default_pg._get_backend(device)

Note: for convenience, ProcessGroupNCCL backend provides

the ncclMemAlloc allocator as backend.mem_allocator

allocator = backend.mem_allocator

You can now define a new memory pool by passing this allocator to torch.cuda.MemPool:

pool = torch.cuda.MemPool(allocator)

The pool can then be used with the torch.cuda.use_mem_pool context manager to allocate tensors into that pool:

with torch.cuda.use_mem_pool(pool): # tensor gets allocated with ncclMemAlloc passed in the pool tensor = torch.arange(1024 * 1024 * 2, device=device) print(f"tensor ptr on rank {rank} is {hex(tensor.data_ptr())}")

register user buffers using ncclCommRegister (called under the hood)

backend.register_mem_pool(pool)

Collective uses Zero Copy NVLS

dist.all_reduce(tensor[0:4]) torch.cuda.synchronize() print(tensor[0:4])

Note the usage of register_mem_pool in the above example. This is an extra step for NVLS reductions, where the user buffers need to be registered with NCCL. A user can de-register the buffers with a similar deregister_mem_pool call.

To reclaim memory, users will first need to ensure nothing is using the pool. When none of the tensors are holding a reference to the pool, empty_cache() will be called internally on deletion of the pool, hence returning all the memory to the system.

The following torch.cuda.MemPool.use_count() and torch.cuda.MemPool.snapshot()APIs can be used for debugging purposes:

pool = torch.cuda.MemPool(allocator)

pool's use count should be 1 at this point as MemPool object

holds a reference

assert pool.use_count() == 1

nelem_1mb = 1024 * 1024 // 4

with torch.cuda.use_mem_pool(pool): out_0 = torch.randn(nelem_1mb, device="cuda")

# pool's use count should be 2 at this point as use_mem_pool
# holds a reference
assert pool.use_count() == 2

pool's use count should be back to 1 at this point as use_mem_pool

released its reference

assert pool.use_count() == 1

with torch.cuda.use_mem_pool(pool): # pool should have 1 segment since we made a small allocation (1 MB) # above and so the CUDACachingAllocator packed it into a 2 MB buffer assert len(pool.snapshot()) == 1

out_1 = torch.randn(nelem_1mb, device="cuda")

# pool should still have 1 segment since we made another small allocation
# (1 MB) that got packed into the existing 2 MB buffer
assert len(pool.snapshot()) == 1

out_2 = torch.randn(nelem_1mb, device="cuda")

# pool now should have 2 segments since the CUDACachingAllocator had
# to make a new 2 MB buffer to accomodate out_2
assert len(pool.snapshot()) == 2

Note

cuBLAS workspaces

For each combination of cuBLAS handle and CUDA stream, a cuBLAS workspace will be allocated if that handle and stream combination executes a cuBLAS kernel that requires a workspace. In order to avoid repeatedly allocating workspaces, these workspaces are not deallocated unlesstorch._C._cuda_clearCublasWorkspaces() is called. The workspace size per allocation can be specified via the environment variable CUBLAS_WORKSPACE_CONFIG with the format :[SIZE]:[COUNT]. As an example, the default workspace size per allocation is CUBLAS_WORKSPACE_CONFIG=:4096:2:16:8which specifies a total size of 2 * 4096 + 8 * 16 KiB. To force cuBLAS to avoid using workspaces, set CUBLAS_WORKSPACE_CONFIG=:0:0.

cuFFT plan cache

For each CUDA device, an LRU cache of cuFFT plans is used to speed up repeatedly running FFT methods (e.g., torch.fft.fft()) on CUDA tensors of same geometry with same configuration. Because some cuFFT plans may allocate GPU memory, these caches have a maximum capacity.

You may control and query the properties of the cache of current device with the following APIs:

To control and query plan caches of a non-default device, you can index thetorch.backends.cuda.cufft_plan_cache object with either a torch.deviceobject or a device index, and access one of the above attributes. E.g., to set the capacity of the cache for device 1, one can writetorch.backends.cuda.cufft_plan_cache[1].max_size = 10.

Just-in-Time Compilation

PyTorch just-in-time compiles some operations, like torch.special.zeta, when performed on CUDA tensors. This compilation can be time consuming (up to a few seconds depending on your hardware and software) and may occur multiple times for a single operator since many PyTorch operators actually select from a variety of kernels, each of which must be compiled once, depending on their input. This compilation occurs once per process, or just once if a kernel cache is used.

By default, PyTorch creates a kernel cache in XDGCACHEHOME/torch/kernelsifXDGCACHEHOMEisdefinedandXDG_CACHE_HOME/torch/kernels if XDG_CACHE_HOME is defined and XDGCACHEHOME/torch/kernelsifXDGCACHEHOMEisdefinedandHOME/.cache/torch/kernels if it’s not (except on Windows, where the kernel cache is not yet supported). The caching behavior can be directly controlled with two environment variables. If USE_PYTORCH_KERNEL_CACHE is set to 0 then no cache will be used, and if PYTORCH_KERNEL_CACHE_PATH is set then that path will be used as a kernel cache instead of the default location.

Best practices

Device-agnostic code

Due to the structure of PyTorch, you may need to explicitly write device-agnostic (CPU or GPU) code; an example may be creating a new tensor as the initial hidden state of a recurrent neural network.

The first step is to determine whether the GPU should be used or not. A common pattern is to use Python’s argparse module to read in user arguments, and have a flag that can be used to disable CUDA, in combination withis_available(). In the following, args.device results in atorch.device object that can be used to move tensors to CPU or CUDA.

import argparse import torch

parser = argparse.ArgumentParser(description='PyTorch Example') parser.add_argument('--disable-cuda', action='store_true', help='Disable CUDA') args = parser.parse_args() args.device = None if not args.disable_cuda and torch.cuda.is_available(): args.device = torch.device('cuda') else: args.device = torch.device('cpu')

Note

When assessing the availability of CUDA in a given environment (is_available()), PyTorch’s default behavior is to call the CUDA Runtime API method cudaGetDeviceCount. Because this call in turn initializes the CUDA Driver API (via cuInit) if it is not already initialized, subsequent forks of a process that has runis_available() will fail with a CUDA initialization error.

One can set PYTORCH_NVML_BASED_CUDA_CHECK=1 in your environment before importing PyTorch modules that executeis_available() (or before executing it directly) in order to directis_available() to attempt an NVML-based assessment (nvmlDeviceGetCount_v2). If the NVML-based assessment is successful (i.e. NVML discovery/initialization does not fail),is_available() calls will not poison subsequent forks.

If NVML discovery/initialization fails, is_available() will fallback to the standard CUDA Runtime API assessment and the aforementioned fork constraint will apply.

Note that the above NVML-based CUDA availability assessment provides a weaker guarantee than the default CUDA Runtime API approach (which requires CUDA initialization to succeed). In some circumstances, the NVML-based check may succeed while later CUDA initialization fails.

Now that we have args.device, we can use it to create a Tensor on the desired device.

x = torch.empty((8, 42), device=args.device) net = Network().to(device=args.device)

This can be used in a number of cases to produce device agnostic code. Below is an example when using a dataloader:

cuda0 = torch.device('cuda:0') # CUDA GPU 0 for i, x in enumerate(train_loader): x = x.to(cuda0)

When working with multiple GPUs on a system, you can use theCUDA_VISIBLE_DEVICES environment flag to manage which GPUs are available to PyTorch. As mentioned above, to manually control which GPU a tensor is created on, the best practice is to use a torch.cuda.device context manager.

print("Outside device is 0") # On device 0 (default in most scenarios) with torch.cuda.device(1): print("Inside device is 1") # On device 1 print("Outside device is still 0") # On device 0

If you have a tensor and would like to create a new tensor of the same type on the same device, then you can use a torch.Tensor.new_* method (see torch.Tensor). Whilst the previously mentioned torch.* factory functions (Creation Ops) depend on the current GPU context and the attributes arguments you pass in, torch.Tensor.new_* methods preserve the device and other attributes of the tensor.

This is the recommended practice when creating modules in which new tensors need to be created internally during the forward pass.

cuda = torch.device('cuda') x_cpu = torch.empty(2) x_gpu = torch.empty(2, device=cuda) x_cpu_long = torch.empty(2, dtype=torch.int64)

y_cpu = x_cpu.new_full([3, 2], fill_value=0.3) print(y_cpu)

tensor([[ 0.3000,  0.3000],
        [ 0.3000,  0.3000],
        [ 0.3000,  0.3000]])

y_gpu = x_gpu.new_full([3, 2], fill_value=-5) print(y_gpu)

tensor([[-5.0000, -5.0000],
        [-5.0000, -5.0000],
        [-5.0000, -5.0000]], device='cuda:0')

y_cpu_long = x_cpu_long.new_tensor([[1, 2, 3]]) print(y_cpu_long)

tensor([[ 1,  2,  3]])

If you want to create a tensor of the same type and size of another tensor, and fill it with either ones or zeros, ones_like() orzeros_like() are provided as convenient helper functions (which also preserve torch.device and torch.dtype of a Tensor).

x_cpu = torch.empty(2, 3) x_gpu = torch.empty(2, 3)

y_cpu = torch.ones_like(x_cpu) y_gpu = torch.zeros_like(x_gpu)

Use pinned memory buffers

Warning

This is an advanced tip. If you overuse pinned memory, it can cause serious problems when running low on RAM, and you should be aware that pinning is often an expensive operation.

Host to GPU copies are much faster when they originate from pinned (page-locked) memory. CPU tensors and storages expose a pin_memory()method, that returns a copy of the object, with data put in a pinned region.

Also, once you pin a tensor or storage, you can use asynchronous GPU copies. Just pass an additional non_blocking=True argument to ato() or a cuda() call. This can be used to overlap data transfers with computation.

You can make the DataLoader return batches placed in pinned memory by passing pin_memory=True to its constructor.

Use nn.parallel.DistributedDataParallel instead of multiprocessing or nn.DataParallel

Most use cases involving batched inputs and multiple GPUs should default to using DistributedDataParallel to utilize more than one GPU.

There are significant caveats to using CUDA models withmultiprocessing; unless care is taken to meet the data handling requirements exactly, it is likely that your program will have incorrect or undefined behavior.

It is recommended to use DistributedDataParallel, instead of DataParallel to do multi-GPU training, even if there is only a single node.

The difference between DistributedDataParallel andDataParallel is: DistributedDataParalleluses multiprocessing where a process is created for each GPU, whileDataParallel uses multithreading. By using multiprocessing, each GPU has its dedicated process, this avoids the performance overhead caused by GIL of Python interpreter.

If you use DistributedDataParallel, you could usetorch.distributed.launch utility to launch your program, see Third-party backends.

CUDA Graphs

A CUDA graph is a record of the work (mostly kernels and their arguments) that a CUDA stream and its dependent streams perform. For general principles and details on the underlying CUDA API, seeGetting Started with CUDA Graphs and theGraphs section of the CUDA C Programming Guide.

PyTorch supports the construction of CUDA graphs using stream capture, which puts a CUDA stream in capture mode. CUDA work issued to a capturing stream doesn’t actually run on the GPU. Instead, the work is recorded in a graph.

After capture, the graph can be launched to run the GPU work as many times as needed. Each replay runs the same kernels with the same arguments. For pointer arguments this means the same memory addresses are used. By filling input memory with new data (e.g., from a new batch) before each replay, you can rerun the same work on new data.

Why CUDA Graphs?

Replaying a graph sacrifices the dynamic flexibility of typical eager execution in exchange forgreatly reduced CPU overhead. A graph’s arguments and kernels are fixed, so a graph replay skips all layers of argument setup and kernel dispatch, including Python, C++, and CUDA driver overheads. Under the hood, a replay submits the entire graph’s work to the GPU with a single call to cudaGraphLaunch. Kernels in a replay also execute slightly faster on the GPU, but eliding CPU overhead is the main benefit.

You should try CUDA graphs if all or part of your network is graph-safe (usually this means static shapes and static control flow, but see the other constraints) and you suspect its runtime is at least somewhat CPU-limited.

PyTorch API

Warning

This API is in beta and may change in future releases.

PyTorch exposes graphs via a raw torch.cuda.CUDAGraph class and two convenience wrappers,torch.cuda.graph andtorch.cuda.make_graphed_callables.

torch.cuda.graph is a simple, versatile context manager that captures CUDA work in its context. Before capture, warm up the workload to be captured by running a few eager iterations. Warmup must occur on a side stream. Because the graph reads from and writes to the same memory addresses in every replay, you must maintain long-lived references to tensors that hold input and output data during capture. To run the graph on new input data, copy new data to the capture’s input tensor(s), replay the graph, then read the new output from the capture’s output tensor(s). Example:

g = torch.cuda.CUDAGraph()

Placeholder input used for capture

static_input = torch.empty((5,), device="cuda")

Warmup before capture

s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for _ in range(3): static_output = static_input * 2 torch.cuda.current_stream().wait_stream(s)

Captures the graph

To allow capture, automatically sets a side stream as the current stream in the context

with torch.cuda.graph(g): static_output = static_input * 2

Fills the graph's input memory with new data to compute on

static_input.copy_(torch.full((5,), 3, device="cuda")) g.replay()

static_output holds the results

print(static_output) # full of 3 * 2 = 6

Fills the graph's input memory with more data to compute on

static_input.copy_(torch.full((5,), 4, device="cuda")) g.replay() print(static_output) # full of 4 * 2 = 8

SeeWhole-network capture,Usage with torch.cuda.amp, andUsage with multiple streamsfor realistic and advanced patterns.

make_graphed_callables is more sophisticated.make_graphed_callables accepts Python functions andtorch.nn.Modules. For each passed function or Module, it creates separate graphs of the forward-pass and backward-pass work. SeePartial-network capture.

Constraints

A set of ops is capturable if it doesn’t violate any of the following constraints.

Constraints apply to all work in atorch.cuda.graph context and all work in the forward and backward passes of any callable you pass to torch.cuda.make_graphed_callables().

Violating any of these will likely cause a runtime error:

Violating any of these will likely cause silent numerical errors or undefined behavior:

Non-constraints

Whole-network capture

If your entire network is capturable, you can capture and replay an entire iteration:

N, D_in, H, D_out = 640, 4096, 2048, 1024 model = torch.nn.Sequential(torch.nn.Linear(D_in, H), torch.nn.Dropout(p=0.2), torch.nn.Linear(H, D_out), torch.nn.Dropout(p=0.1)).cuda() loss_fn = torch.nn.MSELoss() optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

Placeholders used for capture

static_input = torch.randn(N, D_in, device='cuda') static_target = torch.randn(N, D_out, device='cuda')

warmup

Uses static_input and static_target here for convenience,

but in a real setting, because the warmup includes optimizer.step()

you must use a few batches of real data.

s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for i in range(3): optimizer.zero_grad(set_to_none=True) y_pred = model(static_input) loss = loss_fn(y_pred, static_target) loss.backward() optimizer.step() torch.cuda.current_stream().wait_stream(s)

capture

g = torch.cuda.CUDAGraph()

Sets grads to None before capture, so backward() will create

.grad attributes with allocations from the graph's private pool

optimizer.zero_grad(set_to_none=True) with torch.cuda.graph(g): static_y_pred = model(static_input) static_loss = loss_fn(static_y_pred, static_target) static_loss.backward() optimizer.step()

real_inputs = [torch.rand_like(static_input) for _ in range(10)] real_targets = [torch.rand_like(static_target) for _ in range(10)]

for data, target in zip(real_inputs, real_targets): # Fills the graph's input memory with new data to compute on static_input.copy_(data) static_target.copy_(target) # replay() includes forward, backward, and step. # You don't even need to call optimizer.zero_grad() between iterations # because the captured backward refills static .grad tensors in place. g.replay() # Params have been updated. static_y_pred, static_loss, and .grad # attributes hold values from computing on this iteration's data.

Partial-network capture

If some of your network is unsafe to capture (e.g., due to dynamic control flow, dynamic shapes, CPU syncs, or essential CPU-side logic), you can run the unsafe part(s) eagerly and use torch.cuda.make_graphed_callables() to graph only the capture-safe part(s).

By default, callables returned by make_graphed_callables()are autograd-aware, and can be used in the training loop as direct replacements for the functions or nn.Modules you passed.

make_graphed_callables() internally createsCUDAGraph objects, runs warmup iterations, and maintains static inputs and outputs as needed. Therefore (unlike withtorch.cuda.graph) you don’t need to handle those manually.

In the following example, data-dependent dynamic control flow means the network isn’t capturable end-to-end, butmake_graphed_callables()lets us capture and run graph-safe sections as graphs regardless:

N, D_in, H, D_out = 640, 4096, 2048, 1024

module1 = torch.nn.Linear(D_in, H).cuda() module2 = torch.nn.Linear(H, D_out).cuda() module3 = torch.nn.Linear(H, D_out).cuda()

loss_fn = torch.nn.MSELoss() optimizer = torch.optim.SGD(chain(module1.parameters(), module2.parameters(), module3.parameters()), lr=0.1)

Sample inputs used for capture

requires_grad state of sample inputs must match

requires_grad state of real inputs each callable will see.

x = torch.randn(N, D_in, device='cuda') h = torch.randn(N, H, device='cuda', requires_grad=True)

module1 = torch.cuda.make_graphed_callables(module1, (x,)) module2 = torch.cuda.make_graphed_callables(module2, (h,)) module3 = torch.cuda.make_graphed_callables(module3, (h,))

real_inputs = [torch.rand_like(x) for _ in range(10)] real_targets = [torch.randn(N, D_out, device="cuda") for _ in range(10)]

for data, target in zip(real_inputs, real_targets): optimizer.zero_grad(set_to_none=True)

tmp = module1(data)  # forward ops run as a graph

if tmp.sum().item() > 0:
    tmp = module2(tmp)  # forward ops run as a graph
else:
    tmp = module3(tmp)  # forward ops run as a graph

loss = loss_fn(tmp, target)
# module2's or module3's (whichever was chosen) backward ops,
# as well as module1's backward ops, run as graphs
loss.backward()
optimizer.step()

Usage with torch.cuda.amp

For typical optimizers, GradScaler.step syncs the CPU with the GPU, which is prohibited during capture. To avoid errors, either usepartial-network capture, or (if forward, loss, and backward are capture-safe) capture forward, loss, and backward but not the optimizer step:

warmup

In a real setting, use a few batches of real data.

s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for i in range(3): optimizer.zero_grad(set_to_none=True) with torch.cuda.amp.autocast(): y_pred = model(static_input) loss = loss_fn(y_pred, static_target) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() torch.cuda.current_stream().wait_stream(s)

capture

g = torch.cuda.CUDAGraph() optimizer.zero_grad(set_to_none=True) with torch.cuda.graph(g): with torch.cuda.amp.autocast(): static_y_pred = model(static_input) static_loss = loss_fn(static_y_pred, static_target) scaler.scale(static_loss).backward() # don't capture scaler.step(optimizer) or scaler.update()

real_inputs = [torch.rand_like(static_input) for _ in range(10)] real_targets = [torch.rand_like(static_target) for _ in range(10)]

for data, target in zip(real_inputs, real_targets): static_input.copy_(data) static_target.copy_(target) g.replay() # Runs scaler.step and scaler.update eagerly scaler.step(optimizer) scaler.update()

Usage with multiple streams

Capture mode automatically propagates to any streams that sync with a capturing stream. Within capture, you may expose parallelism by issuing calls to different streams, but the overall stream dependency DAG must branch out from the initial capturing stream after capture begins and rejoin the initial stream before capture ends:

with torch.cuda.graph(g): # at context manager entrance, torch.cuda.current_stream() # is the initial capturing stream

# INCORRECT (does not branch out from or rejoin initial stream)
with torch.cuda.stream(s):
    cuda_work()

# CORRECT:
# branches out from initial stream
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
    cuda_work()
# rejoins initial stream before capture ends
torch.cuda.current_stream().wait_stream(s)

Note

To avoid confusion for power users looking at replays in nsight systems or nvprof: Unlike eager execution, the graph interprets a nontrivial stream DAG in capture as a hint, not a command. During replay, the graph may reorganize independent ops onto different streams or enqueue them in a different order (while respecting your original DAG’s overall dependencies).

Usage with DistributedDataParallel

NCCL < 2.9.6

NCCL versions earlier than 2.9.6 don’t allow collectives to be captured. You must use partial-network capture, which defers allreduces to happen outside graphed sections of backward.

Call make_graphed_callables() on graphable network sections_before_ wrapping the network with DDP.

NCCL >= 2.9.6

NCCL versions 2.9.6 or later allow collectives in the graph. Approaches that capture an entire backward passare a viable option, but need three setup steps.

  1. Disable DDP’s internal async error handling:
    os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
    torch.distributed.init_process_group(...)
  2. Before full-backward capture, DDP must be constructed in a side-stream context:
    with torch.cuda.stream(s):
    model = DistributedDataParallel(model)
  3. Your warmup must run at least 11 DDP-enabled eager iterations before capture.

Graph memory management

A captured graph acts on the same virtual addresses every time it replays. If PyTorch frees the memory, a later replay can hit an illegal memory access. If PyTorch reassigns the memory to new tensors, the replay can corrupt the values seen by those tensors. Therefore, the virtual addresses used by the graph must be reserved for the graph across replays. The PyTorch caching allocator achieves this by detecting when capture is underway and satisfying the capture’s allocations from a graph-private memory pool. The private pool stays alive until itsCUDAGraph object and all tensors created during capture go out of scope.

Private pools are maintained automatically. By default, the allocator creates a separate private pool for each capture. If you capture multiple graphs, this conservative approach ensures graph replays never corrupt each other’s values, but sometimes needlessly wastes memory.

Sharing memory across captures

To economize the memory stashed in private pools, torch.cuda.graphand torch.cuda.make_graphed_callables() optionally allow different captures to share the same private pool. It’s safe for a set of graphs to share a private pool if you know they’ll always be replayed in the same order they were captured, and never be replayed concurrently.

torch.cuda.graph’s pool argument is a hint to use a particular private pool, and can be used to share memory across graphs as shown:

g1 = torch.cuda.CUDAGraph() g2 = torch.cuda.CUDAGraph()

(create static inputs for g1 and g2, run warmups of their workloads...)

Captures g1

with torch.cuda.graph(g1): static_out_1 = g1_workload(static_in_1)

Captures g2, hinting that g2 may share a memory pool with g1

with torch.cuda.graph(g2, pool=g1.pool()): static_out_2 = g2_workload(static_in_2)

static_in_1.copy_(real_data_1) static_in_2.copy_(real_data_2) g1.replay() g2.replay()

With torch.cuda.make_graphed_callables(), if you want to graph several callables and you know they’ll always run in the same order (and never concurrently) pass them as a tuple in the same order they’ll run in the live workload, andmake_graphed_callables() will capture their graphs using a shared private pool.

If, in the live workload, your callables will run in an order that occasionally changes, or if they’ll run concurrently, passing them as a tuple to a single invocation ofmake_graphed_callables() is not allowed. Instead, you must callmake_graphed_callables() separately for each one.