FSDPStrategy — mmengine 0.10.7 documentation (original) (raw)

class mmengine._strategy.FSDPStrategy(*, model_wrapper=None, skip_init_weights=False, state_dict_cfg='local', activation_checkpointing=None, **kwargs)[source]

Support training model with FullyShardedDataParallel (FSDP).

Keyword Arguments:

Parameters:

build_model(model)[source]

Build model.

If skip_init_weights is True, the model will be built with an empty weights. It means that load_checkpoint() must be called to fill the weights before training.

Parameters:

model (nn.Module or dict) – A nn.Module object or a dict to build nn.Module object. If model is a nn.Moduleobject, just returns itself.

Returns:

Model build from model.

Return type:

nn.Module

build_optim_wrapper(optim_wrapper, model=None)[source]

Support sharding the optimizer state dict given a built optimizer or optim_wrapper.

See specific usage in BaseStrategy.build_optim_wrapper().

Parameters:

Return type:

BaseOptimWrapper

load_checkpoint(filename, **kwargs)[source]

Load checkpoint from given filename.

Note

If state_dict_type is local, the filename should be a directory contains rank{i}.pth.

Parameters:

filename (str) – Accept local filepath, URL, torchvision://xxx,open-mmlab://xxx.

Keyword Arguments:

Return type:

dict

load_model_state_dict(state_dict, *, strict=False, revise_keys=[('^module.', '')])[source]

Load model state from dict.

Warning

revise_keys is not supported yet.

Parameters:

Return type:

None

load_optim_state_dict(state_dict)[source]

Load optimizer state from dict.

Parameters:

state_dict (dict) – The optimizer state dict. If state_dict_typeis full. state_dict could be the result ofoptimizer.state_dict()

Return type:

None

model_state_dict()[source]

Get model state dict based on the state_dict_type.

If state_dict_type is full, the model state dict will be the same as the one of original unsharded model.

If state_dict_type is local, and use_orig_params is Truein model_wrapper. The key of the state dict will be the same as the one of original unsharded model, but its value will be the sharded one

If state_dict_type is local, and `use_orig_params` isFalse in model_wrapper, the flatten and sharded state dict will be returned.

See more details in the official api documents

Return type:

dict

optim_state_dict()[source]

Get model state dict based on the state_dict_type.

If state_dict_type is full, the optimizer state dict can be loaded by the original unsharded optimizer.

Otherwise, the optimizer state dict could only be loaded by the optimizer with sharded parameters.

Note

The optimizer state dict is not the same as the one of original optimizer even if in full mode, although they can be loaded correctly.

See more details in the official api documents

Return type:

dict

save_checkpoint(filename, *, save_optimizer=True, save_param_scheduler=True, extra_ckpt=None, callback=None)[source]

Save checkpoint to given filename.

If state_dict_type is full, the checkpoint will only be saved in rank0. The structure of the saved checkpoint is the same as the one saved by DDPStrategy

If state_dict_type is local, each rank will save the sharded state dict to a directory, which means the saved structure will look like this:

── epoch_0.pth ├── rank0.pth ├── rank1.pth ├── ... └── rank8.pth

Parameters:

Keyword Arguments:

Return type:

None