LightningOptimizer — PyTorch Lightning 2.5.1.post0 documentation (original) (raw)

class lightning.pytorch.core.optimizer.LightningOptimizer(optimizer)[source]

Bases: object

This class is used to wrap the user optimizers and handle properly the backward and optimizer_step logic across accelerators, AMP, accumulate_grad_batches.

Note: The purpose of this wrapper is only to define new methods and redirect the .step() call. The internal state __dict__ is not kept in sync with the internal state of the original optimizer, but the Trainer never relies on the internal state of the wrapper.

step(closure=None, **kwargs)[source]

Performs a single optimization step (parameter update).

Parameters:

Return type:

Any

Returns:

The output from the step call, which is generally the output of the closure execution.

Example:

Scenario for a GAN using manual optimization

def training_step(self, batch, batch_idx): opt_gen, opt_dis = self.optimizers()

...

# compute generator loss
loss_gen = self.compute_generator_loss(...)
# zero_grad needs to be called before backward
opt_gen.zero_grad()
self.manual_backward(loss_gen)
opt_gen.step()

# compute discriminator loss
loss_dis = self.compute_discriminator_loss(...)

# zero_grad needs to be called before backward
opt_dis.zero_grad()
self.manual_backward(loss_dis)
opt_dis.step()

A more advanced example

def training_step(self, batch, batch_idx): opt_gen, opt_dis = self.optimizers()

...
accumulated_grad_batches = batch_idx % 2 == 0

# compute generator loss
def closure_gen():
    loss_gen = self.compute_generator_loss(...)
    self.manual_backward(loss_gen)
    if accumulated_grad_batches:
        opt_gen.zero_grad()

with opt_gen.toggle_model(sync_grad=accumulated_grad_batches):
    opt_gen.step(closure=closure_gen)

def closure_dis():
    loss_dis = self.compute_discriminator_loss(...)
    self.manual_backward(loss_dis)
    if accumulated_grad_batches:
        opt_dis.zero_grad()

with opt_dis.toggle_model(sync_grad=accumulated_grad_batches):
    opt_dis.step(closure=closure_dis)

toggle_model(sync_grad=True)[source]

This function is just a helper for advanced users.

Considering the current optimizer as A and all other optimizers as B. Toggling means all parameters from B exclusive to A will have requires_grad set to False.

When performing gradient accumulation, there is no need to perform grad synchronization during the accumulation phase. Setting sync_grad to False will block this synchronization and improve performance.

Return type:

Generator[None, None, None]