MixedPrecision — PyTorch Lightning 2.6.0 documentation (original) (raw)
class lightning.pytorch.plugins.precision.MixedPrecision(precision, device, scaler=None)[source]¶
Bases: Precision
Plugin for Automatic Mixed Precision (AMP) training with torch.autocast.
Parameters:
- precision¶ (Literal[
'16-mixed','bf16-mixed']) – Whether to usetorch.float16('16-mixed') ortorch.bfloat16('bf16-mixed'). - device¶ (str) – The device for
torch.autocast. - scaler¶ (Optional[
GradScaler]) – An optional torch.cuda.amp.GradScaler to use.
clip_gradients(optimizer, clip_val=0.0, gradient_clip_algorithm=GradClipAlgorithmType.NORM)[source]¶
Clips the gradients.
Return type:
Enable autocast context.
Return type:
load_state_dict(state_dict)[source]¶
Called when loading a checkpoint, implement to reload precision plugin state given precision plugin state_dict.
Parameters:
state_dict¶ (dict[str, Any]) – the precision plugin state returned by state_dict.
Return type:
optimizer_step(optimizer, model, closure, **kwargs)[source]¶
Hook to run the optimizer step.
Return type:
pre_backward(tensor, module)[source]¶
Runs before precision plugin executes backward.
Parameters:
- tensor¶ (Tensor) – The tensor that will be used for backpropagation
- module¶ (LightningModule) – The module that was involved in producing the tensor and whose parameters need the gradients
Return type:
Called when saving a checkpoint, implement to generate precision plugin state_dict.
Return type:
Returns:
A dictionary containing precision plugin state.