ZeRO — DeepSpeed 0.16.8 documentation (original) (raw)

The Zero Redundancy Optimizer (ZeRO) removes the memory redundancies across data-parallel processes by partitioning the three model states (optimizer states, gradients, and parameters) across data-parallel processes instead of replicating them. By doing this, it boosts memory efficiency compared to classic data-parallelism while retaining its computational granularity and communication efficiency.

  1. ZeRO Stage 1: The optimizer states (e.g., for Adam optimizer, 32-bit weights, and the first, and second moment estimates) are partitioned across the processes, so that each process updates only its partition.
  2. ZeRO Stage 2: The reduced 16-bit gradients for updating the model weights are also partitioned such that each process retains only the gradients corresponding to its portion of the optimizer states.
  3. ZeRO Stage 3: The 16-bit model parameters are partitioned across the processes. ZeRO-3 will automatically collect and partition them during the forward and backward passes.

In addition, ZeRO-3 includes the infinity offload engine to form ZeRO-Infinity ([paper](https://arxiv.org/abs/2104.07857)), which can offload all model states to both CPU and NVMe memory for huge memory savings.

For a deep dive of our algorithms, please see our papers on ZeRO, ZeRO-Offload, and ZeRO-Infinity.

Note

DeepSpeed first included offloading capabilities with ZeRO-Offload, a system for offloading optimizer and gradient states to CPU memory within ZeRO-2. ZeRO-Infinity is the next generation of offloading capabilities, accessible to ZeRO-3. ZeRO-Infinity has all of the savings of ZeRO-Offload, plus is able to offload more the model weights and has more effective bandwidth utilization and overlapping of computation and communication.

Getting Started

If you are new to DeepSpeed, check out our Getting Started page.

Once you are training with DeepSpeed, enabling ZeRO-3 offload is as simple as enabling it in your DeepSpeed configuration! Below are a few examples of ZeRO-3 configurations. Please see our config guidefor a complete list of options for configuration and performance tuning.

ZeRO Configurations

All the settings for DeepSpeed ZeRO are set with the DeepSpeedZeroConfig. The dictionary provided under the zero_optimization entry of the main DeepSpeed configuration dict will be parsed and validated with this class. Sub-configurations for parameter offload and optimizer offload settings are parsed by DeepSpeedZeroOffloadParamConfig andDeepSpeedZeroOffloadOptimizerConfig.

class deepspeed.runtime.zero.config.DeepSpeedZeroConfig[source]

Sets parameters for ZeRO optimizations.

stage_: ZeroStageEnum_ = 0

Chooses different stages of ZeRO Optimizer. Stage 0, 1, 2, and 3 refer to disabled, optimizer state partitioning, and optimizer+gradient state partitioning, and optimizer+gradient+parameter partitioning, respectively.

contiguous_gradients_: bool_ = True

Copies the gradients to a contiguous buffer as they are produced. Avoids memory fragmentation during backward pass.

reduce_scatter_: bool_ = True

Uses reduce or reduce scatter instead of allreduce to average gradients

reduce_bucket_size_: int_ = 500,000,000

Number of elements reduced/allreduced at a time. Limits the memory required for the allgather for large model sizes

Constraints

use_multi_rank_bucket_allreduce_: bool_ = True

Combine the reduce buckets of the different ranks and do an All-Reduce instead of multiple Reduce ops. This feature is useful when the model is small and we want to scale it on too many GPUs which therefore reduces the message sizes of each packet.

allgather_partitions_: bool_ = True

Chooses between allgather collective or a series of broadcast collectives to gather updated parameters from all the GPUs at the end of each step

allgather_bucket_size_: int_ = 500,000,000

Number of elements allgathered at a time. Limits the memory required for the allgather for large model sizes

Constraints

overlap_comm_: Optional[bool]_ = None

Attempts to overlap the reduction of the gradients with backward computation

load_from_fp32_weights_: bool_ = True

Boolean indicating whether to initialize fp32 master weights from fp32 copies in checkpoint (no precision loss) or from model’s fp16 copies (with precision loss). This can be used to initialize optimizer state even when checkpoint is missing optimizer state.

elastic_checkpoint_: bool_ = False

Enable loading checkpoint that was saved by job with different GPU count. No longer supported.

offload_param_: Optional[DeepSpeedZeroOffloadParamConfig]_ = None

Enable offloading of model parameters to CPU or NVMe. This frees up GPU memory for larger models or batch sizes. Valid only with stage 3. Expects a dictionary containing values for DeepSpeedZeroOffloadParamConfig.

offload_optimizer_: Optional[DeepSpeedZeroOffloadOptimizerConfig]_ = None

Enable offloading of optimizer state to CPU or NVMe, and optimizer computation to CPU. This frees up GPU memory for larger models or batch sizes. Valid for ZeRO stage 1, 2, 3. Expects a dictionary containing values for DeepSpeedZeroOffloadOptimizerConfig.

sub_group_size_: int_ = 1,000,000,000

Tile size for parameter processing to fit massive models (with trillions of parameters). Used by ZeRO3-Offload and ZeRO-Infinity

Constraints

cpu_offload_param_: Optional[bool]_ = None

Deprecated, please use offload_param

cpu_offload_use_pin_memory_: Optional[bool]_ = None

Deprecated, please use offload_param or offload_optimizer

cpu_offload_: Optional[bool]_ = None

Deprecated, please use offload_optimizer

prefetch_bucket_size_: int_ = 50,000,000 (alias 'stage3_prefetch_bucket_size')

Maximum number of parameter elements to fetch ahead of use. Used by ZeRO3, ZeRO3-Offload, ZeRO-Infinity, and ZeRO-Inference.

Constraints

param_persistence_threshold_: int_ = 100,000 (alias 'stage3_param_persistence_threshold')

Do not partition parameters smaller than this threshold. Smaller values use less memory, but can greatly increase communication (especially latency-bound messages).

Constraints

model_persistence_threshold_: int_ = sys.maxsize (alias 'stage3_model_persistence_threshold')

Maximum number of parameter elements that can be persisted in GPU and not partitioned. This imposes an upper bound on the number of unpartitioned parameters resulting from param_persistence_threshold setting. Used by ZeRO3-Offload, ZeRO-Infinity and ZeRO-Inference.

Constraints

max_live_parameters_: int_ = 1,000,000,000 (alias 'stage3_max_live_parameters')

The maximum number of parameters resident per GPU before releasing. Smaller values use less memory, but perform more communication.

Constraints

max_reuse_distance_: int_ = 1,000,000,000 (alias 'stage3_max_reuse_distance')

Do not release a parameter if it will be reused within this threshold of parameters. Smaller values use less memory, but perform more communication.

Constraints

gather_16bit_weights_on_model_save_: bool_ = False (alias 'stage3_gather_16bit_weights_on_model_save')

Consolidate the weights before saving the model by save_16bit_model(). Since the weights are partitioned across GPUs, they aren’t part ofstate_dict, so this function automatically gathers the weights when this option is enabled and then saves the fp16 model weights.

module_granularity_threshold_: int_ = 0 (alias 'stage3_module_granularity_threshold')

The granularity of a module is determined by the ratio of “parameter_count / (1 + descendant count)”. ZeRO3 classifies modules with a granularity below the threshold as fine-grained, which are treated as integral units during parameter fetching. This reduces host overhead and the separate allgather overhead introduced by hooks for fine-grained layers when fetching parameters.

use_all_reduce_for_fetch_params_: bool_ = False (alias 'stage3_use_all_reduce_for_fetch_params')

Use all_reduce op when fetching module parameters at stage3. This improves performance by reducing the overhead of concatenation and slicing on the host.

stage3_gather_fp16_weights_on_model_save_: bool_ = False

Deprecated, please use gather_16bit_weights_on_model_save

ignore_unused_parameters_: bool_ = True

Unused parameters in modules may be unexpected in static networks, but could be normal in dynamic networks. This controls whether or not training should terminate with an error message when unused parameters are detected. This is set to True by default, which means unused parameters are ignored and training continues. Now is just used in stage 2.

legacy_stage1_: bool_ = False

For backward-compatibility enable old ZeRO stage 1 implementation. Use at your own risk, will be deprecated soon.

round_robin_gradients_: bool_ = False

Stage 1 and 2 optimization for CPU offloading that parallelizes gradient copying to CPU memory among ranks by fine-grained gradient partitioning. Performance benefit grows with gradient accumulation steps (more copying between optimizer steps) or GPU count (increased parallelism).

zero_hpz_partition_size_: int_ = 1

Number of ranks in zero parameters partitioning secondary group

Constraints

zero_quantized_weights_: bool_ = False

Boolean indicating whether to quantize zero parameters (weights) for efficient all_gather comm

zero_quantized_nontrainable_weights_: bool_ = False

Boolean indicating whether to quantize non-trainable zero parameters (weights) for efficient memory usage and communication. Different from zero_quantized_weights that stores the weights in original precision and only perform quantization during communication, this flag will store the weights in quantized precision. This is useful for LoRA training.

zero_quantized_gradients_: bool_ = False

Boolean indicating whether to use quantized zero gradients for efficient all_2_all_reduce comm

zeropp_loco_param_: Optional[Dict[str, Any]]_ = None

This dictionary contains parameters for using LoCo-Zero++, with two key parameters: - err_beta: A coefficient for the moving average of quantization errors before and after gradient computation. It ranges between 0 and 1, with a default value of 0.8. - reset_T: The number of steps after which the moving-average error buffer is cleared. The default value is 1024. These parameters can be adjusted based on performance needs. Example configuration in ds config: “zeropp_loco_param”: { “err_beta”: 0.8, “reset_T”: 1024 }. See LoCo paper for more details: (https://arxiv.org/abs/2407.04480).

mics_shard_size_: int_ = -1

mics_hierarchical_params_gather_: bool_ = False

memory_efficient_linear_: bool_ = True

Use memory efficient linear implementation, for Stage 3.

pipeline_loading_checkpoint_: bool_ = False

override_module_apply_: bool_ = True

Override nn.Module apply function, for Stage 3.

log_trace_cache_warnings_: bool_ = False

Whether to log warnings from trace cache, such as invalidation events.

class deepspeed.runtime.zero.config.DeepSpeedZeroOffloadParamConfig[source]

Set options for parameter offload. Valid only with stage 3.

device_: OffloadDeviceEnum_ = 'none'

Device memory to offload model parameters. Supported options are cpu andnvme.

nvme_path_: Optional[Path]_ = None

Filesystem path for NVMe device for parameter offloading.

buffer_count_: int_ = 5

Number of buffers in buffer pool for parameter offloading to NVMe.

Constraints

buffer_size_: int_ = 100,000,000

Size of buffers in buffer pool for parameter offloading to NVMe.

Constraints

max_in_cpu_: int_ = 1,000,000,000

Number of parameter elements to maintain in CPU memory when offloading to NVMe is enabled.

Constraints

pin_memory_: bool_ = False

Offload to page-locked CPU memory. This could boost throughput at the cost of extra memory overhead.

class deepspeed.runtime.zero.config.DeepSpeedZeroOffloadOptimizerConfig[source]

Set options for optimizer offload. Valid with stage 1, 2, and 3.

device_: OffloadDeviceEnum_ = 'none'

Device memory to offload optimizer state. Supported options are cpu andnvme. Optimizer computation is offload to CPU regardless of device option.

nvme_path_: Optional[Path]_ = None

Filesystem path for NVMe device for optimizer state offloading.

buffer_count_: int_ = 4

Number of buffers in buffer pool for optimizer state offloading to NVMe. This should be at least the number of states maintained per parameter by the optimizer. For example, Adam optimizer has 4 states (parameter, gradient, momentum, and variance).

Constraints

pin_memory_: bool_ = False

Offload to page-locked CPU memory. This could boost throughput at the cost of extra memory overhead.

pipeline_read_: bool_ = False

For tile-based optimizer step processing, overlap read of next tile with computation of current tile. Used in ZeRO-Infinity.

pipeline_write_: bool_ = False

For tile-based optimizer step processing, overlap write of previous tile with computation of current tile.

fast_init_: bool_ = False

Enable fast optimizer initialization when offloading to NVMe.

ratio_: float_ = 1.0

Percentage of offloaded optimizer states to CPU Adam. Only valid with ZeRO Stage 3.

Constraints

Example ZeRO-3 Configurations

  1. Use ZeRO to partition the optimizer states (stage 1), gradients (stage 2), and parameters (stage 3).

    {
    "zero_optimization": {
    "stage": 3,

},
"fp16": {
"enabled": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": 0.001,
"betas": [
0.8,
0.999
],
"eps": 1e-8,
"weight_decay": 3e-7
}
},
...
} 2. Additionally offload the optimizer states and computations to the CPU with ZeRO-Infinity.
{
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu"
}
},
...
} 3. Save even more memory by offloading parameters to the CPU memory.
{
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu"
}
"offload_param": {
"device": "cpu"
}
},
...
} 4. Save even MORE memory by offloading to NVMe (if available on your system):
{
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "nvme",
"nvme_path": "/nvme_data"
}
"offload_param": {
"device": "nvme",
"nvme_path": "/nvme_data"
}
},
...
}

MiCS Configurations

All MiCS configurations are set with DeepSpeedZeroConfig. MiCS assumes ZeRO stage 3 optimization is enabled. For now, there are two configuration fields of MiCS mics_shard_size and mics_hierarchical_params_gather. mics_shard_sizecontrols how many devices are used for partitioning the model states.mics_hierarchical_params_gather controls whether we use a two-stage hierarchical way to gather parameters in the forward computation.mics_hierarchical_params_gather is useful when model states are partitioned across multiple nodes and the cross-node bandwidth is slow. By default this is turned off.

Example MiCS Configurations

  1. Use MiCS to partition the model states (including optimizer states, gradients, and parameters). The following config example partitions the model states to eight devices, and assumes the eight devices are located within a single node (mics_hierarchical_params_gather is False).

    {
    "zero_optimization": {
    "stage": 3,

   "mics_shard_size": 8,  
   "mics_hierarchical_params_gather": False,  

},
...
}

Assumptions

DeepSpeed automatically coordinates the collection (i.e., all-gather), partitioning (i.e., scatter), and offloading of parameters at the granularity of (sub)module forward() methods. The backward pass is handled similarly. This strategy has two underlying assumptions:

  1. The forward and backward passes of submodules must individually fit in device memory. If this not the case, deepspeed.zero.TiledLinear implementsmemory-centric tiling and works with ZeRO-3 to break linear layers into a sequence of smaller submodules that can fit in memory.
  2. A module’s parameters are only accessed within its own __init__ and forward() methods. Otherwise, DeepSpeed must be instructed to collect and re-partition the parameter. See Manual Parameter Coordination for manually coordinating parameters.

Constructing Massive Models

ZeRO-3 enables massive models whose parameters exceed the size of individual nodes in a system. For the typical case of training without model parallelism, you can simply allocate your model in our context:

with deepspeed.zero.Init(): model = MyLargeModel()

Manual Parameter Coordination

Most models require no modification to be trained with ZeRO-3. However, in some cases one may need to access model weights outside of the training loop, or to share weights across submodules during training. DeepSpeed has several mechanisms to coordinate partitioned weights for ZeRO-3.

Gathering Parameters

DeepSpeed provides mechanisms for collecting (or gathering) a partitioned parameter.

Some models partitioned with deepspeed.zero.Init may need to access a module’s weights outside of the class constructor or its forward()method. We refer to these weights as external parameters, since these parameters are accessed outside of the module that created them. To do so, usedeepspeed.zero.GatheredParameters or deepspeed.zero.register_external_parameter().

Registering External Parameters

ZeRO-3 will automatically collect and partition the model parameters as they are needed during the forward and backward passes. However, in some cases a parameter may be used outside of its module’s forward pass. We call these_external_ parameters. ZeRO-3 can coordinate these parameters if they are registered either automatically or manually.

Note

DeepSpeed version 0.3.15 includes automatic external parameter discovery and registration to support the most common cases. Parameters can still be manually registered if they cannot be automatically detected.

DeepSpeed can automatically detect the following external parameter scenarios:

  1. Parameter access: consider the following pattern common in language models such as GPT:
    The tensor embeddings.weight is used in both embeddings.forward() andcompute_logits(). We call embeddings.weight an external parameter because it is used in the training loop outside of its owning module’s forward pass.
    class LanguageModel(torch.nn.Module):
    ...
    def forward(self, inputs):
    embeds = self.embeddings(inputs)
    ...
    logits = compute_logits(output, self.embeddings.weight)
    ...
  2. Returning a parameter:
    CustomLinear returns both an output and its own bias parameter. DeepSpeed will detect the external bias parameter and register it with submodules that use CustomLinear.
    class CustomLinear(torch.nn.Linear):
    def forward(self, *input):
    output = super().forward(*input)
    return output, self.bias

Overriding Module.apply

A convenient mechanism for customizing model initialization is Module.apply. With ZeRO stage 3, Module.apply implementations must account for parameter partitioning by zero.Init during model initialization. The default behavior of ZeRO stage 3 is to automatically handle this issue by overriding Module.apply to ensure that parameters are gathered before access by Module.apply. The benefit of this approach is development convenience, since users are saved the burden of manual parameter coordination in Module.apply. However, the downside is slow model initialization, since all the model parameters (e.g., billions) are gathered even though the common usage of Module.apply is to customize a few parameters. Developers can disable this default behavior by setting the override_module_apply configuration knob to False, for faster model initialization at the cost of manually handling partitioned parameters in their Module.apply implementations.

Memory-Centric Tiling

To reduce the working memory requirements of DL training for large models, ZeRO-Infinity includes technique called memory-centric tiling that exploits the data fetch and release pattern of ZeRO-3 to reduce the working memory requirements by breaking down a large operator into smaller tiles that can be executed sequentially. When combined with ZeRO-3, the parameter and gradients of each tile can be fetched and released one at a time, reducing the working memory proportional to the number of tiles. Therefore, ZeRO-Infinity can support operators of arbitrary sizes, without refactoring for model parallelism to fit them in limited GPU memory.

Debugging

Debugging ZeRO training is complicated by the partitioning of parameters, gradients, and optimizer states. None of these 3 groups of tensors (model states) can be normally accessed because of that. To overcome that DeepSpeed provides the following routines for accessing individual model states in both their partitioned (local) and unpartitioned (full) forms.

Important: Please note that, to access the unpartitioned (full) form, these utilities must be called by all processes participating in the training, even if you decide to do something with the result only in the main process. If all processes don’t participate these utilities will hang waiting for all processes to send their contribution.

Additionally, you must be aware that these routines return correct data only in specific phases of the training. So for examples the gradients are valid after backward and before step. The optimizer states are updated after step. Same goes for fp32 master weights.

deepspeed.utils.safe_get_full_fp32_param(param)[source]

Assemble and return the fp32 parameter of a low-precision (e.g., fp16) parameter.

Parameters

param (torch.nn.Parameter) – A model parameter

deepspeed.utils.safe_get_full_grad(param)[source]

Assemble and return the fp32 gradient of a low-precision (e.g., fp16) parameter. The return data type is that used for gradient accumulation. This is usually the param data type, but could also be different (e.g., bf16 param training with fp32 gradient accumulation).

Parameters

param (torch.nn.Parameter) – A model parameter

deepspeed.utils.safe_get_full_optimizer_state(param, optim_state_key)[source]

Assemble and return the fp32 optimizer state of a low-precision (e.g., fp16) parameter.

Parameters

deepspeed.utils.safe_get_local_fp32_param(param)[source]

Get the local partition of a ZeRO-3 partitioned parameter in fp32 precision. :param param: A model parameter. :type param: torch.nn.Parameter

deepspeed.utils.safe_get_local_grad(param)[source]

Get the local gradient partition of a ZeRO-3 partitioned parameter. The return data type is that used for gradient accumulation. This is usually the param data type, but could also be different (e.g., bf16 param training with fp32 gradient accumulation). :param param: A model parameter :type param: torch.nn.Parameter

deepspeed.utils.safe_get_local_optimizer_state(param, optim_state_key)[source]

Get the local optimizer state partition of ZeRO-3 partitioned parameter in fp32 precision. :param param: A model parameter :type param: torch.nn.Parameter:param optim_state_key: Key value of optimizer state (e.g., exp_avg in Adam optimizer) :type optim_state_key: string

These routines can be used in a training loop as shown in the following snippet.

backward(loss) [...] from deepspeed.utils import safe_get_full_fp32_param, safe_get_full_grad, safe_get_full_optimizer_state for n, lp in model.named_parameters(): # 1. Access the full states # 1.1) gradient lookup # For zero1 and zero2, gradient lookup must be called after backward and before step # For zero3, gradient lookup must be called after backward hp_grad = safe_get_full_grad(lp)

# 1.2) fp32 and optim states can probably be called anywhere in the training loop, but will be updated after `step`
hp = safe_get_full_fp32_param(lp)
exp_avg = safe_get_full_optimizer_state(lp, "exp_avg")
exp_avg_sq = safe_get_full_optimizer_state(lp, "exp_avg_sq")

# 2. Access the local states (zero3)
# For zero3, all of the parameters, gradients, and optimizer states are partitioned,
# and each process can access its corresponding local state.
local_hp = safe_get_local_fp32_param(lp)
local_hp_grad = safe_get_local_grad(lp)
local_exp_avg = safe_get_local_optimizer_state(lp, "exp_avg")
local_exp_avg_sq = safe_get_local_optimizer_state(lp, "exp_avg_sq")

[...] optimizer.step()

Modifying Partitioned States

Sometimes, a user may want to modify parameters, gradients, or optimizer states outside of the regular training loop. This is currently difficult in ZeRO training because of partitioning. To overcome that, DeepSpeed provides the following routines for modifying the fp32 master parameters and the fp32 optimizer states.

deepspeed.utils.safe_set_full_fp32_param(param, value)[source]

Update the partitioned fp32 parameter of a low-precision (e.g., fp16) parameter.

Parameters

deepspeed.utils.safe_set_full_optimizer_state(param, value, optim_state_key)[source]

Update the partitioned fp32 optimizer state of a low-precision (e.g., fp16) parameter.

Parameters

deepspeed.utils.safe_set_full_grad(param, value)[source]

Update the partitioned gradient of a low-precision (e.g., fp16) parameter. To avoid precision issues, the update value should have the data type of gradient accumulation.

Parameters

deepspeed.utils.safe_set_local_fp32_param(param, value)[source]

Update the local partition of ZeRO-3 partitioned parameter. :param param: A model parameter. :type param: torch.nn.Parameter:param value: New value of local parameter partition. :type value: torch.Tensor

deepspeed.utils.safe_set_local_grad(param, value)[source]

Update the local gradient partition of a ZeRO-3 partitioned parameter. To avoid precision issues, the update value should have the data type of gradient accumulation.

Parameters

deepspeed.utils.safe_set_local_optimizer_state(param, value, optim_state_key)[source]

Update the local optimizer state partition of a ZeRO-3 partitioned parameter. :param param: A model parameter. :type param: torch.nn.Parameter:param value: New value of local optimizer state partition. :type value: torch.Tensor:param optim_state_key: Key value of optimizer state (e.g., exp_avg in Adam optimizer). :type optim_state_key: string

The routines for modifying parameters and optimizer states can be used at any point after initialization of the DeepSpeed engine (i.e., deepspeed.initialize()) as shown in the following snippet.

[...] from deepspeed.runtime.zero.utils import is_zero_param from deepspeed.utils import safe_set_full_fp32_param, safe_set_full_optimizer_state from deepspeed.utils import safe_set_local_fp32_param, safe_set_local_optimizer_state

Here is an example to zero all the fp32 parameters and optimizer states.

for n, lp in model.named_parameters(): # 1. For zero stage 1, 2, or 3 set the full fp32 and their full optim states zero_tensor = torch.zeros(lp.ds_shape) if is_zero_param(lp) else torch.zeros(lp.shape)

safe_set_full_fp32_param(lp, zero_tensor)
safe_get_full_optimizer_state(lp, zero_tensor, "exp_avg")
safe_get_full_optimizer_state(lp, zero_tensor, "exp_avg_sq")

# 2. For zero stage 3, each process sets its local fp32 parameters and their local optimizer states individually
zero_tensor_local = torch.zeros(lp.ds_tensor.shape)

safe_set_local_fp32_param(lp, zero_tensor_local)
safe_set_local_optimizer_state(lp, zero_tensor_local, "exp_avg")
safe_set_local_optimizer_state(lp, zero_tensor_local, "exp_avg_sq")

[...]

The routines for modifying gradients can be used after backward but before step as shown in the following snippet.

backward(loss) [...] from deepspeed.runtime.zero.utils import is_zero_param from deepspeed.utils import safe_set_full_grad, safe_set_local_grad

Here is an example of how to zero all the gradients.

for n, lp in model.named_parameters(): # 1. For zero stage 1, 2, or 3 set the full gradient. zero_tensor = torch.zeros(lp.ds_shape) if is_zero_param(lp) else torch.zeros(lp.shape)

safe_set_full_grad(lp, zero_tensor)

# 2. For zero stage 3, each process sets its local gradient partition.
zero_tensor_local = torch.zeros_like(lp.ds_tensor.shape)

safe_set_local_grad(lp, zero_tensor_local)

[...] optimizer.step()

GPU Memory Management

By default at the end of training with ZeRO stage 3 some parameters could remain unpartitioned and use up some gpu memory. This is done on purpose as an optimization should you resume training again. If you’d like to clear out the cached parameters that use up gpu memory, you can call empty_partition_cache method of a DeepSpeed engine.

The following code snippet illustrates this functionality.

with zero.Init(): model = MyLargeModel()

ds_engine, _, _, _ = deepspeed.initialize(model, ...) for batch in ...: loss = ds_engine(batch) ds_engine.backward(batch) ds_engine.step()

Free GPU memory consumed by model parameters

ds_engine.empty_partition_cache()

Offload States

The DeepSpeed engine maintains a set of states in device memory (e.g., CUDA memory). The following API allows you to offload these states to a different device (currently, only CPU memory is supported), reducing the memory footprint on the device.

def offload_states(self, include: Container[OffloadStateTypeEnum] = None, device: OffloadDeviceEnum = OffloadDeviceEnum.cpu, pin_memory: bool = True, non_blocking: bool = False) -> None: """Offload the engine's states to the specified device.

Arguments:
    include: Optional. The set of states to offload. If not provided, all states are offloaded.
    device: Optional. The device to move the ZeRO optimizer buffers to. Currently only `OffloadDeviceEnum.cpu` is supported.
    pin_memory: Optional. Whether to pin the memory of the offloaded states.
    non_blocking: Optional. Whether to offload the states asynchronously.
"""

You can selectively offload specific states by specifying the OffloadStateTypeEnum in the include argument. OffloadStateTypeEnum is an enum that defines the states that can be offloaded. The following states are supported:

Note that offloading states comes with a trade-off between memory savings and computational overhead. This API allows states to be reloaded back into device memory when needed.

def reload_states(self, non_blocking: bool = False) -> None: """Reload the engine states to the original device.

Arguments:
    non_blocking: Optional. Whether to offload the states asynchronously.
"""

Below is an example code snippet demonstrating how to offload FP32 parameters and optimizer states to CPU memory:

Offload after forward, backward, and step

ds_engine.offload_states(include=[OffloadStateTypeEnum.hp_params, OffloadStateTypeEnum.optim_states])

Do something requiring a lot of device memory

...

Load states back to device memory

ds_engine.reload_states()

deepspeed.runtime.zero.offload_states.get_state_devices returns devices of the specified state.

def get_state_devices(model, state: OffloadStateTypeEnum) -> Set[torch.device]: """Retrieve the devices of the specified state of the model.

Args:
    model (DeepSpeedEngine): The model whose device allocations are to be checked.
    state (OffloadStateTypeEnum): The specific state for which the devices should be retrieved.

Returns:
    Set[torch.device]: A set of devices of the specified state.

"""