LightningOptimizer — PyTorch Lightning 2.5.1.post0 documentation (original) (raw)
class lightning.pytorch.core.optimizer.LightningOptimizer(optimizer)[source]¶
Bases: object
This class is used to wrap the user optimizers and handle properly the backward and optimizer_step logic across accelerators, AMP, accumulate_grad_batches.
Note: The purpose of this wrapper is only to define new methods and redirect the .step() call. The internal state __dict__
is not kept in sync with the internal state of the original optimizer, but the Trainer never relies on the internal state of the wrapper.
step(closure=None, **kwargs)[source]¶
Performs a single optimization step (parameter update).
Parameters:
- closure¶ (Optional[Callable[[], Any]]) – An optional optimizer closure.
- kwargs¶ (Any) – Any additional arguments to the
optimizer.step()
call.
Return type:
Returns:
The output from the step call, which is generally the output of the closure execution.
Example:
Scenario for a GAN using manual optimization
def training_step(self, batch, batch_idx): opt_gen, opt_dis = self.optimizers()
...
# compute generator loss
loss_gen = self.compute_generator_loss(...)
# zero_grad needs to be called before backward
opt_gen.zero_grad()
self.manual_backward(loss_gen)
opt_gen.step()
# compute discriminator loss
loss_dis = self.compute_discriminator_loss(...)
# zero_grad needs to be called before backward
opt_dis.zero_grad()
self.manual_backward(loss_dis)
opt_dis.step()
A more advanced example
def training_step(self, batch, batch_idx): opt_gen, opt_dis = self.optimizers()
...
accumulated_grad_batches = batch_idx % 2 == 0
# compute generator loss
def closure_gen():
loss_gen = self.compute_generator_loss(...)
self.manual_backward(loss_gen)
if accumulated_grad_batches:
opt_gen.zero_grad()
with opt_gen.toggle_model(sync_grad=accumulated_grad_batches):
opt_gen.step(closure=closure_gen)
def closure_dis():
loss_dis = self.compute_discriminator_loss(...)
self.manual_backward(loss_dis)
if accumulated_grad_batches:
opt_dis.zero_grad()
with opt_dis.toggle_model(sync_grad=accumulated_grad_batches):
opt_dis.step(closure=closure_dis)
toggle_model(sync_grad=True)[source]¶
This function is just a helper for advanced users.
Considering the current optimizer as A and all other optimizers as B. Toggling means all parameters from B exclusive to A will have requires_grad
set to False.
When performing gradient accumulation, there is no need to perform grad synchronization during the accumulation phase. Setting sync_grad to False will block this synchronization and improve performance.
Return type: