Customize Process Group Backends Using Cpp Extensions — PyTorch Tutorials 2.7.0+cu126 documentation (original) (raw)

intermediate/process_group_cpp_extension_tutorial

Run in Google Colab

Colab

Download Notebook

Notebook

View on GitHub

GitHub

Created On: Feb 01, 2022 | Last Updated: Nov 14, 2024 | Last Verified: Nov 05, 2024

Author: Howard Huang, Feng Tian, Shen Li, Min Si

Note

edit View and edit this tutorial in github.

Prerequisites:

This tutorial demonstrates how to implement a custom Backend and plug that intoPyTorch distributed package usingcpp extensions. This is helpful when you need a specialized software stack for your hardware, or when you would like to experiment with new collective communication algorithms.

Basics

PyTorch collective communications power several widely adopted distributed training features, includingDistributedDataParallel andZeroRedundancyOptimizer. In order to make the same collective communication API work with different communication backends, the distributed package abstracts collective communication operations into aBackendclass. Different backends can then be implemented as subclasses of Backend using preferred third-party libraries. PyTorch distributed comes with three default backends,ProcessGroupNCCL, ProcessGroupGloo, and ProcessGroupMPI. However, beyond these three backends, there are also other communication libraries (e.g., UCC,OneCCL), different types of hardware (e.g., TPU,Trainum), and emerging communication algorithms (e.g.,Herring,Reduction Server). Therefore, the distributed package exposes extension APIs to allow customizing collective communication backends.

The 4 steps below show how to implement a dummy Backend backend and use that in Python application code. Please note that this tutorial focuses on demonstrating the extension APIs, instead of developing a functioning communication backend. Hence, the dummy backend just covers a subset of the APIs (all_reduce and all_gather), and simply sets the values of tensors to 0.

Step 1: Implement a Subclass of Backend

This first step is to implement a Backend subclass that overrides target collective communication APIs and runs the custom communication algorithm. The extension also needs to implement a Work subclass, which serves as a future of communication results and allows asynchronous execution in application code. If the extension uses third-party libraries, it can include the headers and call into the library APIs from the BackendDummysubclass. The two code snippets below present the implementation of dummy.h anddummy.cpp. See the dummy collectivesrepository for the full implementation.

// file name: dummy.hpp #include <torch/python.h>

#include <torch/csrc/distributed/c10d/Backend.hpp> #include <torch/csrc/distributed/c10d/Work.hpp> #include <torch/csrc/distributed/c10d/Store.hpp> #include <torch/csrc/distributed/c10d/Types.hpp> #include <torch/csrc/distributed/c10d/Utils.hpp>

#include <pybind11/chrono.h>

namespace c10d {

class BackendDummy : public Backend { public: BackendDummy(int rank, int size);

c10::intrusive_ptr<Work> allgather(
    std::vector<std::vector<at::Tensor>>& outputTensors,
    std::vector<at::Tensor>& inputTensors,
    const AllgatherOptions& opts = AllgatherOptions()) override;

c10::intrusive_ptr<Work> allreduce(
    std::vector<at::Tensor>& tensors,
    const AllreduceOptions& opts = AllreduceOptions()) override;

// The collective communication APIs without a custom implementation
// will error out if invoked by application code.

};

class WorkDummy : public Work { public: WorkDummy( OpType opType, c10::intrusive_ptrc10::ivalue::Future future) // future of the output : Work( -1, // rank, only used by recvAnySource, irrelevant in this demo opType), future_(std::move(future)) {} bool isCompleted() override; bool isSuccess() const override; bool wait(std::chrono::milliseconds timeout = kUnsetTimeout) override; virtual c10::intrusive_ptrc10::ivalue::Future getFuture() override;

private: c10::intrusive_ptrc10::ivalue::Future future_; }; } // namespace c10d

// file name: dummy.cpp #include "dummy.hpp"

namespace c10d {

// This is a dummy allgather that sets all output tensors to zero // Modify the implementation to conduct real communication asynchronously c10::intrusive_ptr BackendDummy::allgather( std::vector<std::vectorat::Tensor>& outputTensors, std::vectorat::Tensor& inputTensors, const AllgatherOptions& /* unused */) { for (auto& outputTensorVec : outputTensors) { for (auto& outputTensor : outputTensorVec) { outputTensor.zero_(); } }

auto future = c10::make_intrusive<c10::ivalue::Future>(
    c10::ListType::create(c10::ListType::create(c10::TensorType::get())));
future->markCompleted(c10::IValue(outputTensors));
return c10::make_intrusive<WorkDummy>(OpType::ALLGATHER, std::move(future));

}

// This is a dummy allreduce that sets all output tensors to zero // Modify the implementation to conduct real communication asynchronously c10::intrusive_ptr BackendDummy::allreduce( std::vectorat::Tensor& tensors, const AllreduceOptions& opts) { for (auto& tensor : tensors) { tensor.zero_(); }

auto future = c10::make_intrusive<c10::ivalue::Future>(
    c10::ListType::create(c10::TensorType::get()));
future->markCompleted(c10::IValue(tensors));
return c10::make_intrusive<WorkDummy>(OpType::ALLGATHER, std::move(future));

} } // namespace c10d

Step 2: Expose The Extension Python APIs

The backend constructors are calledfrom Python side, so the extension also needs to expose the constructor APIs to Python. This can be done by adding the following methods. In this example, store andtimeout are ignored by the BackendDummy instantiation method, as those are not used in this dummy implementation. However, real-world extensions should consider using the store to perform rendezvous and supporting thetimeout argument.

// file name: dummy.hpp class BackendDummy : public Backend { ... <Step 1 code> ...

static c10::intrusive_ptr<Backend> createBackendDummy(
    const c10::intrusive_ptr<::c10d::Store>& store,
    int rank,
    int size,
    const std::chrono::duration<float>& timeout);

static void BackendDummyConstructor() __attribute__((constructor)) {
    py::object module = py::module::import("torch.distributed");
    py::object register_backend =
        module.attr("Backend").attr("register_backend");
    // torch.distributed.Backend.register_backend will add `dummy` as a
    // new valid backend.
    register_backend("dummy", py::cpp_function(createBackendDummy));
}

}

// file name: dummy.cpp c10::intrusive_ptr BackendDummy::createBackendDummy( const c10::intrusive_ptr<::c10d::Store>& /* unused /, int rank, int size, const std::chrono::duration& / unused */) { return c10::make_intrusive(rank, size); }

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("createBackendDummy", &BackendDummy::createBackendDummy); }

Step 3: Build The Custom Extension

Now, the extension source code files are ready. We can then usecpp extensionsto build it. To do that, create a setup.py file that prepares the paths and commands. Then call python setup.py develop to install the extension.

If the extension depends on third-party libraries, you can also specifylibraries_dirs and libraries to the cpp extension APIs. See thetorch uccproject as a real-world example.

file name: setup.py

import os import sys import torch from setuptools import setup from torch.utils import cpp_extension

sources = ["src/dummy.cpp"] include_dirs = [f"{os.path.dirname(os.path.abspath(file))}/include/"]

if torch.cuda.is_available(): module = cpp_extension.CUDAExtension( name = "dummy_collectives", sources = sources, include_dirs = include_dirs, ) else: module = cpp_extension.CppExtension( name = "dummy_collectives", sources = sources, include_dirs = include_dirs, )

setup( name = "Dummy-Collectives", version = "0.0.1", ext_modules = [module], cmdclass={'build_ext': cpp_extension.BuildExtension} )

Step 4: Use The Extension in Application

After installation, you can conveniently use the dummy backend when callinginit_process_groupas if it is an builtin backend.

We can specify dispatching based on backend by changing the backend argument of init_process_group. We can dispatch collective with CPU tensor to gloo backend and dispatch collective with CUDA tensor to dummy backend by specifying cpu:gloo,cuda:dummy as the backend argument.

To send all tensors to dummy backend, we can simply specify dummy as the backend argument.

import os

import torch

importing dummy_collectives makes torch.distributed recognize dummy

as a valid backend.

import dummy_collectives

import torch.distributed as dist

os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '29500'

Alternatively:

dist.init_process_group("dummy", rank=0, world_size=1)

dist.init_process_group("cpu:gloo,cuda:dummy", rank=0, world_size=1)

this goes through gloo

x = torch.ones(6) dist.all_reduce(x) print(f"cpu allreduce: {x}")

this goes through dummy

if torch.cuda.is_available(): y = x.cuda() dist.all_reduce(y) print(f"cuda allreduce: {y}")

try:
    dist.broadcast(y, 0)
except RuntimeError:
    print("got RuntimeError when calling broadcast")