BatchSizeFinder — PyTorch Lightning 2.5.1.post0 documentation (original) (raw)

class lightning.pytorch.callbacks.BatchSizeFinder(mode='power', steps_per_trial=3, init_val=2, max_trials=25, batch_arg_name='batch_size')[source]

Bases: Callback

Finds the largest batch size supported by a given model before encountering an out of memory (OOM) error.

All you need to do is add it as a callback inside Trainer and call trainer.{fit,validate,test,predict}. Internally, it calls the respective step function steps_per_trial times for each batch size until one of the batch sizes generates an OOM error.

Parameters:

Example:

1. Customize the BatchSizeFinder callback to run at different epochs. This feature is

useful while fine-tuning models since you can't always use the same batch size after

unfreezing the backbone.

from lightning.pytorch.callbacks import BatchSizeFinder

class FineTuneBatchSizeFinder(BatchSizeFinder): def init(self, milestones, *args, **kwargs): super().init(*args, **kwargs) self.milestones = milestones

def on_fit_start(self, *args, **kwargs):
    return

def on_train_epoch_start(self, trainer, pl_module):
    if trainer.current_epoch in self.milestones or trainer.current_epoch == 0:
        self.scale_batch_size(trainer, pl_module)

trainer = Trainer(callbacks=[FineTuneBatchSizeFinder(milestones=(5, 10))]) trainer.fit(...)

Example:

2. Run batch size finder for validate/test/predict.

from lightning.pytorch.callbacks import BatchSizeFinder

class EvalBatchSizeFinder(BatchSizeFinder): def init(self, *args, **kwargs): super().init(*args, **kwargs)

def on_fit_start(self, *args, **kwargs):
    return

def on_test_start(self, trainer, pl_module):
    self.scale_batch_size(trainer, pl_module)

trainer = Trainer(callbacks=[EvalBatchSizeFinder()]) trainer.test(...)

on_fit_start(trainer, pl_module)[source]

Called when fit begins.

Return type:

None

on_predict_start(trainer, pl_module)[source]

Called when the predict begins.

Return type:

None

on_test_start(trainer, pl_module)[source]

Called when the test begins.

Return type:

None

on_validation_start(trainer, pl_module)[source]

Called when the validation loop begins.

Return type:

None

setup(trainer, pl_module, stage=None)[source]

Called when fit, validate, test, predict, or tune begins.

Return type:

None