Distributed RPC Framework — PyTorch 2.7 documentation (original) (raw)

The distributed RPC framework provides mechanisms for multi-machine model training through a set of primitives to allow for remote communication, and a higher-level API to automatically differentiate models split across several machines.

Warning

APIs in the RPC package are stable. There are multiple ongoing work items to improve performance and error handling, which will ship in future releases.

Warning

CUDA support was introduced in PyTorch 1.9 and is still a beta feature. Not all features of the RPC package are yet compatible with CUDA support and thus their use is discouraged. These unsupported features include: RRefs, JIT compatibility, dist autograd and dist optimizer, and profiling. These shortcomings will be addressed in future releases.

Basics

The distributed RPC framework makes it easy to run functions remotely, supports referencing remote objects without copying the real data around, and provides autograd and optimizer APIs to transparently run backward and update parameters across RPC boundaries. These features can be categorized into four sets of APIs.

  1. Remote Procedure Call (RPC) supports running a function on the specified destination worker with the given arguments and getting the return value back or creating a reference to the return value. There are three main RPC APIs:rpc_sync() (synchronous),rpc_async() (asynchronous), andremote() (asynchronous and returns a reference to the remote return value). Use the synchronous API if the user code cannot proceed without the return value. Otherwise, use the asynchronous API to get a future, and wait on the future when the return value is needed on the caller. The remote() API is useful when the requirement is to create something remotely but never need to fetch it to the caller. Imagine the case that a driver process is setting up a parameter server and a trainer. The driver can create an embedding table on the parameter server and then share the reference to the embedding table with the trainer, but itself will never use the embedding table locally. In this case,rpc_sync() andrpc_async() are no longer appropriate, as they always imply that the return value will be returned to the caller immediately or in the future.
  2. Remote Reference (RRef) serves as a distributed shared pointer to a local or remote object. It can be shared with other workers and reference counting will be handled transparently. Each RRef only has one owner and the object only lives on that owner. Non-owner workers holding RRefs can get copies of the object from the owner by explicitly requesting it. This is useful when a worker needs to access some data object, but itself is neither the creator (the caller of remote()) or the owner of the object. The distributed optimizer, as we will discuss below, is one example of such use cases.
  3. Distributed Autograd stitches together local autograd engines on all the workers involved in the forward pass, and automatically reach out to them during the backward pass to compute gradients. This is especially helpful if the forward pass needs to span multiple machines when conducting, e.g., distributed model parallel training, parameter-server training, etc. With this feature, user code no longer needs to worry about how to send gradients across RPC boundaries and in which order should the local autograd engines be launched, which can become quite complicated where there are nested and inter-dependent RPC calls in the forward pass.
  4. Distributed Optimizer’s constructor takes aOptimizer() (e.g., SGD(),Adagrad(), etc.) and a list of parameter RRefs, creates anOptimizer() instance on each distinct RRef owner, and updates parameters accordingly when running step(). When you have distributed forward and backward passes, parameters and gradients will be scattered across multiple workers, and hence it requires an optimizer on each of the involved workers. Distributed Optimizer wraps all those local optimizers into one, and provides a concise constructor and step() API.

RPC

Before using RPC and distributed autograd primitives, initialization must take place. To initialize the RPC framework we need to useinit_rpc() which would initialize the RPC framework, RRef framework and distributed autograd.

torch.distributed.rpc.init_rpc(name, backend=None, rank=-1, world_size=None, rpc_backend_options=None)[source][source]

Initializes RPC primitives such as the local RPC agent and distributed autograd, which immediately makes the current process ready to send and receive RPCs.

Parameters

The following APIs allow users to remotely execute functions as well as create references (RRefs) to remote data objects. In these APIs, when passing aTensor as an argument or a return value, the destination worker will try to create a Tensor with the same meta (i.e., shape, stride, etc.). We intentionally disallow transmitting CUDA tensors because it might crash if the device lists on source and destination workers do not match. In such cases, applications can always explicitly move the input tensors to CPU on the caller and move it to the desired devices on the callee if necessary.

Warning

TorchScript support in RPC is a prototype feature and subject to change. Since v1.5.0, torch.distributed.rpc supports calling TorchScript functions as RPC target functions, and this will help improve parallelism on the callee side as executing TorchScript functions does not require GIL.

torch.distributed.rpc.rpc_sync(to, func, args=None, kwargs=None, timeout=-1.0)[source][source]

Make a blocking RPC call to run function func on worker to. RPC messages are sent and received in parallel to execution of Python code. This method is thread-safe.

Parameters

Returns

Returns the result of running func with args and kwargs.

Example::

Make sure that MASTER_ADDR and MASTER_PORT are set properly on both workers. Refer to init_process_group()API for more details. For example,

export MASTER_ADDR=localhost export MASTER_PORT=5678

Then run the following code in two different processes:

On worker 0:

import torch import torch.distributed.rpc as rpc rpc.init_rpc("worker0", rank=0, world_size=2) ret = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3)) rpc.shutdown()

On worker 1:

import torch.distributed.rpc as rpc rpc.init_rpc("worker1", rank=1, world_size=2) rpc.shutdown()

Below is an example of running a TorchScript function using RPC.

On both workers:

@torch.jit.script def my_script_add(tensor: torch.Tensor, scalar: int): return torch.add(tensor, scalar)

On worker 0:

import torch.distributed.rpc as rpc rpc.init_rpc("worker0", rank=0, world_size=2) ret = rpc.rpc_sync("worker1", my_script_add, args=(torch.ones(2), 3)) rpc.shutdown()

On worker 1:

import torch.distributed.rpc as rpc rpc.init_rpc("worker1", rank=1, world_size=2) rpc.shutdown()

torch.distributed.rpc.rpc_async(to, func, args=None, kwargs=None, timeout=-1.0)[source][source]

Make a non-blocking RPC call to run function func on worker to. RPC messages are sent and received in parallel to execution of Python code. This method is thread-safe. This method will immediately return aFuture that can be awaited on.

Parameters

Returns

Returns a Future object that can be waited on. When completed, the return value of func on args andkwargs can be retrieved from the Futureobject.

Warning

Using GPU tensors as arguments or return values of func is not supported since we don’t support sending GPU tensors over the wire. You need to explicitly copy GPU tensors to CPU before using them as arguments or return values of func.

Warning

The rpc_async API does not copy storages of argument tensors until sending them over the wire, which could be done by a different thread depending on the RPC backend type. The caller should make sure that the contents of those tensors stay intact until the returnedFuture completes.

Example::

Make sure that MASTER_ADDR and MASTER_PORT are set properly on both workers. Refer to init_process_group()API for more details. For example,

export MASTER_ADDR=localhost export MASTER_PORT=5678

Then run the following code in two different processes:

On worker 0:

import torch import torch.distributed.rpc as rpc rpc.init_rpc("worker0", rank=0, world_size=2) fut1 = rpc.rpc_async("worker1", torch.add, args=(torch.ones(2), 3)) fut2 = rpc.rpc_async("worker1", min, args=(1, 2)) result = fut1.wait() + fut2.wait() rpc.shutdown()

On worker 1:

import torch.distributed.rpc as rpc rpc.init_rpc("worker1", rank=1, world_size=2) rpc.shutdown()

Below is an example of running a TorchScript function using RPC.

On both workers:

@torch.jit.script def my_script_add(tensor: torch.Tensor, scalar: int): return torch.add(tensor, scalar)

On worker 0:

import torch.distributed.rpc as rpc rpc.init_rpc("worker0", rank=0, world_size=2) fut = rpc.rpc_async("worker1", my_script_add, args=(torch.ones(2), 3)) ret = fut.wait() rpc.shutdown()

On worker 1:

import torch.distributed.rpc as rpc rpc.init_rpc("worker1", rank=1, world_size=2) rpc.shutdown()

torch.distributed.rpc.remote(to, func, args=None, kwargs=None, timeout=-1.0)[source][source]

Make a remote call to run func on worker to and return anRRef to the result value immediately. Worker to will be the owner of the returnedRRef, and the worker calling remote is a user. The owner manages the global reference count of itsRRef, and the ownerRRef is only destructed when globally there are no living references to it.

Parameters

Returns

A user RRef instance to the result value. Use the blocking API torch.distributed.rpc.RRef.to_here()to retrieve the result value locally.

Warning

The remote API does not copy storages of argument tensors until sending them over the wire, which could be done by a different thread depending on the RPC backend type. The caller should make sure that the contents of those tensors stay intact until the returned RRef is confirmed by the owner, which can be checked using thetorch.distributed.rpc.RRef.confirmed_by_owner() API.

Warning

Errors such as timeouts for the remote API are handled on a best-effort basis. This means that when remote calls initiated byremote fail, such as with a timeout error, we take a best-effort approach to error handling. This means that errors are handled and set on the resulting RRef on an asynchronous basis. If the RRef has not been used by the application before this handling (such as to_here or fork call), then future uses of the RRef will appropriately raise errors. However, it is possible that the user application will use theRRef before the errors are handled. In this case, errors may not be raised as they have not yet been handled.

Example:

Make sure that MASTER_ADDR and MASTER_PORT are set properly on both workers. Refer to :meth:~torch.distributed.init_process_group API for more details. For example,

export MASTER_ADDR=localhost export MASTER_PORT=5678

Then run the following code in two different processes:

On worker 0:

import torch import torch.distributed.rpc as rpc rpc.init_rpc("worker0", rank=0, world_size=2) rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3)) rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1)) x = rref1.to_here() + rref2.to_here() rpc.shutdown()

On worker 1:

import torch.distributed.rpc as rpc rpc.init_rpc("worker1", rank=1, world_size=2) rpc.shutdown()

Below is an example of running a TorchScript function using RPC.

On both workers:

@torch.jit.script def my_script_add(tensor: torch.Tensor, scalar: int): return torch.add(tensor, scalar)

On worker 0:

import torch.distributed.rpc as rpc rpc.init_rpc("worker0", rank=0, world_size=2) rref = rpc.remote("worker1", my_script_add, args=(torch.ones(2), 3)) rref.to_here() rpc.shutdown()

On worker 1:

import torch.distributed.rpc as rpc rpc.init_rpc("worker1", rank=1, world_size=2) rpc.shutdown()

torch.distributed.rpc.get_worker_info(worker_name=None)[source][source]

Get WorkerInfo of a given worker name. Use this WorkerInfo to avoid passing an expensive string on every invocation.

Parameters

worker_name (str) – the string name of a worker. If None, return the the id of the current worker. (default None)

Returns

WorkerInfo instance for the givenworker_name or WorkerInfo of the current worker if worker_name is None.

torch.distributed.rpc.shutdown(graceful=True, timeout=0)[source][source]

Perform a shutdown of the RPC agent, and then destroy the RPC agent. This stops the local agent from accepting outstanding requests, and shuts down the RPC framework by terminating all RPC threads. If graceful=True, this will block until all local and remote RPC processes reach this method and wait for all outstanding work to complete. Otherwise, ifgraceful=False, this is a local shutdown, and it does not wait for other RPC processes to reach this method.

Warning

For Future objects returned byrpc_async(), future.wait() should not be called after shutdown().

Parameters

graceful (bool) – Whether to do a graceful shutdown or not. If True, this will 1) wait until there is no pending system messages for UserRRefs and delete them; 2) block until all local and remote RPC processes have reached this method and wait for all outstanding work to complete.

Example::

Make sure that MASTER_ADDR and MASTER_PORT are set properly on both workers. Refer to init_process_group()API for more details. For example,

export MASTER_ADDR=localhost export MASTER_PORT=5678

Then run the following code in two different processes:

On worker 0:

import torch import torch.distributed.rpc as rpc rpc.init_rpc("worker0", rank=0, world_size=2)

do some work

result = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(1), 1))

ready to shutdown

rpc.shutdown()

On worker 1:

import torch.distributed.rpc as rpc rpc.init_rpc("worker1", rank=1, world_size=2)

wait for worker 0 to finish work, and then shutdown.

rpc.shutdown()

class torch.distributed.rpc.WorkerInfo

A structure that encapsulates information of a worker in the system. Contains the name and ID of the worker. This class is not meant to be constructed directly, rather, an instance can be retrieved through get_worker_info() and the result can be passed in to functions such asrpc_sync(), rpc_async(),remote() to avoid copying a string on every invocation.

property id

Globally unique id to identify the worker.

property name

The name of the worker.

The RPC package also provides decorators which allow applications to specify how a given function should be treated on the callee side.

torch.distributed.rpc.functions.async_execution(fn)[source][source]

A decorator for a function indicating that the return value of the function is guaranteed to be a Future object and this function can run asynchronously on the RPC callee. More specifically, the callee extracts the Future returned by the wrapped function and installs subsequent processing steps as a callback to thatFuture. The installed callback will read the value from the Future when completed and send the value back as the RPC response. That also means the returnedFuture only exists on the callee side and is never sent through RPC. This decorator is useful when the wrapped function’s (fn) execution needs to pause and resume due to, e.g., containingrpc_async() or waiting for other signals.

Note

To enable asynchronous execution, applications must pass the function object returned by this decorator to RPC APIs. If RPC detected attributes installed by this decorator, it knows that this function returns a Future object and will handle that accordingly. However, this does not mean this decorator has to be outmost one when defining a function. For example, when combined with @staticmethodor @classmethod, @rpc.functions.async_execution needs to be the inner decorator to allow the target function be recognized as a static or class function. This target function can still execute asynchronously because, when accessed, the static or class method preserves attributes installed by @rpc.functions.async_execution.

Example::

The returned Future object can come fromrpc_async(),then(), or Futureconstructor. The example below shows directly using theFuture returned bythen().

from torch.distributed import rpc

omitting setup and shutdown RPC

On all workers

@rpc.functions.async_execution def async_add_chained(to, x, y, z): # This function runs on "worker1" and returns immediately when # the callback is installed through the then(cb) API. In the # mean time, the rpc_async to "worker2" can run concurrently. # When the return value of that rpc_async arrives at # "worker1", "worker1" will run the lambda function accordingly # and set the value for the previously returned Future, which # will then trigger RPC to send the result back to "worker0". return rpc.rpc_async(to, torch.add, args=(x, y)).then( lambda fut: fut.wait() + z )

On worker0

ret = rpc.rpc_sync( "worker1", async_add_chained, args=("worker2", torch.ones(2), 1, 1) ) print(ret) # prints tensor([3., 3.])

When combined with TorchScript decorators, this decorator must be the outmost one.

from torch import Tensor from torch.futures import Future from torch.distributed import rpc

omitting setup and shutdown RPC

On all workers

@torch.jit.script def script_add(x: Tensor, y: Tensor) -> Tensor: return x + y

@rpc.functions.async_execution @torch.jit.script def async_add(to: str, x: Tensor, y: Tensor) -> Future[Tensor]: return rpc.rpc_async(to, script_add, (x, y))

On worker0

ret = rpc.rpc_sync( "worker1", async_add, args=("worker2", torch.ones(2), 1) ) print(ret) # prints tensor([2., 2.])

When combined with static or class method, this decorator must be the inner one.

from torch.distributed import rpc

omitting setup and shutdown RPC

On all workers

class AsyncExecutionClass:

@staticmethod
@rpc.functions.async_execution
def static_async_add(to, x, y, z):
    return rpc.rpc_async(to, torch.add, args=(x, y)).then(
        lambda fut: fut.wait() + z
    )

@classmethod
@rpc.functions.async_execution
def class_async_add(cls, to, x, y, z):
    ret_fut = torch.futures.Future()
    rpc.rpc_async(to, torch.add, args=(x, y)).then(
        lambda fut: ret_fut.set_result(fut.wait() + z)
    )
    return ret_fut

@rpc.functions.async_execution
def bound_async_add(self, to, x, y, z):
    return rpc.rpc_async(to, torch.add, args=(x, y)).then(
        lambda fut: fut.wait() + z
    )

On worker0

ret = rpc.rpc_sync( "worker1", AsyncExecutionClass.static_async_add, args=("worker2", torch.ones(2), 1, 2) ) print(ret) # prints tensor([4., 4.])

ret = rpc.rpc_sync( "worker1", AsyncExecutionClass.class_async_add, args=("worker2", torch.ones(2), 1, 2) ) print(ret) # prints tensor([4., 4.])

This decorator also works with RRef helpers, i.e., .torch.distributed.rpc.RRef.rpc_sync(),torch.distributed.rpc.RRef.rpc_async(), andtorch.distributed.rpc.RRef.remote().

from torch.distributed import rpc

reuse the AsyncExecutionClass class above

rref = rpc.remote("worker1", AsyncExecutionClass) ret = rref.rpc_sync().static_async_add("worker2", torch.ones(2), 1, 2) print(ret) # prints tensor([4., 4.])

rref = rpc.remote("worker1", AsyncExecutionClass) ret = rref.rpc_async().static_async_add("worker2", torch.ones(2), 1, 2).wait() print(ret) # prints tensor([4., 4.])

rref = rpc.remote("worker1", AsyncExecutionClass) ret = rref.remote().static_async_add("worker2", torch.ones(2), 1, 2).to_here() print(ret) # prints tensor([4., 4.])

Backends

The RPC module can leverage different backends to perform the communication between the nodes. The backend to be used can be specified in theinit_rpc() function, by passing a certain value of the BackendType enum. Regardless of what backend is used, the rest of the RPC API won’t change. Each backend also defines its own subclass of the RpcBackendOptions class, an instance of which can also be passed to init_rpc()to configure the backend’s behavior.

class torch.distributed.rpc.BackendType(value)

An enum class of available backends.

PyTorch ships with a builtin BackendType.TENSORPIPE backend. Additional ones can be registered using theregister_backend() function.

class torch.distributed.rpc.RpcBackendOptions

An abstract structure encapsulating the options passed into the RPC backend. An instance of this class can be passed in toinit_rpc() in order to initialize RPC with specific configurations, such as the RPC timeout andinit_method to be used.

property init_method

URL specifying how to initialize the process group. Default is env://

property rpc_timeout

A float indicating the timeout to use for all RPCs. If an RPC does not complete in this timeframe, it will complete with an exception indicating that it has timed out.

TensorPipe Backend

The TensorPipe agent, which is the default, leverages the TensorPipe library, which provides a natively point-to-point communication primitive specifically suited for machine learning that fundamentally addresses some of the limitations of Gloo. Compared to Gloo, it has the advantage of being asynchronous, which allows a large number of transfers to occur simultaneously, each at their own speed, without blocking each other. It will only open pipes between pairs of nodes when needed, on demand, and when one node fails only its incident pipes will be closed, while all other ones will keep working as normal. In addition, it is able to support multiple different transports (TCP, of course, but also shared memory, NVLink, InfiniBand, …) and can automatically detect their availability and negotiate the best transport to use for each pipe.

The TensorPipe backend has been introduced in PyTorch v1.6 and is being actively developed. At the moment, it only supports CPU tensors, with GPU support coming soon. It comes with a TCP-based transport, just like Gloo. It is also able to automatically chunk and multiplex large tensors over multiple sockets and threads in order to achieve very high bandwidths. The agent will be able to pick the best transport on its own, with no intervention required.

Example:

import os from torch.distributed import rpc os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '29500'

rpc.init_rpc( "worker1", rank=0, world_size=2, rpc_backend_options=rpc.TensorPipeRpcBackendOptions( num_worker_threads=8, rpc_timeout=20 # 20 second timeout ) )

omitting init_rpc invocation on worker2

class torch.distributed.rpc.TensorPipeRpcBackendOptions(*, num_worker_threads=16, rpc_timeout=60.0, init_method='env://', device_maps=None, devices=None, _transports=None, _channels=None)[source][source]

The backend options forTensorPipeAgent, derived fromRpcBackendOptions.

Parameters

property device_maps

The device map locations.

property devices

All devices used by the local agent.

property init_method

URL specifying how to initialize the process group. Default is env://

property num_worker_threads

The number of threads in the thread-pool used byTensorPipeAgent to execute requests.

property rpc_timeout

A float indicating the timeout to use for all RPCs. If an RPC does not complete in this timeframe, it will complete with an exception indicating that it has timed out.

set_device_map(to, device_map)[source][source]

Set device mapping between each RPC caller and callee pair. This function can be called multiple times to incrementally add device placement configurations.

Parameters

Example

both workers

def add(x, y): print(x) # tensor([1., 1.], device='cuda:1') return x + y, (x + y).to(2)

on worker 0

options = TensorPipeRpcBackendOptions( num_worker_threads=8, device_maps={"worker1": {0: 1}}

maps worker0's cuda:0 to worker1's cuda:1

) options.set_device_map("worker1", {1: 2})

maps worker0's cuda:1 to worker1's cuda:2

rpc.init_rpc( "worker0", rank=0, world_size=2, backend=rpc.BackendType.TENSORPIPE, rpc_backend_options=options )

x = torch.ones(2) rets = rpc.rpc_sync("worker1", add, args=(x.to(0), 1))

The first argument will be moved to cuda:1 on worker1. When

sending the return value back, it will follow the invert of

the device map, and hence will be moved back to cuda:0 and

cuda:1 on worker0

print(rets[0]) # tensor([2., 2.], device='cuda:0') print(rets[1]) # tensor([2., 2.], device='cuda:1')

set_devices(devices)[source][source]

Set local devices used by the TensorPipe RPC agent. When processing CUDA RPC requests, the TensorPipe RPC agent will properly synchronize CUDA streams for all devices in this List.

Parameters

devices (List of int, str, or torch.device) – local devices used by the TensorPipe RPC agent.

Note

The RPC framework does not automatically retry anyrpc_sync(),rpc_async() andremote() calls. The reason being that there is no way the RPC framework can determine whether an operation is idempotent or not and whether it is safe to retry. As a result, it is the application’s responsibility to deal with failures and retry if necessary. RPC communication is based on TCP and as a result failures could happen due to network failures or intermittent network connectivity issues. In such scenarios, the application needs to retry appropriately with reasonable backoffs to ensure the network isn’t overwhelmed by aggressive retries.

RRef

Warning

RRefs are not currently supported when using CUDA tensors

An RRef (Remote REFerence) is a reference to a value of some type T(e.g. Tensor) on a remote worker. This handle keeps the referenced remote value alive on the owner, but there is no implication that the value will be transferred to the local worker in the future. RRefs can be used in multi-machine training by holding references to nn.Modules that exist on other workers, and calling the appropriate functions to retrieve or modify their parameters during training. See Remote Reference Protocol for more details.

class torch.distributed.rpc.PyRRef(RRef)

A class encapsulating a reference to a value of some type on a remote worker. This handle will keep the referenced remote value alive on the worker. A UserRRef will be deleted when 1) no references to it in both the application code and in the local RRef context, or 2) the application has called a graceful shutdown. Invoking methods on a deleted RRef leads to undefined behaviors. RRef implementation only offers best-effort error detection, and applications should not useUserRRefs after rpc.shutdown().

Warning

RRefs can only be serialized and deserialized by the RPC module. Serializing and deserializing RRefs without RPC (e.g., Python pickle, torch save() / load(), JIT save() / load(), etc.) will lead to errors.

Parameters

Example::

Following examples skip RPC initialization and shutdown code for simplicity. Refer to RPC docs for those details.

  1. Create an RRef using rpc.remote

import torch import torch.distributed.rpc as rpc rref = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))

get a copy of value from the RRef

x = rref.to_here()

  1. Create an RRef from a local object

import torch from torch.distributed.rpc import RRef x = torch.zeros(2, 2) rref = RRef(x)

  1. Share an RRef with other workers

On both worker0 and worker1:

def f(rref): return rref.to_here() + 1

On worker0:

import torch import torch.distributed.rpc as rpc from torch.distributed.rpc import RRef rref = RRef(torch.zeros(2, 2))

the following RPC shares the rref with worker1, reference

count is automatically updated.

rpc.rpc_sync("worker1", f, args=(rref,))

backward(self: torch._C._distributed_rpc.PyRRef, dist_autograd_ctx_id: int = -1, retain_graph: bool = False) → None

Runs the backward pass using the RRef as the root of the backward pass. If dist_autograd_ctx_id is provided, we perform a distributed backward pass using the provided ctx_id starting from the owner of the RRef. In this case,get_gradients() should be used to retrieve the gradients. If dist_autograd_ctx_idis None, it is assumed that this is a local autograd graph and we only perform a local backward pass. In the local case, the node calling this API has to be the owner of the RRef. The value of the RRef is expected to be a scalar Tensor.

Parameters

Example::

import torch.distributed.autograd as dist_autograd with dist_autograd.context() as context_id: rref.backward(context_id)

confirmed_by_owner(self: torch._C._distributed_rpc.PyRRef) → bool

Returns whether this RRef has been confirmed by the owner.OwnerRRef always returns true, while UserRRef only returns true when the owner knowns about this UserRRef.

is_owner(self: torch._C._distributed_rpc.PyRRef) → bool

Returns whether or not the current node is the owner of thisRRef.

local_value(self: torch._C._distributed_rpc.PyRRef) → object

If the current node is the owner, returns a reference to the local value. Otherwise, throws an exception.

owner(self: torch._C._distributed_rpc.PyRRef) → torch._C._distributed_rpc.WorkerInfo

Returns worker information of the node that owns this RRef.

owner_name(self: torch._C._distributed_rpc.PyRRef) → str

Returns worker name of the node that owns this RRef.

remote(self: torch._C._distributed_rpc.PyRRef, timeout: float = -1.0) → object

Create a helper proxy to easily launch a remote using the owner of the RRef as the destination to run functions on the object referenced by this RRef. More specifically,rref.remote().func_name(*args, **kwargs) is the same as the following:

def run(rref, func_name, args, kwargs): return getattr(rref.local_value(), func_name)(*args, **kwargs)

rpc.remote(rref.owner(), run, args=(rref, func_name, args, kwargs))

Parameters

timeout (float, optional) – Timeout for rref.remote(). If the creation of this RRefis not successfully completed within the timeout, then the next time there is an attempt to use the RRef (such as to_here), a timeout will be raised. If not provided, the default RPC timeout will be used. Please seerpc.remote() for specific timeout semantics forRRef.

Example::

from torch.distributed import rpc rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1)) rref.remote().size().to_here() # returns torch.Size([2, 2]) rref.remote().view(1, 4).to_here() # returns tensor([[1., 1., 1., 1.]])

rpc_async(self: torch._C._distributed_rpc.PyRRef, timeout: float = -1.0) → object

Create a helper proxy to easily launch an rpc_async using the owner of the RRef as the destination to run functions on the object referenced by this RRef. More specifically,rref.rpc_async().func_name(*args, **kwargs) is the same as the following:

def run(rref, func_name, args, kwargs): return getattr(rref.local_value(), func_name)(*args, **kwargs)

rpc.rpc_async(rref.owner(), run, args=(rref, func_name, args, kwargs))

Parameters

timeout (float, optional) – Timeout for rref.rpc_async(). If the call does not complete within this timeframe, an exception indicating so will be raised. If this argument is not provided, the default RPC timeout will be used.

Example::

from torch.distributed import rpc rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1)) rref.rpc_async().size().wait() # returns torch.Size([2, 2]) rref.rpc_async().view(1, 4).wait() # returns tensor([[1., 1., 1., 1.]])

rpc_sync(self: torch._C._distributed_rpc.PyRRef, timeout: float = -1.0) → object

Create a helper proxy to easily launch an rpc_sync using the owner of the RRef as the destination to run functions on the object referenced by this RRef. More specifically,rref.rpc_sync().func_name(*args, **kwargs) is the same as the following:

def run(rref, func_name, args, kwargs): return getattr(rref.local_value(), func_name)(*args, **kwargs)

rpc.rpc_sync(rref.owner(), run, args=(rref, func_name, args, kwargs))

Parameters

timeout (float, optional) – Timeout for rref.rpc_sync(). If the call does not complete within this timeframe, an exception indicating so will be raised. If this argument is not provided, the default RPC timeout will be used.

Example::

from torch.distributed import rpc rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1)) rref.rpc_sync().size() # returns torch.Size([2, 2]) rref.rpc_sync().view(1, 4) # returns tensor([[1., 1., 1., 1.]])

to_here(self: torch._C._distributed_rpc.PyRRef, timeout: float = -1.0) → object

Blocking call that copies the value of the RRef from the owner to the local node and returns it. If the current node is the owner, returns a reference to the local value.

Parameters

timeout (float, optional) – Timeout for to_here. If the call does not complete within this timeframe, an exception indicating so will be raised. If this argument is not provided, the default RPC timeout (60s) will be used.

More Information about RRef

RemoteModule

Warning

RemoteModule is not currently supported when using CUDA tensors

RemoteModule is an easy way to create an nn.Module remotely on a different process. The actual module resides on a remote host, but the local host has a handle to this module and invoke this module similar to a regular nn.Module. The invocation however incurs RPC calls to the remote end and can be performed asynchronously if needed via additional APIs supported by RemoteModule.

class torch.distributed.nn.api.remote_module.RemoteModule(*args, **kwargs)[source][source]

A RemoteModule instance can only be created after RPC initialization.

It creates a user-specified module on a specified remote node. It behaves like a regular nn.Module except that the forward method is executed on the remote node. It takes care of autograd recording to ensure the backward pass propagates gradients back to the corresponding remote module.

It generates two methods forward_async and forward based on the signature of the forward method of module_cls. forward_asyncruns asynchronously and returns a Future. The arguments of forward_asyncand forward are the same as the forward method of the module returned by the module_cls.

For example, if module_cls returns an instance of nn.Linear, that has forward method signature: def forward(input: Tensor) -> Tensor:, the generated RemoteModule will have 2 methods with the signatures:

def forward(input: Tensor) -> Tensor:

def forward_async(input: Tensor) -> Future[Tensor]:

Parameters

Returns

A remote module instance which wraps the Module created by the user-provided module_cls, it has a blocking forward method and an asynchronous forward_async method that returns a future of the forward call on the user-provided module on the remote side.

Example::

Run the following code in two different processes:

On worker 0:

import torch import torch.distributed.rpc as rpc from torch import nn, Tensor from torch.distributed.nn.api.remote_module import RemoteModule

rpc.init_rpc("worker0", rank=0, world_size=2) remote_linear_module = RemoteModule( "worker1/cpu", nn.Linear, args=(20, 30), ) input = torch.randn(128, 20) ret_fut = remote_linear_module.forward_async(input) ret = ret_fut.wait() rpc.shutdown()

On worker 1:

import torch import torch.distributed.rpc as rpc

rpc.init_rpc("worker1", rank=1, world_size=2) rpc.shutdown()

Furthermore, a more practical example that is combined withDistributedDataParallel (DDP) can be found in this tutorial.

get_module_rref()[source]

Return an RRef (RRef[nn.Module]) pointing to the remote module.

Return type

_RRef_[Module]

remote_parameters(recurse=True)[source]

Return a list of RRef pointing to the remote module’s parameters.

This can typically be used in conjunction with DistributedOptimizer.

Parameters

recurse (bool) – if True, then returns parameters of the remote module and all submodules of the remote module. Otherwise, returns only parameters that are direct members of the remote module.

Returns

A list of RRef (List[RRef[nn.Parameter]]) to remote module’s parameters.

Return type

list[torch.distributed.rpc.api.RRef[torch.nn.parameter.Parameter]]

Distributed Autograd Framework

Warning

Distributed autograd is not currently supported when using CUDA tensors

This module provides an RPC-based distributed autograd framework that can be used for applications such as model parallel training. In short, applications may send and receive gradient recording tensors over RPC. In the forward pass, we record when gradient recording tensors are sent over RPC and during the backward pass we use this information to perform a distributed backward pass using RPC. For more details see Distributed Autograd Design.

torch.distributed.autograd.backward(context_id: int, roots: List[Tensor], retain_graph=False) → None

Kicks off the distributed backward pass using the provided roots. This currently implements the FAST mode algorithm which assumes all RPC messages sent in the same distributed autograd context across workers would be part of the autograd graph during the backward pass.

We use the provided roots to discover the autograd graph and compute appropriate dependencies. This method blocks until the entire autograd computation is done.

We accumulate the gradients in the appropriatetorch.distributed.autograd.context on each of the nodes. The autograd context to be used is looked up given the context_id that is passed in whentorch.distributed.autograd.backward() is called. If there is no valid autograd context corresponding to the given ID, we throw an error. You can retrieve the accumulated gradients using theget_gradients() API.

Parameters

Example::

import torch.distributed.autograd as dist_autograd with dist_autograd.context() as context_id: pred = model.forward() loss = loss_func(pred, loss) dist_autograd.backward(context_id, loss)

class torch.distributed.autograd.context[source][source]

Context object to wrap forward and backward passes when using distributed autograd. The context_id generated in the withstatement is required to uniquely identify a distributed backward pass on all workers. Each worker stores metadata associated with thiscontext_id, which is required to correctly execute a distributed autograd pass.

Example::

import torch.distributed.autograd as dist_autograd with dist_autograd.context() as context_id: t1 = torch.rand((3, 3), requires_grad=True) t2 = torch.rand((3, 3), requires_grad=True) loss = rpc.rpc_sync("worker1", torch.add, args=(t1, t2)).sum() dist_autograd.backward(context_id, [loss])

torch.distributed.autograd.get_gradients(context_id: int) → Dict[Tensor, Tensor]

Retrieves a map from Tensor to the appropriate gradient for that Tensor accumulated in the provided context corresponding to the given context_idas part of the distributed autograd backward pass.

Parameters

context_id (int) – The autograd context id for which we should retrieve the gradients.

Returns

A map where the key is the Tensor and the value is the associated gradient for that Tensor.

Example::

import torch.distributed.autograd as dist_autograd with dist_autograd.context() as context_id: t1 = torch.rand((3, 3), requires_grad=True) t2 = torch.rand((3, 3), requires_grad=True) loss = t1 + t2 dist_autograd.backward(context_id, [loss.sum()]) grads = dist_autograd.get_gradients(context_id) print(grads[t1]) print(grads[t2])

More Information about RPC Autograd

Design Notes

The distributed autograd design note covers the design of the RPC-based distributed autograd framework that is useful for applications such as model parallel training.

The RRef design note covers the design of the RRef (Remote REFerence) protocol used to refer to values on remote workers by the framework.