LightningModule — PyTorch Lightning 2.5.1.post0 documentation (original) (raw)

A LightningModule organizes your PyTorch code into 6 sections:

When you convert to use Lightning, the code IS NOT abstracted - just organized. All the other code that’s not in the LightningModulehas been automated for you by the Trainer.

net = MyLightningModuleNet() trainer = Trainer() trainer.fit(net)

There are no .cuda() or .to(device) calls required. Lightning does these for you.

don't do in Lightning

x = torch.Tensor(2, 3) x = x.cuda() x = x.to(device)

do this instead

x = x # leave it alone!

or to init a new tensor

new_x = torch.Tensor(2, 3) new_x = new_x.to(x)

When running under a distributed strategy, Lightning handles the distributed sampler for you by default.

Don't do in Lightning...

data = MNIST(...) sampler = DistributedSampler(data) DataLoader(data, sampler=sampler)

do this instead

data = MNIST(...) DataLoader(data)

A LightningModule is a torch.nn.Module but with added functionality. Use it as such!

net = Net.load_from_checkpoint(PATH) net.freeze() out = net(x)

Thus, to use Lightning, you just need to organize your code which takes about 30 minutes, (and let’s be real, you probably should do anyway).


Starter Example

Here are the only required methods.

import lightning as L import torch

from lightning.pytorch.demos import Transformer

class LightningTransformer(L.LightningModule): def init(self, vocab_size): super().init() self.model = Transformer(vocab_size=vocab_size)

def forward(self, inputs, target):
    return self.model(inputs, target)

def training_step(self, batch, batch_idx):
    inputs, target = batch
    output = self(inputs, target)
    loss = torch.nn.functional.nll_loss(output, target.view(-1))
    return loss

def configure_optimizers(self):
    return torch.optim.SGD(self.model.parameters(), lr=0.1)

Which you can train by doing:

from lightning.pytorch.demos import WikiText2 from torch.utils.data import DataLoader

dataset = WikiText2() dataloader = DataLoader(dataset) model = LightningTransformer(vocab_size=dataset.vocab_size)

trainer = L.Trainer(fast_dev_run=100) trainer.fit(model=model, train_dataloaders=dataloader)

The LightningModule has many convenient methods, but the core ones you need to know about are:

Name Description
__init__ and setup() Define initialization here
forward() To run data through your model only (separate from training_step)
training_step() the complete training step
validation_step() the complete validation step
test_step() the complete test step
predict_step() the complete prediction step
configure_optimizers() define optimizers and LR schedulers

Training

Training Loop

To activate the training loop, override the training_step() method.

class LightningTransformer(L.LightningModule): def init(self, vocab_size): super().init() self.model = Transformer(vocab_size=vocab_size)

def training_step(self, batch, batch_idx):
    inputs, target = batch
    output = self.model(inputs, target)
    loss = torch.nn.functional.nll_loss(output, target.view(-1))
    return loss

Under the hood, Lightning does the following (pseudocode):

enable gradient calculation

torch.set_grad_enabled(True)

for batch_idx, batch in enumerate(train_dataloader): loss = training_step(batch, batch_idx)

# clear gradients
optimizer.zero_grad()

# backward
loss.backward()

# update parameters
optimizer.step()

Train Epoch-level Metrics

If you want to calculate epoch-level metrics and log them, use log().

def training_step(self, batch, batch_idx): inputs, target = batch output = self.model(inputs, target) loss = torch.nn.functional.nll_loss(output, target.view(-1))

# logs metrics for each training_step,
# and the average across the epoch, to the progress bar and logger
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss

The log() method automatically reduces the requested metrics across a complete epoch and devices. Here’s the pseudocode of what it does under the hood:

outs = [] for batch_idx, batch in enumerate(train_dataloader): # forward loss = training_step(batch, batch_idx) outs.append(loss.detach())

# clear gradients
optimizer.zero_grad()
# backward
loss.backward()
# update parameters
optimizer.step()

note: in reality, we do this incrementally, instead of keeping all outputs in memory

epoch_metric = torch.mean(torch.stack(outs))

Train Epoch-level Operations

In the case that you need to make use of all the outputs from each training_step(), override the on_train_epoch_end() method.

class LightningTransformer(L.LightningModule): def init(self, vocab_size): super().init() self.model = Transformer(vocab_size=vocab_size) self.training_step_outputs = []

def training_step(self, batch, batch_idx):
    inputs, target = batch
    output = self.model(inputs, target)
    loss = torch.nn.functional.nll_loss(output, target.view(-1))
    preds = ...
    self.training_step_outputs.append(preds)
    return loss

def on_train_epoch_end(self):
    all_preds = torch.stack(self.training_step_outputs)
    # do something with all preds
    ...
    self.training_step_outputs.clear()  # free memory

Validation

Validation Loop

To activate the validation loop while training, override the validation_step() method.

class LightningTransformer(L.LightningModule): def validation_step(self, batch, batch_idx): inputs, target = batch output = self.model(inputs, target) loss = F.cross_entropy(y_hat, y) self.log("val_loss", loss)

Under the hood, Lightning does the following (pseudocode):

...

for batch_idx, batch in enumerate(train_dataloader): loss = model.training_step(batch, batch_idx) loss.backward() # ...

if validate_at_some_point:
    # disable grads + batchnorm + dropout
    torch.set_grad_enabled(False)
    model.eval()

    # ----------------- VAL LOOP ---------------
    for val_batch_idx, val_batch in enumerate(val_dataloader):
        val_out = model.validation_step(val_batch, val_batch_idx)
    # ----------------- VAL LOOP ---------------

    # enable grads + batchnorm + dropout
    torch.set_grad_enabled(True)
    model.train()

You can also run just the validation loop on your validation dataloaders by overriding validation_step()and calling validate().

model = LightningTransformer(vocab_size=dataset.vocab_size) trainer = L.Trainer() trainer.validate(model)

Note

It is recommended to validate on single device to ensure each sample/batch gets evaluated exactly once. This is helpful to make sure benchmarking for research papers is done the right way. Otherwise, in a multi-device setting, samples could occur duplicated when DistributedSampleris used, for eg. with strategy="ddp". It replicates some samples on some devices to make sure all devices have same batch size in case of uneven inputs.

Validation Epoch-level Metrics

In the case that you need to make use of all the outputs from each validation_step(), override the on_validation_epoch_end() method. Note that this method is called before on_train_epoch_end().

class LightningTransformer(L.LightningModule): def init(self, vocab_size): super().init() self.model = Transformer(vocab_size=vocab_size) self.validation_step_outputs = []

def validation_step(self, batch, batch_idx):
    x, y = batch
    inputs, target = batch
    output = self.model(inputs, target)
    loss = torch.nn.functional.nll_loss(output, target.view(-1))
    pred = ...
    self.validation_step_outputs.append(pred)
    return pred

def on_validation_epoch_end(self):
    all_preds = torch.stack(self.validation_step_outputs)
    # do something with all preds
    ...
    self.validation_step_outputs.clear()  # free memory

Testing

Test Loop

The process for enabling a test loop is the same as the process for enabling a validation loop. Please refer to the section above for details. For this you need to override the test_step() method.

The only difference is that the test loop is only called when test() is used.

model = LightningTransformer(vocab_size=dataset.vocab_size) dataloader = DataLoader(dataset) trainer = L.Trainer() trainer.fit(model=model, train_dataloaders=dataloader)

automatically loads the best weights for you

trainer.test(model)

There are two ways to call test():

call after training

trainer = L.Trainer() trainer.fit(model=model, train_dataloaders=dataloader)

automatically auto-loads the best weights from the previous run

trainer.test(dataloaders=test_dataloaders)

or call with pretrained model

model = LightningTransformer.load_from_checkpoint(PATH) dataset = WikiText2() test_dataloader = DataLoader(dataset) trainer = L.Trainer() trainer.test(model, dataloaders=test_dataloader)

Note

WikiText2 is used in a manner that does not create a train, test, val split. This is done for illustrative purposes only. A proper split can be created in lightning.pytorch.core.LightningModule.setup() or lightning.pytorch.core.LightningDataModule.setup().

Note

It is recommended to validate on single device to ensure each sample/batch gets evaluated exactly once. This is helpful to make sure benchmarking for research papers is done the right way. Otherwise, in a multi-device setting, samples could occur duplicated when DistributedSampleris used, for eg. with strategy="ddp". It replicates some samples on some devices to make sure all devices have same batch size in case of uneven inputs.


Inference

Prediction Loop

By default, the predict_step() method runs theforward() method. In order to customize this behaviour, simply override the predict_step() method.

For the example let’s override predict_step:

class LightningTransformer(L.LightningModule): def init(self, vocab_size): super().init() self.model = Transformer(vocab_size=vocab_size)

def predict_step(self, batch):
    inputs, target = batch
    return self.model(inputs, target)

Under the hood, Lightning does the following (pseudocode):

disable grads + batchnorm + dropout

torch.set_grad_enabled(False) model.eval() all_preds = []

for batch_idx, batch in enumerate(predict_dataloader): pred = model.predict_step(batch, batch_idx) all_preds.append(pred)

There are two ways to call predict():

call after training

trainer = L.Trainer() trainer.fit(model=model, train_dataloaders=dataloader)

automatically auto-loads the best weights from the previous run

predictions = trainer.predict(dataloaders=predict_dataloader)

or call with pretrained model

model = LightningTransformer.load_from_checkpoint(PATH) dataset = WikiText2() test_dataloader = DataLoader(dataset) trainer = L.Trainer() predictions = trainer.predict(model, dataloaders=test_dataloader)

Inference in Research

If you want to perform inference with the system, you can add a forward method to the LightningModule.

Note

When using forward, you are responsible to call eval() and use the no_grad() context manager.

class LightningTransformer(L.LightningModule): def init(self, vocab_size): super().init() self.model = Transformer(vocab_size=vocab_size)

def forward(self, batch):
    inputs, target = batch
    return self.model(inputs, target)

def training_step(self, batch, batch_idx):
    inputs, target = batch
    output = self.model(inputs, target)
    loss = torch.nn.functional.nll_loss(output, target.view(-1))
    return loss

def configure_optimizers(self):
    return torch.optim.SGD(self.model.parameters(), lr=0.1)

model = LightningTransformer(vocab_size=dataset.vocab_size)

model.eval() with torch.no_grad(): batch = dataloader.dataset[0] pred = model(batch)

The advantage of adding a forward is that in complex systems, you can do a much more involved inference procedure, such as text generation:

class Seq2Seq(L.LightningModule): def forward(self, x): embeddings = self(x) hidden_states = self.encoder(embeddings) for h in hidden_states: # decode ... return decoded

In the case where you want to scale your inference, you should be usingpredict_step().

class Autoencoder(L.LightningModule): def forward(self, x): return self.decoder(x)

def predict_step(self, batch, batch_idx, dataloader_idx=0):
    # this calls forward
    return self(batch)

data_module = ... model = Autoencoder() trainer = Trainer(accelerator="gpu", devices=2) trainer.predict(model, data_module)

Inference in Production

For cases like production, you might want to iterate different models inside a LightningModule.

from torchmetrics.functional import accuracy

class ClassificationTask(L.LightningModule): def init(self, model): super().init() self.model = model

def training_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat, y)
    return loss

def validation_step(self, batch, batch_idx):
    loss, acc = self._shared_eval_step(batch, batch_idx)
    metrics = {"val_acc": acc, "val_loss": loss}
    self.log_dict(metrics)
    return metrics

def test_step(self, batch, batch_idx):
    loss, acc = self._shared_eval_step(batch, batch_idx)
    metrics = {"test_acc": acc, "test_loss": loss}
    self.log_dict(metrics)
    return metrics

def _shared_eval_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat, y)
    acc = accuracy(y_hat, y)
    return loss, acc

def predict_step(self, batch, batch_idx, dataloader_idx=0):
    x, y = batch
    y_hat = self.model(x)
    return y_hat

def configure_optimizers(self):
    return torch.optim.Adam(self.model.parameters(), lr=0.02)

Then pass in any arbitrary model to be fit with this task

for model in [resnet50(), vgg16(), BidirectionalRNN()]: task = ClassificationTask(model)

trainer = Trainer(accelerator="gpu", devices=2)
trainer.fit(task, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

Tasks can be arbitrarily complex such as implementing GAN training, self-supervised or even RL.

class GANTask(L.LightningModule): def init(self, generator, discriminator): super().init() self.generator = generator self.discriminator = discriminator

...

When used like this, the model can be separated from the Task and thus used in production without needing to keep it in a LightningModule.

The following example shows how you can run inference in the Python runtime:

task = ClassificationTask(model) trainer = Trainer(accelerator="gpu", devices=2) trainer.fit(task, train_dataloader, val_dataloader) trainer.save_checkpoint("best_model.ckpt")

use model after training or load weights and drop into the production system

model = ClassificationTask.load_from_checkpoint("best_model.ckpt") x = ... model.eval() with torch.no_grad(): y_hat = model(x)

Check out Inference in Production guide to learn about the possible ways to perform inference in production.


Save Hyperparameters

Often times we train many versions of a model. You might share that model or come back to it a few months later at which point it is very useful to know how that model was trained (i.e.: what learning rate, neural network, etc…).

Lightning has a standardized way of saving the information for you in checkpoints and YAML files. The goal here is to improve readability and reproducibility.

save_hyperparameters

Use save_hyperparameters() within yourLightningModule’s __init__ method. It will enable Lightning to store all the provided arguments under the self.hparams attribute. These hyperparameters will also be stored within the model checkpoint, which simplifies model re-instantiation after training.

class LitMNIST(L.LightningModule): def init(self, layer_1_dim=128, learning_rate=1e-2): super().init() # call this to save (layer_1_dim=128, learning_rate=1e-4) to the checkpoint self.save_hyperparameters()

    # equivalent
    self.save_hyperparameters("layer_1_dim", "learning_rate")

    # Now possible to access layer_1_dim from hparams
    self.hparams.layer_1_dim

In addition, loggers that support it will automatically log the contents of self.hparams.

Excluding hyperparameters

By default, every parameter of the __init__ method will be considered a hyperparameter to the LightningModule. However, sometimes some parameters need to be excluded from saving, for example when they are not serializable. Those parameters should be provided back when reloading the LightningModule. In this case, exclude them explicitly:

class LitMNIST(L.LightningModule): def init(self, loss_fx, generator_network, layer_1_dim=128): super().init() self.layer_1_dim = layer_1_dim self.loss_fx = loss_fx

    # call this to save only (layer_1_dim=128) to the checkpoint
    self.save_hyperparameters("layer_1_dim")

    # equivalent
    self.save_hyperparameters(ignore=["loss_fx", "generator_network"])

load_from_checkpoint

LightningModules that have hyperparameters automatically saved withsave_hyperparameters() can conveniently be loaded and instantiated directly from a checkpoint with load_from_checkpoint():

to load specify the other args

model = LitMNIST.load_from_checkpoint(PATH, loss_fx=torch.nn.SomeOtherLoss, generator_network=MyGenerator())

If parameters were excluded, they need to be provided at the time of loading:

the excluded parameters were loss_fx and generator_network

model = LitMNIST.load_from_checkpoint(PATH, loss_fx=torch.nn.SomeOtherLoss, generator_network=MyGenerator())


Child Modules

Research projects tend to test different approaches to the same dataset. This is very easy to do in Lightning with inheritance.

For example, imagine we now want to train an AutoEncoder to use as a feature extractor for images. The only things that change in the LitAutoEncoder model are the init, forward, training, validation and test step.

class Encoder(torch.nn.Module): ...

class Decoder(torch.nn.Module): ...

class AutoEncoder(torch.nn.Module): def init(self): super().init() self.encoder = Encoder() self.decoder = Decoder()

def forward(self, x):
    return self.decoder(self.encoder(x))

class LitAutoEncoder(LightningModule): def init(self, auto_encoder): super().init() self.auto_encoder = auto_encoder self.metric = torch.nn.MSELoss()

def forward(self, x):
    return self.auto_encoder.encoder(x)

def training_step(self, batch, batch_idx):
    x, _ = batch
    x_hat = self.auto_encoder(x)
    loss = self.metric(x, x_hat)
    return loss

def validation_step(self, batch, batch_idx):
    self._shared_eval(batch, batch_idx, "val")

def test_step(self, batch, batch_idx):
    self._shared_eval(batch, batch_idx, "test")

def _shared_eval(self, batch, batch_idx, prefix):
    x, _ = batch
    x_hat = self.auto_encoder(x)
    loss = self.metric(x, x_hat)
    self.log(f"{prefix}_loss", loss)

and we can train this using the Trainer:

auto_encoder = AutoEncoder() lightning_module = LitAutoEncoder(auto_encoder) trainer = Trainer() trainer.fit(lightning_module, train_dataloader, val_dataloader)

And remember that the forward method should define the practical use of a LightningModule. In this case, we want to use the LitAutoEncoder to extract image representations:

some_images = torch.Tensor(32, 1, 28, 28) representations = lightning_module(some_images)


LightningModule API

Methods

all_gather

LightningModule.all_gather(data, group=None, sync_grads=False)[source]

Gather tensors or collections of tensors from multiple processes.

This method needs to be called on all processes and the tensors need to have the same shape across all processes, otherwise your program will stall forever.

Parameters:

Return type:

Union[Tensor, dict, list, tuple]

Returns:

A tensor of shape (world_size, batch, …), or if the input was a collection the output will also be a collection with tensors of this shape. For the special case where world_size is 1, no additional dimension is added to the tensor(s).

configure_callbacks

LightningModule.configure_callbacks()[source]

Configure model-specific callbacks. When the model gets attached, e.g., when .fit() or .test() gets called, the list or a callback returned here will be merged with the list of callbacks passed to the Trainer’scallbacks argument. If a callback returned here has the same type as one or several callbacks already present in the Trainer’s callbacks list, it will take priority and replace them. In addition, Lightning will make sure ModelCheckpoint callbacks run last.

Return type:

Union[Sequence[Callback], Callback]

Returns:

A callback or a list of callbacks which will extend the list of callbacks in the Trainer.

Example:

def configure_callbacks(self): early_stop = EarlyStopping(monitor="val_acc", mode="max") checkpoint = ModelCheckpoint(monitor="val_loss") return [early_stop, checkpoint]

configure_optimizers

LightningModule.configure_optimizers()[source]

Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.

Return type:

Union[Optimizer, Sequence[Optimizer], tuple[Sequence[Optimizer], Sequence[Union[LRScheduler, ReduceLROnPlateau, LRSchedulerConfig]]], OptimizerConfig, OptimizerLRSchedulerConfig, Sequence[OptimizerConfig], Sequence[OptimizerLRSchedulerConfig], None]

Returns:

Any of these 6 options.

The lr_scheduler_config is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.

lr_scheduler_config = { # REQUIRED: The scheduler instance "scheduler": lr_scheduler, # The unit of the scheduler's step size, could also be 'step'. # 'epoch' updates the scheduler on epoch end whereas 'step' # updates it after a optimizer update. "interval": "epoch", # How many epochs/steps should pass between calls to # scheduler.step(). 1 corresponds to updating the learning # rate after every epoch/step. "frequency": 1, # Metric to monitor for schedulers like ReduceLROnPlateau "monitor": "val_loss", # If set to True, will enforce that the value specified 'monitor' # is available when the scheduler is updated, thus stopping # training if not found. If set to False, it will only produce a warning "strict": True, # If using the LearningRateMonitor callback to monitor the # learning rate progress, this keyword can be used to specify # a custom logged name "name": None, }

When there are schedulers in which the .step() method is conditioned on a value, such as thetorch.optim.lr_scheduler.ReduceLROnPlateau scheduler, Lightning requires that thelr_scheduler_config contains the keyword "monitor" set to the metric name that the scheduler should be conditioned on.

The ReduceLROnPlateau scheduler requires a monitor

def configure_optimizers(self): optimizer = Adam(...) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": ReduceLROnPlateau(optimizer, ...), "monitor": "metric_to_track", "frequency": "indicates how often the metric is updated", # If "monitor" references validation metrics, then "frequency" should be set to a # multiple of "trainer.check_val_every_n_epoch". }, }

In the case of two optimizers, only one using the ReduceLROnPlateau scheduler

def configure_optimizers(self): optimizer1 = Adam(...) optimizer2 = SGD(...) scheduler1 = ReduceLROnPlateau(optimizer1, ...) scheduler2 = LambdaLR(optimizer2, ...) return ( { "optimizer": optimizer1, "lr_scheduler": { "scheduler": scheduler1, "monitor": "metric_to_track", }, }, {"optimizer": optimizer2, "lr_scheduler": scheduler2}, )

Metrics can be made available to monitor by simply logging it usingself.log('metric_to_track', metric_val) in your LightningModule.

Note

Some things to know:

forward

LightningModule.forward(*args, **kwargs)[source]

Same as torch.nn.Module.forward().

Parameters:

Return type:

Any

Returns:

Your model’s output

freeze

LightningModule.freeze()[source]

Freeze all params for inference.

Example:

model = MyLightningModule(...) model.freeze()

Return type:

None

log

LightningModule.log(name, value, prog_bar=False, logger=None, on_step=None, on_epoch=None, reduce_fx='mean', enable_graph=False, sync_dist=False, sync_dist_group=None, add_dataloader_idx=True, batch_size=None, metric_attribute=None, rank_zero_only=False)[source]

Log a key, value pair.

Example:

self.log('train_loss', loss)

The default behavior per hook is documented here: Automatic Logging.

Parameters:

Return type:

None

log_dict

LightningModule.log_dict(dictionary, prog_bar=False, logger=None, on_step=None, on_epoch=None, reduce_fx='mean', enable_graph=False, sync_dist=False, sync_dist_group=None, add_dataloader_idx=True, batch_size=None, rank_zero_only=False)[source]

Log a dictionary of values at once.

Example:

values = {'loss': loss, 'acc': acc, ..., 'metric_n': metric_n} self.log_dict(values)

Parameters:

Return type:

None

lr_schedulers

LightningModule.lr_schedulers()[source]

Returns the learning rate scheduler(s) that are being used during training. Useful for manual optimization.

Return type:

Union[None, list[Union[LRScheduler, ReduceLROnPlateau]], LRScheduler, ReduceLROnPlateau]

Returns:

A single scheduler, or a list of schedulers in case multiple ones are present, or None if no schedulers were returned in configure_optimizers().

manual_backward

LightningModule.manual_backward(loss, *args, **kwargs)[source]

Call this directly from your training_step() when doing optimizations manually. By using this, Lightning can ensure that all the proper scaling gets applied when using mixed precision.

See manual optimization for more examples.

Example:

def training_step(...): opt = self.optimizers() loss = ... opt.zero_grad() # automatically applies scaling, etc... self.manual_backward(loss) opt.step()

Parameters:

Return type:

None

optimizers

LightningModule.optimizers(use_pl_optimizer=True)[source]

Returns the optimizer(s) that are being used during training. Useful for manual optimization.

Parameters:

use_pl_optimizer (bool) – If True, will wrap the optimizer(s) in aLightningOptimizer for automatic handling of precision, profiling, and counting of step calls for proper logging and checkpointing. It specifically wraps thestep method and custom optimizers that don’t have this method are not supported.

Return type:

Union[Optimizer, LightningOptimizer, _FabricOptimizer, list[Optimizer], list[LightningOptimizer], list[_FabricOptimizer]]

Returns:

A single optimizer, or a list of optimizers in case multiple ones are present.

print

LightningModule.print(*args, **kwargs)[source]

Prints only from process 0. Use this in any distributed mode to log only once.

Parameters:

Return type:

None

Example:

def forward(self, x): self.print(x, 'in forward')

predict_step

LightningModule.predict_step(*args, **kwargs)[source]

Step function called during predict(). By default, it callsforward(). Override to add any processing logic.

The predict_step() is used to scale inference on multi-devices.

To prevent an OOM error, it is possible to use BasePredictionWritercallback to write the predictions to disk or database after each batch or on epoch end.

The BasePredictionWriter should be used while using a spawn based accelerator. This happens for Trainer(strategy="ddp_spawn")or training on 8 TPU cores with Trainer(accelerator="tpu", devices=8) as predictions won’t be returned.

Parameters:

Return type:

Any

Returns:

Predicted output (optional).

Example

class MyModel(LightningModule):

def predict_step(self, batch, batch_idx, dataloader_idx=0):
    return self(batch)

dm = ... model = MyModel() trainer = Trainer(accelerator="gpu", devices=2) predictions = trainer.predict(model, dm)

save_hyperparameters

LightningModule.save_hyperparameters(*args, ignore=None, frame=None, logger=True)

Save arguments to hparams attribute.

Parameters:

Return type:

None

Example::

from lightning.pytorch.core.mixins import HyperparametersMixin class ManuallyArgsModel(HyperparametersMixin): ... def init(self, arg1, arg2, arg3): ... super().init() ... # manually assign arguments ... self.save_hyperparameters('arg1', 'arg3') ... def forward(self, *args, **kwargs): ... ... model = ManuallyArgsModel(1, 'abc', 3.14) model.hparams "arg1": 1 "arg3": 3.14

from lightning.pytorch.core.mixins import HyperparametersMixin class AutomaticArgsModel(HyperparametersMixin): ... def init(self, arg1, arg2, arg3): ... super().init() ... # equivalent automatic ... self.save_hyperparameters() ... def forward(self, *args, **kwargs): ... ... model = AutomaticArgsModel(1, 'abc', 3.14) model.hparams "arg1": 1 "arg2": abc "arg3": 3.14

from lightning.pytorch.core.mixins import HyperparametersMixin class SingleArgModel(HyperparametersMixin): ... def init(self, params): ... super().init() ... # manually assign single argument ... self.save_hyperparameters(params) ... def forward(self, *args, **kwargs): ... ... model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14)) model.hparams "p1": 1 "p2": abc "p3": 3.14

from lightning.pytorch.core.mixins import HyperparametersMixin class ManuallyArgsModel(HyperparametersMixin): ... def init(self, arg1, arg2, arg3): ... super().init() ... # pass argument(s) to ignore as a string or in a list ... self.save_hyperparameters(ignore='arg2') ... def forward(self, *args, **kwargs): ... ... model = ManuallyArgsModel(1, 'abc', 3.14) model.hparams "arg1": 1 "arg3": 3.14

toggle_optimizer

LightningModule.toggle_optimizer(optimizer)[source]

Makes sure only the gradients of the current optimizer’s parameters are calculated in the training step to prevent dangling gradients in multiple-optimizer setup.

It works with untoggle_optimizer() to make sure param_requires_grad_state is properly reset.

Parameters:

optimizer (Union[Optimizer, LightningOptimizer]) – The optimizer to toggle.

Return type:

None

test_step

LightningModule.test_step(*args, **kwargs)[source]

Operates on a single batch of data from the test set. In this step you’d normally generate examples or calculate anything of interest such as accuracy.

Parameters:

Return type:

Union[Tensor, Mapping[str, Any], None]

Returns:

if you have one test dataloader:

def test_step(self, batch, batch_idx): ...

if you have multiple test dataloaders:

def test_step(self, batch, batch_idx, dataloader_idx=0): ...

Examples:

CASE 1: A single test dataset

def test_step(self, batch, batch_idx): x, y = batch

# implement your own
out = self(x)
loss = self.loss(out, y)

# log 6 example images
# or generated text... or whatever
sample_imgs = x[:6]
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image('example_images', grid, 0)

# calculate acc
labels_hat = torch.argmax(out, dim=1)
test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

# log the outputs!
self.log_dict({'test_loss': loss, 'test_acc': test_acc})

If you pass in multiple test dataloaders, test_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

CASE 2: multiple test dataloaders

def test_step(self, batch, batch_idx, dataloader_idx=0): # dataloader_idx tells you which dataset this is. ...

Note

If you don’t need to test you don’t need to implement this method.

Note

When the test_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of the test epoch, the model goes back to training mode and gradients are enabled.

to_onnx

LightningModule.to_onnx(file_path, input_sample=None, **kwargs)[source]

Saves the model in ONNX format.

Parameters:

Return type:

None

Example:

class SimpleModel(LightningModule): def init(self): super().init() self.l1 = torch.nn.Linear(in_features=64, out_features=4)

def forward(self, x):
    return torch.relu(self.l1(x.view(x.size(0), -1)

model = SimpleModel() input_sample = torch.randn(1, 64) model.to_onnx("export.onnx", input_sample, export_params=True)

to_torchscript

LightningModule.to_torchscript(file_path=None, method='script', example_inputs=None, **kwargs)[source]

By default compiles the whole model to a ScriptModule. If you want to use tracing, please provided the argument method='trace' and make sure that either the example_inputs argument is provided, or the model has example_input_array set. If you would like to customize the modules that are scripted you should override this method. In case you want to return multiple modules, we recommend using a dictionary.

Parameters:

Note

Example:

class SimpleModel(LightningModule): def init(self): super().init() self.l1 = torch.nn.Linear(in_features=64, out_features=4)

def forward(self, x):
    return torch.relu(self.l1(x.view(x.size(0), -1)))

model = SimpleModel() model.to_torchscript(file_path="model.pt")

torch.jit.save(model.to_torchscript( file_path="model_trace.pt", method='trace', example_inputs=torch.randn(1, 64)) )

Return type:

Union[ScriptModule, dict[str, ScriptModule]]

Returns:

This LightningModule as a torchscript, regardless of whether file_path is defined or not.

training_step

LightningModule.training_step(*args, **kwargs)[source]

Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.

Parameters:

Return type:

Union[Tensor, Mapping[str, Any], None]

Returns:

In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.

Example:

def training_step(self, batch, batch_idx): x, y, z = batch out = self.encoder(x) loss = self.loss(out, x) return loss

To use multiple optimizers, you can switch to ‘manual optimization’ and control their stepping:

def init(self): super().init() self.automatic_optimization = False

Multiple optimizers (e.g.: GANs)

def training_step(self, batch, batch_idx): opt1, opt2 = self.optimizers()

# do training_step with encoder
...
opt1.step()
# do training_step with decoder
...
opt2.step()

Note

When accumulate_grad_batches > 1, the loss returned here will be automatically normalized by accumulate_grad_batches internally.

unfreeze

LightningModule.unfreeze()[source]

Unfreeze all parameters for training.

model = MyLightningModule(...) model.unfreeze()

Return type:

None

untoggle_optimizer

LightningModule.untoggle_optimizer(optimizer)[source]

Resets the state of required gradients that were toggled with toggle_optimizer().

Parameters:

optimizer (Union[Optimizer, LightningOptimizer]) – The optimizer to untoggle.

Return type:

None

validation_step

LightningModule.validation_step(*args, **kwargs)[source]

Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.

Parameters:

Return type:

Union[Tensor, Mapping[str, Any], None]

Returns:

if you have one val dataloader:

def validation_step(self, batch, batch_idx): ...

if you have multiple val dataloaders:

def validation_step(self, batch, batch_idx, dataloader_idx=0): ...

Examples:

CASE 1: A single validation dataset

def validation_step(self, batch, batch_idx): x, y = batch

# implement your own
out = self(x)
loss = self.loss(out, y)

# log 6 example images
# or generated text... or whatever
sample_imgs = x[:6]
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image('example_images', grid, 0)

# calculate acc
labels_hat = torch.argmax(out, dim=1)
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

# log the outputs!
self.log_dict({'val_loss': loss, 'val_acc': val_acc})

If you pass in multiple val dataloaders, validation_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

CASE 2: multiple validation dataloaders

def validation_step(self, batch, batch_idx, dataloader_idx=0): # dataloader_idx tells you which dataset this is. ...

Note

If you don’t need to validate you don’t need to implement this method.

Note

When the validation_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.


Properties

These are properties available in a LightningModule.

current_epoch

The number of epochs run.

def training_step(self, batch, batch_idx): if self.current_epoch == 0: ...

device

The device the module is on. Use it to keep your code device agnostic.

def training_step(self, batch, batch_idx): z = torch.rand(2, 3, device=self.device)

global_rank

The global_rank is the index of the current process across all nodes and devices. Lightning will perform some operations such as logging, weight checkpointing only when global_rank=0. You usually do not need to use this property, but it is useful to know how to access it if needed.

def training_step(self, batch, batch_idx): if self.global_rank == 0: # do something only once across all the nodes ...

global_step

The number of optimizer steps taken (does not reset each epoch). This includes multiple optimizers (if enabled).

def training_step(self, batch, batch_idx): self.logger.experiment.log_image(..., step=self.global_step)

hparams

The arguments passed through LightningModule.__init__() and saved by callingsave_hyperparameters() could be accessed by the hparams attribute.

def init(self, learning_rate): self.save_hyperparameters()

def configure_optimizers(self): return Adam(self.parameters(), lr=self.hparams.learning_rate)

logger

The current logger being used (tensorboard or other supported logger)

def training_step(self, batch, batch_idx): # the generic logger (same no matter if tensorboard or other supported logger) self.logger

# the particular logger
tensorboard_logger = self.logger.experiment

loggers

The list of loggers currently being used by the Trainer.

def training_step(self, batch, batch_idx): # List of Logger objects loggers = self.loggers for logger in loggers: logger.log_metrics({"foo": 1.0})

local_rank

The local_rank is the index of the current process across all the devices for the current node. You usually do not need to use this property, but it is useful to know how to access it if needed. For example, if using 10 machines (or nodes), the GPU at index 0 on each machine has local_rank = 0.

def training_step(self, batch, batch_idx): if self.local_rank == 0: # do something only once across each node ...

precision

The type of precision used:

def training_step(self, batch, batch_idx): if self.precision == "16-true": ...

trainer

Pointer to the trainer

def training_step(self, batch, batch_idx): max_steps = self.trainer.max_steps any_flag = self.trainer.any_flag

prepare_data_per_node

If set to True will call prepare_data() on LOCAL_RANK=0 for every node. If set to False will only call from NODE_RANK=0, LOCAL_RANK=0.

class LitModel(LightningModule): def init(self): super().init() self.prepare_data_per_node = True

automatic_optimization

When set to False, Lightning does not automate the optimization process. This means you are responsible for handling your optimizers. However, we do take care of precision and any accelerators used.

See manual optimization for details.

def init(self): self.automatic_optimization = False

def training_step(self, batch, batch_idx): opt = self.optimizers(use_pl_optimizer=True)

loss = ...
opt.zero_grad()
self.manual_backward(loss)
opt.step()

Manual optimization is most useful for research topics like reinforcement learning, sparse coding, and GAN research. It is required when you are using 2+ optimizers because with automatic optimization, you can only use one optimizer.

def init(self): self.automatic_optimization = False

def training_step(self, batch, batch_idx): # access your optimizers with use_pl_optimizer=False. Default is True opt_a, opt_b = self.optimizers(use_pl_optimizer=True)

gen_loss = ...
opt_a.zero_grad()
self.manual_backward(gen_loss)
opt_a.step()

disc_loss = ...
opt_b.zero_grad()
self.manual_backward(disc_loss)
opt_b.step()

example_input_array

Set and access example_input_array, which basically represents a single batch.

def init(self): self.example_input_array = ... self.generator = ...

def on_train_epoch_end(self): # generate some images using the example_input_array gen_images = self.generator(self.example_input_array)


Hooks

This is the pseudocode to describe the structure of fit(). The inputs and outputs of each function are not represented for simplicity. Please check each function’s API reference for more information.

runs on every device: devices can be GPUs, TPUs, ...

def fit(self): configure_callbacks()

if local_rank == 0:
    prepare_data()

setup("fit")
configure_model()
configure_optimizers()

on_fit_start()

# the sanity check runs here

on_train_start()
for epoch in epochs:
    fit_loop()
on_train_end()

on_fit_end()
teardown("fit")

def fit_loop(): torch.set_grad_enabled(True)

on_train_epoch_start()

for batch_idx, batch in enumerate(train_dataloader()):
    on_train_batch_start()

    on_before_batch_transfer()
    transfer_batch_to_device()
    on_after_batch_transfer()

    out = training_step()

    on_before_zero_grad()
    optimizer_zero_grad()

    on_before_backward()
    backward()
    on_after_backward()

    on_before_optimizer_step()
    configure_gradient_clipping()
    optimizer_step()

    on_train_batch_end(out, batch, batch_idx)

    if should_check_val:
        val_loop()

on_train_epoch_end()

def val_loop(): on_validation_model_eval() # calls model.eval() torch.set_grad_enabled(False)

on_validation_start()
on_validation_epoch_start()

for batch_idx, batch in enumerate(val_dataloader()):
    on_validation_batch_start(batch, batch_idx)

    batch = on_before_batch_transfer(batch)
    batch = transfer_batch_to_device(batch)
    batch = on_after_batch_transfer(batch)

    out = validation_step(batch, batch_idx)

    on_validation_batch_end(out, batch, batch_idx)

on_validation_epoch_end()
on_validation_end()

# set up for train
on_validation_model_train()  # calls `model.train()`
torch.set_grad_enabled(True)

backward

LightningModule.backward(loss, *args, **kwargs)[source]

Called to perform backward on the loss returned in training_step(). Override this hook with your own implementation if you need to.

Parameters:

loss (Tensor) – The loss tensor returned by training_step(). If gradient accumulation is used, the loss here holds the normalized value (scaled by 1 / accumulation steps).

Return type:

None

Example:

def backward(self, loss): loss.backward()

on_before_backward

LightningModule.on_before_backward(loss)

Called before loss.backward().

Parameters:

loss (Tensor) – Loss divided by number of batches for gradient accumulation and scaled if using AMP.

Return type:

None

on_after_backward

LightningModule.on_after_backward()

Called after loss.backward() and before optimizers are stepped. :rtype: None

Note

If using native AMP, the gradients will not be unscaled at this point. Use the on_before_optimizer_step if you need the unscaled gradients.

on_before_zero_grad

LightningModule.on_before_zero_grad(optimizer)

Called after training_step() and before optimizer.zero_grad().

Called in the training loop after taking an optimizer step and before zeroing grads. Good place to inspect weight information with weights updated.

This is where it is called:

for optimizer in optimizers: out = training_step(...)

model.on_before_zero_grad(optimizer) # < ---- called here
optimizer.zero_grad()

backward()

Parameters:

optimizer (Optimizer) – The optimizer for which grads should be zeroed.

Return type:

None

on_fit_start

LightningModule.on_fit_start()

Called at the very beginning of fit.

If on DDP it is called on every process

Return type:

None

on_fit_end

LightningModule.on_fit_end()

Called at the very end of fit.

If on DDP it is called on every process

Return type:

None

on_load_checkpoint

LightningModule.on_load_checkpoint(checkpoint)

Called by Lightning to restore your model. If you saved something with on_save_checkpoint() this is your chance to restore this.

Parameters:

checkpoint (dict[str, Any]) – Loaded checkpoint

Return type:

None

Example:

def on_load_checkpoint(self, checkpoint): # 99% of the time you don't need to implement this method self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']

Note

Lightning auto-restores global step, epoch, and train state including amp scaling. There is no need for you to restore anything regarding training.

on_save_checkpoint

LightningModule.on_save_checkpoint(checkpoint)

Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.

Parameters:

checkpoint (dict[str, Any]) – The full checkpoint dictionary before it gets dumped to a file. Implementations of this hook can insert additional data into this dictionary.

Return type:

None

Example:

def on_save_checkpoint(self, checkpoint): # 99% of use cases you don't need to implement this method checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object

Note

Lightning saves all aspects of training (epoch, global step, etc…) including amp scaling. There is no need for you to store anything about training.

load_from_checkpoint

LightningModule.load_from_checkpoint(checkpoint_path, map_location=None, hparams_file=None, strict=None, **kwargs)[source]

Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments passed to __init__ in the checkpoint under "hyper_parameters".

Any arguments specified through **kwargs will override args stored in "hyper_parameters".

Parameters:

Return type:

Self

Returns:

LightningModule instance with loaded weights and hyperparameters (if available).

Note

load_from_checkpoint is a class method. You should use your LightningModule class to call it instead of the LightningModule instance, or aTypeError will be raised.

Note

To ensure all layers can be loaded from the checkpoint, this function will callconfigure_model() directly after instantiating the model if this hook is overridden in your LightningModule. However, note that load_from_checkpoint does not support loading sharded checkpoints, and you may run out of memory if the model is too large. In this case, consider loading through the Trainer via .fit(ckpt_path=...).

Example:

load weights without mapping ...

model = MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')

or load weights mapping all weights from GPU 1 to GPU 0 ...

map_location = {'cuda:1':'cuda:0'} model = MyLightningModule.load_from_checkpoint( 'path/to/checkpoint.ckpt', map_location=map_location )

or load weights and hyperparameters from separate files.

model = MyLightningModule.load_from_checkpoint( 'path/to/checkpoint.ckpt', hparams_file='/path/to/hparams_file.yaml' )

override some of the params with new values

model = MyLightningModule.load_from_checkpoint( PATH, num_layers=128, pretrained_ckpt_path=NEW_PATH, )

predict

pretrained_model.eval() pretrained_model.freeze() y_hat = pretrained_model(x)

on_train_start

LightningModule.on_train_start()

Called at the beginning of training after sanity check.

Return type:

None

on_train_end

LightningModule.on_train_end()

Called at the end of training before logger experiment is closed.

Return type:

None

on_validation_start

LightningModule.on_validation_start()

Called at the beginning of validation.

Return type:

None

on_validation_end

LightningModule.on_validation_end()

Called at the end of validation.

Return type:

None

on_test_batch_start

LightningModule.on_test_batch_start(batch, batch_idx, dataloader_idx=0)

Called in the test loop before anything happens for that batch.

Parameters:

Return type:

None

on_test_batch_end

LightningModule.on_test_batch_end(outputs, batch, batch_idx, dataloader_idx=0)

Called in the test loop after the batch.

Parameters:

Return type:

None

on_test_epoch_start

LightningModule.on_test_epoch_start()

Called in the test loop at the very beginning of the epoch.

Return type:

None

on_test_epoch_end

LightningModule.on_test_epoch_end()

Called in the test loop at the very end of the epoch.

Return type:

None

on_test_start

LightningModule.on_test_start()

Called at the beginning of testing.

Return type:

None

on_test_end

LightningModule.on_test_end()

Called at the end of testing.

Return type:

None

on_predict_batch_start

LightningModule.on_predict_batch_start(batch, batch_idx, dataloader_idx=0)

Called in the predict loop before anything happens for that batch.

Parameters:

Return type:

None

on_predict_batch_end

LightningModule.on_predict_batch_end(outputs, batch, batch_idx, dataloader_idx=0)

Called in the predict loop after the batch.

Parameters:

Return type:

None

on_predict_epoch_start

LightningModule.on_predict_epoch_start()

Called at the beginning of predicting.

Return type:

None

on_predict_epoch_end

LightningModule.on_predict_epoch_end()

Called at the end of predicting.

Return type:

None

on_predict_start

LightningModule.on_predict_start()

Called at the beginning of predicting.

Return type:

None

on_predict_end

LightningModule.on_predict_end()

Called at the end of predicting.

Return type:

None

on_train_batch_start

LightningModule.on_train_batch_start(batch, batch_idx)

Called in the training loop before anything happens for that batch.

If you return -1 here, you will skip training for the rest of the current epoch.

Parameters:

Return type:

Optional[int]

on_train_batch_end

LightningModule.on_train_batch_end(outputs, batch, batch_idx)

Called in the training loop after the batch.

Parameters:

Return type:

None

Note

The value outputs["loss"] here will be the normalized value w.r.t accumulate_grad_batches of the loss returned from training_step.

on_train_epoch_start

LightningModule.on_train_epoch_start()

Called in the training loop at the very beginning of the epoch.

Return type:

None

on_train_epoch_end

LightningModule.on_train_epoch_end()

Called in the training loop at the very end of the epoch.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of theLightningModule and access them in this hook:

class MyLightningModule(L.LightningModule): def init(self): super().init() self.training_step_outputs = []

def training_step(self):
    loss = ...
    self.training_step_outputs.append(loss)
    return loss

def on_train_epoch_end(self):
    # do something with all training_step outputs, for example:
    epoch_mean = torch.stack(self.training_step_outputs).mean()
    self.log("training_epoch_mean", epoch_mean)
    # free up the memory
    self.training_step_outputs.clear()

Return type:

None

on_validation_batch_start

LightningModule.on_validation_batch_start(batch, batch_idx, dataloader_idx=0)

Called in the validation loop before anything happens for that batch.

Parameters:

Return type:

None

on_validation_batch_end

LightningModule.on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx=0)

Called in the validation loop after the batch.

Parameters:

Return type:

None

on_validation_epoch_start

LightningModule.on_validation_epoch_start()

Called in the validation loop at the very beginning of the epoch.

Return type:

None

on_validation_epoch_end

LightningModule.on_validation_epoch_end()

Called in the validation loop at the very end of the epoch.

Return type:

None

configure_model

LightningModule.configure_model()

Hook to create modules in a strategy and precision aware context.

This is particularly useful for when using sharded strategies (FSDP and DeepSpeed), where we’d like to shard the model instantly to save memory and initialization time. For non-sharded strategies, you can choose to override this hook or to initialize your model under theinit_module() context manager.

This hook is called during each of fit/val/test/predict stages in the same process, so ensure that implementation of this hook is idempotent, i.e., after the first time the hook is called, subsequent calls to it should be a no-op.

Return type:

None

on_validation_model_eval

LightningModule.on_validation_model_eval()

Called when the validation loop starts.

The validation loop by default calls .eval() on the LightningModule before it starts. Override this hook to change the behavior. See also on_validation_model_train().

Return type:

None

on_validation_model_train

LightningModule.on_validation_model_train()

Called when the validation loop ends.

The validation loop by default restores the training mode of the LightningModule to what it was before starting validation. Override this hook to change the behavior. See alsoon_validation_model_eval().

Return type:

None

on_test_model_eval

LightningModule.on_test_model_eval()

Called when the test loop starts.

The test loop by default calls .eval() on the LightningModule before it starts. Override this hook to change the behavior. See also on_test_model_train().

Return type:

None

on_test_model_train

LightningModule.on_test_model_train()

Called when the test loop ends.

The test loop by default restores the training mode of the LightningModule to what it was before starting testing. Override this hook to change the behavior. See alsoon_test_model_eval().

Return type:

None

on_before_optimizer_step

LightningModule.on_before_optimizer_step(optimizer)

Called before optimizer.step().

If using gradient accumulation, the hook is called once the gradients have been accumulated. See: accumulate_grad_batches.

If using AMP, the loss will be unscaled before calling this hook. See these docsfor more information on the scaling of gradients.

If clipping gradients, the gradients will not have been clipped yet.

Parameters:

optimizer (Optimizer) – Current optimizer being used.

Return type:

None

Example:

def on_before_optimizer_step(self, optimizer): # example to inspect gradient information in tensorboard if self.trainer.global_step % 25 == 0: # don't make the tf file huge for k, v in self.named_parameters(): self.logger.experiment.add_histogram( tag=k, values=v.grad, global_step=self.trainer.global_step )

configure_gradient_clipping

LightningModule.configure_gradient_clipping(optimizer, gradient_clip_val=None, gradient_clip_algorithm=None)[source]

Perform gradient clipping for the optimizer parameters. Called before optimizer_step().

Parameters:

Return type:

None

Example:

def configure_gradient_clipping(self, optimizer, gradient_clip_val, gradient_clip_algorithm): # Implement your own custom logic to clip gradients # You can call self.clip_gradients with your settings: self.clip_gradients( optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm )

optimizer_step

LightningModule.optimizer_step(epoch, batch_idx, optimizer, optimizer_closure=None)[source]

Override this method to adjust the default way the Trainer calls the optimizer.

By default, Lightning calls step() and zero_grad() as shown in the example. This method (and zero_grad()) won’t be called during the accumulation phase whenTrainer(accumulate_grad_batches != 1). Overriding this hook has no benefit with manual optimization.

Parameters:

Return type:

None

Examples:

def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure): # Add your custom logic to run directly before optimizer.step()

optimizer.step(closure=optimizer_closure)

# Add your custom logic to run directly after `optimizer.step()`

optimizer_zero_grad

LightningModule.optimizer_zero_grad(epoch, batch_idx, optimizer)[source]

Override this method to change the default behaviour of optimizer.zero_grad().

Parameters:

Return type:

None

Examples:

DEFAULT

def optimizer_zero_grad(self, epoch, batch_idx, optimizer): optimizer.zero_grad()

Set gradients to None instead of zero to improve performance (not required on torch>=2.0.0).

def optimizer_zero_grad(self, epoch, batch_idx, optimizer): optimizer.zero_grad(set_to_none=True)

See torch.optim.Optimizer.zero_grad() for the explanation of the above example.

prepare_data

LightningModule.prepare_data()

Use this to download and prepare data. Downloading and saving data with multiple processes (distributed settings) will result in corrupted data. Lightning ensures this method is called only within a single process, so you can safely add your downloading logic within. :rtype: None

Warning

DO NOT set state to the model (use setup instead) since this is NOT called on every device

Example:

def prepare_data(self): # good download_data() tokenize() etc()

# bad
self.split = data_split
self.some_state = some_other_state()

In a distributed environment, prepare_data can be called in two ways (using prepare_data_per_node)

  1. Once per node. This is the default and is only called on LOCAL_RANK=0.
  2. Once in total. Only called on GLOBAL_RANK=0.

Example:

DEFAULT

called once per node on LOCAL_RANK=0 of that node

class LitDataModule(LightningDataModule): def init(self): super().init() self.prepare_data_per_node = True

call on GLOBAL_RANK=0 (great for shared file systems)

class LitDataModule(LightningDataModule): def init(self): super().init() self.prepare_data_per_node = False

This is called before requesting the dataloaders:

model.prepare_data() initialize_distributed() model.setup(stage) model.train_dataloader() model.val_dataloader() model.test_dataloader() model.predict_dataloader()

setup

LightningModule.setup(stage)

Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Parameters:

stage (str) – either 'fit', 'validate', 'test', or 'predict'

Return type:

None

Example:

class LitModel(...): def init(self): self.l1 = None

def prepare_data(self):
    download_data()
    tokenize()

    # don't do this
    self.something = else

def setup(self, stage):
    data = load_data(...)
    self.l1 = nn.Linear(28, data.num_classes)

teardown

LightningModule.teardown(stage)

Called at the end of fit (train + validate), validate, test, or predict.

Parameters:

stage (str) – either 'fit', 'validate', 'test', or 'predict'

Return type:

None

train_dataloader

LightningModule.train_dataloader()

An iterable or collection of iterables specifying training samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you setreload_dataloaders_every_n_epochs to a positive integer.

For data processing use the following pattern: :rtype: Any

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

val_dataloader

LightningModule.val_dataloader()

An iterable or collection of iterables specifying validation samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you setreload_dataloaders_every_n_epochs to a positive integer.

It’s recommended that all data downloads and preparation happen in prepare_data(). :rtype: Any

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Note

If you don’t need a validation dataset and a validation_step(), you don’t need to implement this method.

test_dataloader

LightningModule.test_dataloader()

An iterable or collection of iterables specifying test samples.

For more information about multiple dataloaders, see this section.

For data processing use the following pattern: :rtype: Any

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Note

If you don’t need a test dataset and a test_step(), you don’t need to implement this method.

predict_dataloader

LightningModule.predict_dataloader()

An iterable or collection of iterables specifying prediction samples.

For more information about multiple dataloaders, see this section.

It’s recommended that all data downloads and preparation happen in prepare_data().

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Return type:

Any

Returns:

A torch.utils.data.DataLoader or a sequence of them specifying prediction samples.

transfer_batch_to_device

LightningModule.transfer_batch_to_device(batch, device, dataloader_idx)

Override this hook if your DataLoader returns tensors wrapped in a custom data structure.

The data types listed below (and any arbitrary nesting of them) are supported out of the box:

For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, …).

Note

This hook should only transfer the data and not modify it, nor should it move the data to any other device than the one passed in as argument (unless you know what you are doing). To check the current state of execution of this hook you can useself.trainer.training/testing/validating/predicting so that you can add different logic as per your requirement.

Parameters:

Return type:

Any

Returns:

A reference to the data on the new device.

Example:

def transfer_batch_to_device(self, batch, device, dataloader_idx): if isinstance(batch, CustomBatch): # move all tensors in your custom data structure to the device batch.samples = batch.samples.to(device) batch.targets = batch.targets.to(device) elif dataloader_idx == 0: # skip device transfer for the first dataloader or anything you wish pass else: batch = super().transfer_batch_to_device(batch, device, dataloader_idx) return batch

See also

on_before_batch_transfer

LightningModule.on_before_batch_transfer(batch, dataloader_idx)

Override to alter or apply batch augmentations to your batch before it is transferred to the device.

Note

To check the current state of execution of this hook you can useself.trainer.training/testing/validating/predicting so that you can add different logic as per your requirement.

Parameters:

Return type:

Any

Returns:

A batch of data

Example:

def on_before_batch_transfer(self, batch, dataloader_idx): batch['x'] = transforms(batch['x']) return batch

See also

on_after_batch_transfer

LightningModule.on_after_batch_transfer(batch, dataloader_idx)

Override to alter or apply batch augmentations to your batch after it is transferred to the device.

Note

To check the current state of execution of this hook you can useself.trainer.training/testing/validating/predicting so that you can add different logic as per your requirement.

Parameters:

Return type:

Any

Returns:

A batch of data

Example:

def on_after_batch_transfer(self, batch, dataloader_idx): batch['x'] = gpu_transforms(batch['x']) return batch

See also