BatchSizeFinder — PyTorch Lightning 2.5.1.post0 documentation (original) (raw)
class lightning.pytorch.callbacks.BatchSizeFinder(mode='power', steps_per_trial=3, init_val=2, max_trials=25, batch_arg_name='batch_size')[source]¶
Bases: Callback
Finds the largest batch size supported by a given model before encountering an out of memory (OOM) error.
All you need to do is add it as a callback inside Trainer and call trainer.{fit,validate,test,predict}
. Internally, it calls the respective step function steps_per_trial
times for each batch size until one of the batch sizes generates an OOM error.
Parameters:
- mode¶ (str) –
search strategy to update the batch size:'power'
: Keep multiplying the batch size by 2, until we get an OOM error.'binsearch'
: Initially keep multiplying by 2 and after encountering an OOM error do a binary search between the last successful batch size and the batch size that failed.
- steps_per_trial¶ (int) – number of steps to run with a given batch size. Ideally 1 should be enough to test if an OOM error occurs, however in practice a few are needed.
- init_val¶ (int) – initial batch size to start the search with.
- max_trials¶ (int) – max number of increases in batch size done before algorithm is terminated
- batch_arg_name¶ (str) –
name of the attribute that stores the batch size. It is expected that the user has provided a model or datamodule that has a hyperparameter with that name. We will look for this attribute name in the following placesmodel
model.hparams
trainer.datamodule
(the datamodule passed to the tune method)
Example:
1. Customize the BatchSizeFinder callback to run at different epochs. This feature is
useful while fine-tuning models since you can't always use the same batch size after
unfreezing the backbone.
from lightning.pytorch.callbacks import BatchSizeFinder
class FineTuneBatchSizeFinder(BatchSizeFinder): def init(self, milestones, *args, **kwargs): super().init(*args, **kwargs) self.milestones = milestones
def on_fit_start(self, *args, **kwargs):
return
def on_train_epoch_start(self, trainer, pl_module):
if trainer.current_epoch in self.milestones or trainer.current_epoch == 0:
self.scale_batch_size(trainer, pl_module)
trainer = Trainer(callbacks=[FineTuneBatchSizeFinder(milestones=(5, 10))]) trainer.fit(...)
Example:
2. Run batch size finder for validate/test/predict.
from lightning.pytorch.callbacks import BatchSizeFinder
class EvalBatchSizeFinder(BatchSizeFinder): def init(self, *args, **kwargs): super().init(*args, **kwargs)
def on_fit_start(self, *args, **kwargs):
return
def on_test_start(self, trainer, pl_module):
self.scale_batch_size(trainer, pl_module)
trainer = Trainer(callbacks=[EvalBatchSizeFinder()]) trainer.test(...)
on_fit_start(trainer, pl_module)[source]¶
Called when fit begins.
Return type:
on_predict_start(trainer, pl_module)[source]¶
Called when the predict begins.
Return type:
on_test_start(trainer, pl_module)[source]¶
Called when the test begins.
Return type:
on_validation_start(trainer, pl_module)[source]¶
Called when the validation loop begins.
Return type:
setup(trainer, pl_module, stage=None)[source]¶
Called when fit, validate, test, predict, or tune begins.
Return type: