Strategy — PyTorch Lightning 2.5.1.post0 documentation (original) (raw)
class lightning.pytorch.strategies.Strategy(accelerator=None, checkpoint_io=None, precision_plugin=None)[source]¶
Bases: ABC
Base class for all strategies that change the behaviour of the training, validation and test- loop.
abstract all_gather(tensor, group=None, sync_grads=False)[source]¶
Perform an all_gather on all processes.
Parameters:
- tensor¶ (Tensor) – the tensor to all_gather
- group¶ (Optional[Any]) – the process group to gather results from
- sync_grads¶ (bool) – flag that allows users to synchronize gradients for all_gather op
Return type:
backward(closure_loss, optimizer, *args, **kwargs)[source]¶
Forwards backward-calls to the precision plugin.
Parameters:
- closure_loss¶ (Tensor) – a tensor holding the loss value to backpropagate
- optimizer¶ (Optional[Optimizer]) – An optional optimizer that gets passed down to the precision plugin’s backward
- *args¶ (Any) – Positional arguments that get passed down to the precision plugin’s backward, intended as arguments for the actual function that performs the backward, like backward().
- **kwargs¶ (Any) – Keyword arguments for the same purpose as
*args
.
Return type:
abstract barrier(name=None)[source]¶
Synchronizes all processes which blocks processes until the whole group enters this function.
Parameters:
name¶ (Optional[str]) – an optional name to pass into barrier.
Return type:
batch_to_device(batch, device=None, dataloader_idx=0)[source]¶
Moves the batch to the correct device.
The returned batch is of the same type as the input batch, just having all tensors on the correct device.
Parameters:
- batch¶ (Any) – The batch of samples to move to the correct device
- device¶ (Optional[device]) – The target device
- dataloader_idx¶ (int) – The index of the dataloader to which the batch belongs.
Return type:
abstract broadcast(obj, src=0)[source]¶
Broadcasts an object to all processes.
Parameters:
Return type:
TypeVar(TBroadcast
)
Called by the Trainer to connect the strategy with the model.
Return type:
lightning_module_state_dict()[source]¶
Returns model state.
Return type:
model_sharded_context()[source]¶
Provide hook to create modules in a distributed aware context. This is useful for when we’d like to shard the model instantly, which is useful for extremely large models which can save memory and initialization time.
Returns: Model parallel context.
Return type:
abstract model_to_device()[source]¶
Moves the model to the correct device.
Return type:
on_exception(exception)[source]¶
Called when the trainer execution is interrupted by an exception.
Return type:
Called when predict ends.
Return type:
Called when predict begins.
Return type:
Called when test end.
Return type:
Called when test begins.
Return type:
on_train_batch_start(batch, batch_idx)[source]¶
Called in the training loop before anything happens for that batch.
Return type:
Called when train ends.
Return type:
Called when train begins.
Return type:
Called when validation ends.
Return type:
on_validation_start()[source]¶
Called when validation begins.
Return type:
optimizer_state(optimizer)[source]¶
Returns state of an optimizer.
Allows for syncing/collating optimizer state from processes in custom strategies.
Return type:
optimizer_step(optimizer, closure, model=None, **kwargs)[source]¶
Performs the actual optimizer step.
Parameters:
- optimizer¶ (Optimizer) – the optimizer performing the step
- closure¶ (Callable[[], Any]) – closure calculating the loss value
- model¶ (Union[LightningModule, Module, None]) – reference to the model, optionally defining optimizer step related hooks
- **kwargs¶ (Any) – Keyword arguments to
optimizer.step
Return type:
post_backward(closure_loss)[source]¶
Run after precision plugin executes backward.
Return type:
This hook is deprecated.
Override training_step() instead.
Return type:
pre_backward(closure_loss)[source]¶
Run before precision plugin executes backward.
Return type:
predict_step(*args, **kwargs)[source]¶
The actual predict step.
See predict_step() for more details
Return type:
process_dataloader(dataloader)[source]¶
Wraps the dataloader if necessary.
Parameters:
dataloader¶ (object) – iterable. Ideally of type: torch.utils.data.DataLoader
Return type:
abstract reduce(tensor, group=None, reduce_op='mean')[source]¶
Reduces the given tensor (e.g. across GPUs/processes).
Parameters:
- tensor¶ (Union[Tensor, Any]) – the tensor to sync and reduce
- group¶ (Optional[Any]) – the process group to reduce
- reduce_op¶ (Union[ReduceOp, str, None]) – the reduction operation. Defaults to ‘mean’. Can also be a string ‘sum’ or ReduceOp.
Return type:
reduce_boolean_decision(decision, all=True)[source]¶
Reduce a boolean decision across all processes.
Return type:
remove_checkpoint(filepath)[source]¶
Remove checkpoint filepath from the filesystem.
Parameters:
filepath¶ (Union[str, Path]) – Path to checkpoint
Return type:
save_checkpoint(checkpoint, filepath, storage_options=None)[source]¶
Save model/training states as a checkpoint file through state-dump and file-write.
Parameters:
- checkpoint¶ (dict[str, Any]) – dict containing model and trainer state
- filepath¶ (Union[str, Path]) – write-target file’s path
- storage_options¶ (Optional[Any]) – parameter for how to save to storage, passed to
CheckpointIO
plugin
Return type:
Sets up the accelerator, plugins and initializes the optimizers (if needed).
Parameters:
trainer¶ (Trainer) – the trainer instance
Return type:
Setup any processes or distributed connections.
This is called before the LightningModule/DataModule setup hook which allows the user to access the accelerator environment before setup is complete.
Return type:
setup_optimizers(trainer)[source]¶
Creates optimizers and schedulers.
Parameters:
trainer¶ (Trainer) – the Trainer, these optimizers should be connected to
Return type:
setup_precision_plugin()[source]¶
Attaches the precision plugin to the strategy.
Return type:
This method is called to teardown the training process.
It is the right place to release memory and free other resources.
Return type:
tensor_init_context(empty_init=None)[source]¶
Controls how tensors get created (device, dtype).
Parameters:
empty_init¶ (Optional[bool]) – Whether to initialize the model with empty weights (uninitialized memory). If None
, the strategy will decide. Some strategies may not support all options.
Return type:
test_step(*args, **kwargs)[source]¶
The actual test step.
See test_step() for more details
Return type:
Union[Tensor, Mapping[str, Any], None]
training_step(*args, **kwargs)[source]¶
The actual training step.
See training_step() for more details
Return type:
Union[Tensor, Mapping[str, Any], None]
validation_step(*args, **kwargs)[source]¶
The actual validation step.
See validation_step() for more details
Return type:
Union[Tensor, Mapping[str, Any], None]
property handles_gradient_accumulation_: bool_¶
Whether the strategy handles gradient accumulation internally.
abstract property is_global_zero_: bool_¶
Whether the current process is the rank zero process not only on the local node, but for all nodes.
property lightning_module_: Optional[LightningModule]_¶
Returns the pure LightningModule without potential wrappers.
property lightning_restore_optimizer_: bool_¶
Override to disable Lightning restoring optimizers/schedulers.
This is useful for strategies which manage restoring optimizers/schedulers.
property model_: Optional[Module]_¶
Returns the potentially wrapped LightningModule.
property restore_checkpoint_after_setup_: bool_¶
Override to delay restoring from checkpoint till after the setup phase has completed. This is useful when the strategy requires all the setup hooks to run before loading checkpoint.
Returns:
If True
, restore checkpoint after strategy setup.
abstract property root_device_: device_¶
Returns the root device.