FullyShardedDataParallel — PyTorch 2.7 documentation (original) (raw)

class torch.distributed.fsdp.FullyShardedDataParallel(module, process_group=None, sharding_strategy=None, cpu_offload=None, auto_wrap_policy=None, backward_prefetch=BackwardPrefetch.BACKWARD_PRE, mixed_precision=None, ignored_modules=None, param_init_fn=None, device_id=None, sync_module_states=False, forward_prefetch=False, limit_all_gathers=True, use_orig_params=False, ignored_states=None, device_mesh=None)[source][source]

A wrapper for sharding module parameters across data parallel workers.

This is inspired by Xu et al. as well as the ZeRO Stage 3 from DeepSpeed. FullyShardedDataParallel is commonly shortened to FSDP.

To understand FSDP internals, refer to theFSDP Notes.

Example:

import torch from torch.distributed.fsdp import FullyShardedDataParallel as FSDP torch.cuda.set_device(device_id) sharded_module = FSDP(my_module) optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001) x = sharded_module(x, y=3, z=torch.Tensor([1])) loss = x.sum() loss.backward() optim.step()

Using FSDP involves wrapping your module and then initializing your optimizer after. This is required since FSDP changes the parameter variables.

When setting up FSDP, you need to consider the destination CUDA device. If the device has an ID (dev_id), you have three options:

This ensures that the FSDP instance’s compute device is the destination device. For option 1 and 3, the FSDP initialization always occurs on GPU. For option 2, the FSDP initialization happens on module’s current device, which may be a CPU.

If you’re using the sync_module_states=True flag, you need to ensure that the module is on a GPU or use the device_idargument to specify a CUDA device that FSDP will move the module to in the FSDP constructor. This is necessary becausesync_module_states=True requires GPU communication.

FSDP also takes care of moving input tensors to the forward method to the GPU compute device, so you don’t need to manually move them from CPU.

For use_orig_params=True,ShardingStrategy.SHARD_GRAD_OP exposes the unsharded parameters, not the sharded parameters after forward, unlikeShardingStrategy.FULL_SHARD. If you want to inspect the gradients, you can use the summon_full_paramsmethod with with_grads=True.

With limit_all_gathers=True, you may see a gap in the FSDP pre-forward where the CPU thread is not issuing any kernels. This is intentional and shows the rate limiter in effect. Synchronizing the CPU thread in that way prevents over-allocating memory for subsequent all-gathers, and it should not actually delay GPU kernel execution.

FSDP replaces managed modules’ parameters with torch.Tensorviews during forward and backward computation for autograd-related reasons. If your module’s forward relies on saved references to the parameters instead of reacquiring the references each iteration, then it will not see FSDP’s newly created views, and autograd will not work correctly.

Finally, when using sharding_strategy=ShardingStrategy.HYBRID_SHARDwith the sharding process group being intra-node and the replication process group being inter-node, settingNCCL_CROSS_NIC=1 can help improve the all-reduce times over the replication process group for some cluster setups.

Limitations

There are several limitations to be aware of when using FSDP:

Parameters

apply(fn)[source][source]

Apply fn recursively to every submodule (as returned by .children()) as well as self.

Typical use includes initializing the parameters of a model (see also torch.nn.init).

Compared to torch.nn.Module.apply, this version additionally gathers the full parameters before applying fn. It should not be called from within another summon_full_params context.

Parameters

fn (Module -> None) – function to be applied to each submodule

Returns

self

Return type

Module

check_is_root()[source][source]

Check if this instance is a root FSDP module.

Return type

bool

clip_grad_norm_(max_norm, norm_type=2.0)[source][source]

Clip the gradient norm of all parameters.

The norm is computed over all parameters’ gradients as viewed as a single vector, and the gradients are modified in-place.

Parameters

Returns

Total norm of the parameters (viewed as a single vector).

Return type

Tensor

If every FSDP instance uses NO_SHARD, meaning that no gradients are sharded across ranks, then you may directly usetorch.nn.utils.clip_grad_norm_().

If at least some FSDP instance uses a sharded strategy (i.e. one other than NO_SHARD), then you should use this method instead of torch.nn.utils.clip_grad_norm_() since this method handles the fact that gradients are sharded across ranks.

The total norm returned will have the “largest” dtype across all parameters/gradients as defined by PyTorch’s type promotion semantics. For example, if all parameters/gradients use a low precision dtype, then the returned norm’s dtype will be that low precision dtype, but if there exists at least one parameter/ gradient using FP32, then the returned norm’s dtype will be FP32.

Warning

This needs to be called on all ranks since it uses collective communications.

static flatten_sharded_optim_state_dict(sharded_optim_state_dict, model, optim)[source][source]

Flatten a sharded optimizer state-dict.

The API is similar to shard_full_optim_state_dict(). The only difference is that the input sharded_optim_state_dict should be returned from sharded_optim_state_dict(). Therefore, there will be all-gather calls on each rank to gather ShardedTensor s.

Parameters

Returns

Refer to shard_full_optim_state_dict().

Return type

dict[str, Any]

forward(*args, **kwargs)[source][source]

Run the forward pass for the wrapped module, inserting FSDP-specific pre- and post-forward sharding logic.

Return type

Any

static fsdp_modules(module, root_only=False)[source][source]

Return all nested FSDP instances.

This possibly includes module itself and only includes FSDP root modules if root_only=True.

Parameters

Returns

FSDP modules that are nested in the input module.

Return type

List[FullyShardedDataParallel]

static full_optim_state_dict(model, optim, optim_input=None, rank0_only=True, group=None)[source][source]

Return the full optimizer state-dict.

Consolidates the full optimizer state on rank 0 and returns it as a dict following the convention oftorch.optim.Optimizer.state_dict(), i.e. with keys "state"and "param_groups". The flattened parameters in FSDP modules contained in model are mapped back to their unflattened parameters.

This needs to be called on all ranks since it uses collective communications. However, if rank0_only=True, then the state dict is only populated on rank 0, and all other ranks return an empty dict.

Unlike torch.optim.Optimizer.state_dict(), this method uses full parameter names as keys instead of parameter IDs.

Like in torch.optim.Optimizer.state_dict(), the tensors contained in the optimizer state dict are not cloned, so there may be aliasing surprises. For best practices, consider saving the returned optimizer state dict immediately, e.g. usingtorch.save().

Parameters

Returns

A dict containing the optimizer state formodel ‘s original unflattened parameters and including keys “state” and “param_groups” following the convention oftorch.optim.Optimizer.state_dict(). If rank0_only=True, then nonzero ranks return an empty dict.

Return type

Dict[str, Any]

static get_state_dict_type(module)[source][source]

Get the state_dict_type and the corresponding configurations for the FSDP modules rooted at module.

The target module does not have to be an FSDP module.

Returns

A StateDictSettings containing the state_dict_type and state_dict / optim_state_dict configs that are currently set.

Raises

Return type

StateDictSettings

property module_: Module_

Return the wrapped module.

named_buffers(*args, **kwargs)[source][source]

Return an iterator over module buffers, yielding both the name of the buffer and the buffer itself.

Intercepts buffer names and removes all occurrences of the FSDP-specific flattened buffer prefix when inside the summon_full_params() context manager.

Return type

Iterator[tuple[str, torch.Tensor]]

named_parameters(*args, **kwargs)[source][source]

Return an iterator over module parameters, yielding both the name of the parameter and the parameter itself.

Intercepts parameter names and removes all occurrences of the FSDP-specific flattened parameter prefix when inside the summon_full_params() context manager.

Return type

Iterator[tuple[str, torch.nn.parameter.Parameter]]

no_sync()[source][source]

Disable gradient synchronizations across FSDP instances.

Within this context, gradients will be accumulated in module variables, which will later be synchronized in the first forward-backward pass after exiting the context. This should only be used on the root FSDP instance and will recursively apply to all children FSDP instances.

Note

This likely results in higher memory usage because FSDP will accumulate the full model gradients (instead of gradient shards) until the eventual sync.

Note

When used with CPU offloading, the gradients will not be offloaded to CPU when inside the context manager. Instead, they will only be offloaded right after the eventual sync.

Return type

Generator

static optim_state_dict(model, optim, optim_state_dict=None, group=None)[source][source]

Transform the state-dict of an optimizer corresponding to a sharded model.

The given state-dict can be transformed to one of three types: 1) full optimizer state_dict, 2) sharded optimizer state_dict, 3) local optimizer state_dict.

For full optimizer state_dict, all states are unflattened and not sharded. Rank0 only and CPU only can be specified via state_dict_type() to avoid OOM.

For sharded optimizer state_dict, all states are unflattened but sharded. CPU only can be specified via state_dict_type() to further save memory.

For local state_dict, no transformation will be performed. But a state will be converted from nn.Tensor to ShardedTensor to represent its sharding nature (this is not supported yet).

Example:

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.distributed.fsdp import FullStateDictConfig from torch.distributed.fsdp import FullOptimStateDictConfig

Save a checkpoint

model, optim = ... FSDP.set_state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(rank0_only=False), FullOptimStateDictConfig(rank0_only=False), ) state_dict = model.state_dict() optim_state_dict = FSDP.optim_state_dict(model, optim) save_a_checkpoint(state_dict, optim_state_dict)

Load a checkpoint

model, optim = ... state_dict, optim_state_dict = load_a_checkpoint() FSDP.set_state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(rank0_only=False), FullOptimStateDictConfig(rank0_only=False), ) model.load_state_dict(state_dict) optim_state_dict = FSDP.optim_state_dict_to_load( model, optim, optim_state_dict ) optim.load_state_dict(optim_state_dict)

Parameters

Returns

A dict containing the optimizer state formodel. The sharding of the optimizer state is based onstate_dict_type.

Return type

Dict[str, Any]

static optim_state_dict_to_load(model, optim, optim_state_dict, is_named_optimizer=False, load_directly=False, group=None)[source][source]

Convert an optimizer state-dict so that it can be loaded into the optimizer associated with the FSDP model.

Given a optim_state_dict that is transformed throughoptim_state_dict(), it gets converted to the flattened optimizer state_dict that can be loaded to optim which is the optimizer formodel. model must be sharded by FullyShardedDataParallel.

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.distributed.fsdp import FullStateDictConfig from torch.distributed.fsdp import FullOptimStateDictConfig

Save a checkpoint

model, optim = ... FSDP.set_state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(rank0_only=False), FullOptimStateDictConfig(rank0_only=False), ) state_dict = model.state_dict() original_osd = optim.state_dict() optim_state_dict = FSDP.optim_state_dict( model, optim, optim_state_dict=original_osd ) save_a_checkpoint(state_dict, optim_state_dict)

Load a checkpoint

model, optim = ... state_dict, optim_state_dict = load_a_checkpoint() FSDP.set_state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(rank0_only=False), FullOptimStateDictConfig(rank0_only=False), ) model.load_state_dict(state_dict) optim_state_dict = FSDP.optim_state_dict_to_load( model, optim, optim_state_dict ) optim.load_state_dict(optim_state_dict)

Parameters

Return type

dict[str, Any]

register_comm_hook(state, hook)[source][source]

Register a communication hook.

This is an enhancement that provides a flexible hook to users where they can specify how FSDP aggregates gradients across multiple workers. This hook can be used to implement several algorithms likeGossipGrad and gradient compression which involve different communication strategies for parameter syncs while training with FullyShardedDataParallel.

Warning

FSDP communication hook should be registered before running an initial forward pass and only once.

Parameters

static rekey_optim_state_dict(optim_state_dict, optim_state_key_type, model, optim_input=None, optim=None)[source][source]

Re-keys the optimizer state dict optim_state_dict to use the key type optim_state_key_type.

This can be used to achieve compatibility between optimizer state dicts from models with FSDP instances and ones without.

To re-key an FSDP full optimizer state dict (i.e. fromfull_optim_state_dict()) to use parameter IDs and be loadable to a non-wrapped model:

wrapped_model, wrapped_optim = ... full_osd = FSDP.full_optim_state_dict(wrapped_model, wrapped_optim) nonwrapped_model, nonwrapped_optim = ... rekeyed_osd = FSDP.rekey_optim_state_dict(full_osd, OptimStateKeyType.PARAM_ID, nonwrapped_model) nonwrapped_optim.load_state_dict(rekeyed_osd)

To re-key a normal optimizer state dict from a non-wrapped model to be loadable to a wrapped model:

nonwrapped_model, nonwrapped_optim = ... osd = nonwrapped_optim.state_dict() rekeyed_osd = FSDP.rekey_optim_state_dict(osd, OptimStateKeyType.PARAM_NAME, nonwrapped_model) wrapped_model, wrapped_optim = ... sharded_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, wrapped_model) wrapped_optim.load_state_dict(sharded_osd)

Returns

The optimizer state dict re-keyed using the parameter keys specified by optim_state_key_type.

Return type

Dict[str, Any]

static scatter_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, optim=None, group=None)[source][source]

Scatter the full optimizer state dict from rank 0 to all other ranks.

Returns the sharded optimizer state dict on each rank. The return value is the same as shard_full_optim_state_dict(), and on rank 0, the first argument should be the return value offull_optim_state_dict().

Example:

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP model, optim = ... full_osd = FSDP.full_optim_state_dict(model, optim) # only non-empty on rank 0

Define new model with possibly different world size

new_model, new_optim, new_group = ... sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, new_model, group=new_group) new_optim.load_state_dict(sharded_osd)

Note

Both shard_full_optim_state_dict() andscatter_full_optim_state_dict() may be used to get the sharded optimizer state dict to load. Assuming that the full optimizer state dict resides in CPU memory, the former requires each rank to have the full dict in CPU memory, where each rank individually shards the dict without any communication, while the latter requires only rank 0 to have the full dict in CPU memory, where rank 0 moves each shard to GPU memory (for NCCL) and communicates it to ranks appropriately. Hence, the former has higher aggregate CPU memory cost, while the latter has higher communication cost.

Parameters

Returns

The full optimizer state dict now remapped to flattened parameters instead of unflattened parameters and restricted to only include this rank’s part of the optimizer state.

Return type

Dict[str, Any]

static set_state_dict_type(module, state_dict_type, state_dict_config=None, optim_state_dict_config=None)[source][source]

Set the state_dict_type of all the descendant FSDP modules of the target module.

Also takes (optional) configuration for the model’s and optimizer’s state dict. The target module does not have to be a FSDP module. If the target module is a FSDP module, its state_dict_type will also be changed.

Note

This API should be called for only the top-level (root) module.

Note

This API enables users to transparently use the conventionalstate_dict API to take model checkpoints in cases where the root FSDP module is wrapped by another nn.Module. For example, the following will ensure state_dict is called on all non-FSDP instances, while dispatching into sharded_state_dict implementation for FSDP:

Example:

model = DDP(FSDP(...)) FSDP.set_state_dict_type( model, StateDictType.SHARDED_STATE_DICT, state_dict_config = ShardedStateDictConfig(offload_to_cpu=True), optim_state_dict_config = OptimStateDictConfig(offload_to_cpu=True), ) param_state_dict = model.state_dict() optim_state_dict = FSDP.optim_state_dict(model, optim)

Parameters

Returns

A StateDictSettings that include the previous state_dict type and configuration for the module.

Return type

StateDictSettings

static shard_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, optim=None)[source][source]

Shard a full optimizer state-dict.

Remaps the state in full_optim_state_dict to flattened parameters instead of unflattened parameters and restricts to only this rank’s part of the optimizer state. The first argument should be the return value of full_optim_state_dict().

Example:

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP model, optim = ... full_osd = FSDP.full_optim_state_dict(model, optim) torch.save(full_osd, PATH)

Define new model with possibly different world size

new_model, new_optim = ... full_osd = torch.load(PATH) sharded_osd = FSDP.shard_full_optim_state_dict(full_osd, new_model) new_optim.load_state_dict(sharded_osd)

Note

Both shard_full_optim_state_dict() andscatter_full_optim_state_dict() may be used to get the sharded optimizer state dict to load. Assuming that the full optimizer state dict resides in CPU memory, the former requires each rank to have the full dict in CPU memory, where each rank individually shards the dict without any communication, while the latter requires only rank 0 to have the full dict in CPU memory, where rank 0 moves each shard to GPU memory (for NCCL) and communicates it to ranks appropriately. Hence, the former has higher aggregate CPU memory cost, while the latter has higher communication cost.

Parameters

Returns

The full optimizer state dict now remapped to flattened parameters instead of unflattened parameters and restricted to only include this rank’s part of the optimizer state.

Return type

Dict[str, Any]

static sharded_optim_state_dict(model, optim, group=None)[source][source]

Return the optimizer state-dict in its sharded form.

The API is similar to full_optim_state_dict() but this API chunks all non-zero-dimension states to ShardedTensor to save memory. This API should only be used when the model state_dict is derived with the context manager with state_dict_type(SHARDED_STATE_DICT):.

For the detailed usage, refer to full_optim_state_dict().

Warning

The returned state dict contains ShardedTensor and cannot be directly used by the regular optim.load_state_dict.

Return type

dict[str, Any]

static state_dict_type(module, state_dict_type, state_dict_config=None, optim_state_dict_config=None)[source][source]

Set the state_dict_type of all the descendant FSDP modules of the target module.

This context manager has the same functions as set_state_dict_type(). Read the document ofset_state_dict_type() for the detail.

Example:

model = DDP(FSDP(...)) with FSDP.state_dict_type( model, StateDictType.SHARDED_STATE_DICT, ): checkpoint = model.state_dict()

Parameters

Return type

Generator

static summon_full_params(module, recurse=True, writeback=True, rank0_only=False, offload_to_cpu=False, with_grads=False)[source][source]

Expose full params for FSDP instances with this context manager.

Can be useful after forward/backward for a model to get the params for additional processing or checking. It can take a non-FSDP module and will summon full params for all contained FSDP modules as well as their children, depending on the recurse argument.

Note

This can be used on inner FSDPs.

Note

This can not be used within a forward or backward pass. Nor can forward and backward be started from within this context.

Note

Parameters will revert to their local shards after the context manager exits, storage behavior is the same as forward.

Note

The full parameters can be modified, but only the portion corresponding to the local param shard will persist after the context manager exits (unless writeback=False, in which case changes will be discarded). In the case where FSDP does not shard the parameters, currently only when world_size == 1, or NO_SHARDconfig, the modification is persisted regardless of writeback.

Note

This method works on modules which are not FSDP themselves but may contain multiple independent FSDP units. In that case, the given arguments will apply to all contained FSDP units.

Warning

Note that rank0_only=True in conjunction withwriteback=True is not currently supported and will raise an error. This is because model parameter shapes would be different across ranks within the context, and writing to them can lead to inconsistency across ranks when the context is exited.

Warning

Note that offload_to_cpu and rank0_only=False will result in full parameters being redundantly copied to CPU memory for GPUs that reside on the same machine, which may incur the risk of CPU OOM. It is recommended to use offload_to_cpu withrank0_only=True.

Parameters

Return type

Generator

class torch.distributed.fsdp.BackwardPrefetch(value)[source][source]

This configures explicit backward prefetching, which improves throughput by enabling communication and computation overlap in the backward pass at the cost of slightly increased memory usage.

For more technical context: For a single process group using NCCL backend, any collectives, even if issued from different streams, contend for the same per-device NCCL stream, which implies that the relative order in which the collectives are issued matters for overlapping. The two backward prefetching values correspond to different issue orders.

class torch.distributed.fsdp.ShardingStrategy(value)[source][source]

This specifies the sharding strategy to be used for distributed training byFullyShardedDataParallel.

class torch.distributed.fsdp.MixedPrecision(param_dtype=None, reduce_dtype=None, buffer_dtype=None, keep_low_precision_grads=False, cast_forward_inputs=False, cast_root_forward_inputs=True, _module_classes_to_ignore=(<class 'torch.nn.modules.batchnorm._BatchNorm'>, ))[source][source]

This configures FSDP-native mixed precision training.

Variables

Note

This API is experimental and subject to change.

Note

Only floating point tensors are cast to their specified dtypes.

Note

In summon_full_params, parameters are forced to full precision, but buffers are not.

Note

Layer norm and batch norm accumulate in float32 even when their inputs are in a low precision like float16 or bfloat16. Disabling FSDP’s mixed precision for those norm modules only means that the affine parameters are kept in float32. However, this incurs separate all-gathers and reduce-scatters for those norm modules, which may be inefficient, so if the workload permits, the user should prefer to still apply mixed precision to those modules.

Note

By default, if the user passes a model with any _BatchNormmodules and specifies an auto_wrap_policy, then the batch norm modules will have FSDP applied to them separately with mixed precision disabled. See the _module_classes_to_ignore argument.

Note

MixedPrecision has cast_root_forward_inputs=True andcast_forward_inputs=False by default. For the root FSDP instance, its cast_root_forward_inputs takes precedence over itscast_forward_inputs. For non-root FSDP instances, theircast_root_forward_inputs values are ignored. The default setting is sufficient for the typical case where each FSDP instance has the sameMixedPrecision configuration and only needs to cast inputs to theparam_dtype at the beginning of the model’s forward pass.

Note

For nested FSDP instances with different MixedPrecisionconfigurations, we recommend setting individual cast_forward_inputsvalues to configure casting inputs or not before each instance’s forward. In such a case, since the casts happen before each FSDP instance’s forward, a parent FSDP instance should have its non-FSDP submodules run before its FSDP submodules to avoid the activation dtype being changed due to a different MixedPrecision configuration.

Example:

model = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3)) model[1] = FSDP( model[1], mixed_precision=MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True), ) model = FSDP( model, mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True), )

The above shows a working example. On the other hand, if model[1]were replaced with model[0], meaning that the submodule using different MixedPrecision ran its forward first, then model[1]would incorrectly see float16 activations instead of bfloat16ones.

class torch.distributed.fsdp.CPUOffload(offload_params=False)[source][source]

This configures CPU offloading.

Variables

offload_params (bool) – This specifies whether to offload parameters to CPU when not involved in computation. If True, then this offloads gradients to CPU as well, meaning that the optimizer step runs on CPU.

class torch.distributed.fsdp.StateDictConfig(offload_to_cpu=False)[source][source]

StateDictConfig is the base class for all state_dict configuration classes. Users should instantiate a child class (e.g.FullStateDictConfig) in order to configure settings for the corresponding state_dict type supported by FSDP.

Variables

offload_to_cpu (bool) – If True, then FSDP offloads the state dict values to CPU, and if False, then FSDP keeps them on GPU. (Default: False)

class torch.distributed.fsdp.FullStateDictConfig(offload_to_cpu=False, rank0_only=False)[source][source]

FullStateDictConfig is a config class meant to be used withStateDictType.FULL_STATE_DICT. We recommend enabling bothoffload_to_cpu=True and rank0_only=True when saving full state dicts to save GPU memory and CPU memory, respectively. This config class is meant to be used via the state_dict_type() context manager as follows:

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP fsdp = FSDP(model, auto_wrap_policy=...) cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) with FSDP.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg): state = fsdp.state_dict()

state will be empty on non rank 0 and contain CPU tensors on rank 0.

To reload checkpoint for inference, finetuning, transfer learning, etc:

model = model_fn() # Initialize model in preparation for wrapping with FSDP if dist.get_rank() == 0:

Load checkpoint only on rank 0 to avoid memory redundancy

state_dict = torch.load("my_checkpoint.pt")
model.load_state_dict(state_dict)

All ranks initialize FSDP module as usual. sync_module_states argument

communicates loaded checkpoint states from rank 0 to rest of the world.

fsdp = FSDP( ... model, ... device_id=torch.cuda.current_device(), ... auto_wrap_policy=..., ... sync_module_states=True, ... )

After this point, all ranks have FSDP model with loaded checkpoint.

Variables

rank0_only (bool) – If True, then only rank 0 saves the full state dict, and nonzero ranks save an empty dict. If False, then all ranks save the full state dict. (Default: False)

class torch.distributed.fsdp.ShardedStateDictConfig(offload_to_cpu=False, _use_dtensor=False)[source][source]

ShardedStateDictConfig is a config class meant to be used withStateDictType.SHARDED_STATE_DICT.

Variables

_use_dtensor (bool) – If True, then FSDP saves the state dict values as DTensor, and if False, then FSDP saves them asShardedTensor. (Default: False)

Warning

_use_dtensor is a private field of ShardedStateDictConfigand it is used by FSDP to determine the type of state dict values. Users should not manually modify _use_dtensor.

class torch.distributed.fsdp.LocalStateDictConfig(offload_to_cpu: bool = False)[source][source]

class torch.distributed.fsdp.OptimStateDictConfig(offload_to_cpu=True)[source][source]

OptimStateDictConfig is the base class for all optim_state_dictconfiguration classes. Users should instantiate a child class (e.g.FullOptimStateDictConfig) in order to configure settings for the corresponding optim_state_dict type supported by FSDP.

Variables

offload_to_cpu (bool) – If True, then FSDP offloads the state dict’s tensor values to CPU, and if False, then FSDP keeps them on the original device (which is GPU unless parameter CPU offloading is enabled). (Default: True)

class torch.distributed.fsdp.FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False)[source][source]

Variables

rank0_only (bool) – If True, then only rank 0 saves the full state dict, and nonzero ranks save an empty dict. If False, then all ranks save the full state dict. (Default: False)

class torch.distributed.fsdp.ShardedOptimStateDictConfig(offload_to_cpu=True, _use_dtensor=False)[source][source]

ShardedOptimStateDictConfig is a config class meant to be used withStateDictType.SHARDED_STATE_DICT.

Variables

_use_dtensor (bool) – If True, then FSDP saves the state dict values as DTensor, and if False, then FSDP saves them asShardedTensor. (Default: False)

Warning

_use_dtensor is a private field of ShardedOptimStateDictConfigand it is used by FSDP to determine the type of state dict values. Users should not manually modify _use_dtensor.

class torch.distributed.fsdp.LocalOptimStateDictConfig(offload_to_cpu: bool = False)[source][source]

class torch.distributed.fsdp.StateDictSettings(state_dict_type: torch.distributed.fsdp.api.StateDictType, state_dict_config: torch.distributed.fsdp.api.StateDictConfig, optim_state_dict_config: torch.distributed.fsdp.api.OptimStateDictConfig)[source][source]