Deferred Module Initialization — torchdistX 0.2.0 documentation (original) (raw)
TL;DR¶
Deferred Module Initialization feature consists of a deferred_init()function that constructs Module
instances without allocating storage for their tensors, and the accompanying materialize_module() andmaterialize_tensor() functions that can fully or partially materialize modules constructed by deferred_init(). The feature is meant to be used if a module is memory-wise too big or computationally too expensive to construct on a single machine, but needs to be inspected for various reasons before being initialized.
Problem¶
With ever increasing model sizes, it is becoming increasingly common for models to exceed the memory or compute capacity of a single machine or accelerator. This means training such models requires some sharding (a.k.a. partitioning) strategy to distribute parts of the model onto different computing nodes. However techniques such as 3D parallelism used to apply these strategies often need access to the model architecture to decide on the optimal strategy and this represents a chicken-egg problem.
Automated parallelism libraries (e.g. FSDP, DeepSpeed) either completely ignore this problem, meaning they expect the model to fit on a single machine, or they have some rudimentary workarounds to partially overcome it. For instance they use a technique that sequentially initializes model parameters while sharding them on-the-fly based on some predefined memory-size threshold. However the limitation of such workarounds is that these libraries are not able to see the whole architecture of the model that would enable them to make smarter sharding decisions.
What is Deferred Module Initialization?¶
Deferred Module Initialization addresses the problem mentioned above by offering three functions. deferred_init() is a non-intrusive function that enables users to defer the initialization of a Module
by skipping storage allocation for its parameters and buffers while also keeping a record of the operations performed on them in an in-memory graph. materialize_module() andmaterialize_tensor() are the accompanying functions that materialize (i.e. initialize) tensors or modules constructed within a previousdeferred_init() call by re-playing the operations recorded at that time.
API¶
Initialization¶
As mentioned above deferred_init()
is the “entry point” of the API and has the following signature:
torchdistx.deferred_init.deferred_init(module_fn, *args, **kwargs)[source]¶
Defers the initialization of a Module
.
This function forces all tensors constructed within module_fn
to be fake while also recording all operations performed on them. The modules and tensors returned from module_fn
can later be instantiated using the materialize_tensor() and materialize_module() functions.
Parameters
- module_fn (Callable [ [ ... ] , torch.nn.modules.module.Module ]) – A callable that takes arbitrary number of arguments and returns a
Module
instance. - args – The positional and keyword arguments to be passed to
module_fn
. - kwargs – The positional and keyword arguments to be passed to
module_fn
.
Return type
torch.nn.modules.module.Module
Warning
The operations performed on the parameters and buffers of a module will only be recorded while inside deferred_init()
. Avoid making changes to a module after its returned from deferred_init()
; otherwise it cannot be correctly materialized.
Note
The graph structure generated by deferred_init()
is fairly simple, albeit holds information that is specifically meant to materialize in-memory tensors as if they were initialized without deferral. In that sense its implementation and its purpose diverges from the much larger and feature rich solutions such as torch.fx and TorchScript.
Materialization¶
Modules, parameters, and buffers constructed within a deferred_init() call can later be materialized using the materialize_module()
andmaterialize_tensor()
functions.
torchdistx.deferred_init.materialize_module(module, buffers_only=False, check_fn=None)[source]¶
Materializes module
and its descendant modules.
Parameters
- module (torch.nn.modules.module.Module) – The module instance to materialize.
- buffers_only (bool) – A boolean value indicating whether to materialize the buffer tensors only.
- check_fn (Optional [ Callable [ [ torch.nn.modules.module.Module ] , bool ] ]) – An optional callable which takes a
Module
instance and returns a boolean value indicating whether to materialize it.
Return type
None
torchdistx.deferred_init.materialize_tensor(tensor)[source]¶
Materializes tensor
.
Parameters
tensor (torch.Tensor) – The tensor instance to materialize.
Return type
Warning
Once materialized a fake tensor will hold a reference to its materialized version. In order to avoid memory leaks make sure to dispose it when it is no longer required.
Examples¶
The simplest use case is to construct a module using deferred_init() and then later materialize it after some form of inspection usingmaterialize_module():
import torch
from torchdistx.deferred_init import deferred_init, materialize_module
Notice that
m
does not have any storage even though it appears to bebe a module allocated on CPU.
m = deferred_init(torch.nn.Linear, 5, 1): m.weight Parameter containing: tensor(..., device='cpu', requires_grad=True, fake=True)
Do some form of inspection.
...
At the end materialize the module.
materialize_module(m) m.weight Parameter containing: tensor([[-1.4677e+24, 4.5915e-41, 1.4013e-45, 0.0000e+00, -1.4677e+24, 4.5915e-41]], requires_grad=True)
It is also possible to materialize only a subset of modules, parameters, or buffers of a large model:
import torch
from torchdistx.deferred_init import ( ... deferred_init, ... materialize_module, ... materialize_tensor, ... )
class MyLargeModel(torch.nn.Module): ... ...
m = deferred_init(MyLargeModel):
Do some form of inspection (e.g. determine sharding strategy).
...
Only materialize
sublayer1
andsublayer2
.materialize_module(m.sublayer1) materialize_module(m.sublayer2)
Or materialize an individual parameter or buffer.
materialized_param = materialize_tensor(m.sublayer1.param1)
deferred_init() skips storage allocation even for explicitly passed device arguments:
import torch
from torchdistx.deferred_init import deferred_init, materialize_module
class MyModule(torch.nn.Module): ... def init(self): ... super().init() ... self.param = torch.nn.Parameter(torch.ones([3], device="cpu")) ... m = deferred_init(MyModule): m.param Parameter containing: tensor(..., device='cpu', size=(10, 10), requires_grad=True, fake=True)
materialize_module(m) m.param Parameter containing: tensor([1., 1., 1.], requires_grad=True)
Lazy modules can be used along with deferred_init() by wrapping the module construction and the dry-run call in a single function as demonstrated below:
import torch
from torchdistx.deferred_init import deferred_init
def MyLazyModule(out_features: int): ... lazy_m = torch.nn.LazyLinear(out_features) ... ... # Dry-run the module to infer the parameter and buffer shapes. ... lazy_m(torch.ones([10, 10])) ... ... return lazy_m
m = deferred_init(MyLazyModule, 10)
However note that deferred_init() and materialize functions use a “best effort” approach and are not guaranteed to always succeed. See theCommon Failure Patterns section below to learn more.
Common Failure Patterns¶
**A module using an operator that is not supported by the meta backend:**Internally deferred_init() relies on the meta backend. If the module to be constructed by deferred_init() uses an operator that is not yet supported by the meta backend, the operator call will fail. Fortunately such failures are easy to spot since the returned error message will clearly indicate which operator was the culprit. The solution in such case is to introduce meta backend support for the failed operation.
Mutable operator arguments: Although almost all PyTorch operators use either primitives (e.g. integers, floating-point numbers) or tensors as parameter types, if an operator accepts a mutable argument (e.g. a storage, blob, future) with Tensor
being an exception, deferred_init() will deliberately fail the operation since we cannot guarantee that the argument will have the same state during materialization.
In-place updated external tensors and inference tensors: As a follow-up of mutable arguments, if a tensor constructed from external data (e.g. viatorch.load()
, torch.from_numpy()
) is used as an argument to a meta operation within deferred_init(), its version counter will be tracked similar to Autograd. A change to the version counter, which practically means an in-place update to the tensor, will be checked during materialization and, if detected, an error will be raised since that would prevent the correct materialization. The rules are stricter for inference tensors; since in-place updates cannot be tracked for them any materialization call using an inference tensor as an argument will raise an error.
A module using tolist() or numpy() functions in its constructor: Currently Deferred Module Initialization does not support tracing calls to tolist()
and numpy()
functions. We consider this a temporary limitation and will work with the PyTorch core team to mitigate it in future releases.