TQDMProgressBar — PyTorch Lightning 2.5.1.post0 documentation (original) (raw)
class lightning.pytorch.callbacks.TQDMProgressBar(refresh_rate=1, process_position=0, leave=False)[source]¶
Bases: ProgressBar
This is the default progress bar used by Lightning. It prints to stdout
using the tqdm
package and shows up to four different bars:
- sanity check progress: the progress during the sanity check run
- train progress: shows the training progress. It will pause if validation starts and will resume when it ends, and also accounts for multiple validation runs during training whenval_check_interval is used.
- validation progress: only visible during validation; shows total progress over all validation datasets.
- test progress: only active when testing; shows total progress over all test datasets.
For infinite datasets, the progress bar never ends.
If you want to customize the default tqdm
progress bars used by Lightning, you can override specific methods of the callback class and pass your custom implementation to theTrainer.
Example
class LitProgressBar(TQDMProgressBar): ... def init_validation_tqdm(self): ... bar = super().init_validation_tqdm() ... bar.set_description('running validation ...') ... return bar ... bar = LitProgressBar() from lightning.pytorch import Trainer trainer = Trainer(callbacks=[bar])
Parameters:
- refresh_rate¶ (int) – Determines at which rate (in number of batches) the progress bars get updated. Set it to
0
to disable the display. - process_position¶ (int) – Set this to a value greater than
0
to offset the progress bars by this many lines. This is useful when you have progress bars defined elsewhere and want to show all of them together. - leave¶ (bool) – If set to
True
, leaves the finished progress bar in the terminal at the end of the epoch. Default:False
You should provide a way to disable the progress bar.
Return type:
You should provide a way to enable the progress bar.
The Trainer will call this in e.g. pre-training routines like the learning rate finder. to temporarily enable and disable the training progress bar.
Return type:
Override this to customize the tqdm bar for predicting.
Return type:
Tqdm
Override this to customize the tqdm bar for the validation sanity run.
Return type:
Tqdm
Override this to customize the tqdm bar for testing.
Return type:
Tqdm
Override this to customize the tqdm bar for training.
Return type:
Tqdm
init_validation_tqdm()[source]¶
Override this to customize the tqdm bar for validation.
Return type:
Tqdm
on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]¶
Called when the predict batch ends.
Return type:
on_predict_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)[source]¶
Called when the predict batch begins.
Return type:
on_predict_end(trainer, pl_module)[source]¶
Called when predict ends.
Return type:
on_predict_start(trainer, pl_module)[source]¶
Called when the predict begins.
Return type:
on_sanity_check_end(*_)[source]¶
Called when the validation sanity check ends.
Return type:
on_sanity_check_start(*_)[source]¶
Called when the validation sanity check starts.
Return type:
on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]¶
Called when the test batch ends.
Return type:
on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)[source]¶
Called when the test batch begins.
Return type:
on_test_end(trainer, pl_module)[source]¶
Called when the test ends.
Return type:
on_test_start(trainer, pl_module)[source]¶
Called when the test begins.
Return type:
on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]¶
Called when the train batch ends. :rtype: None
Note
The value outputs["loss"]
here will be the normalized value w.r.t accumulate_grad_batches
of the loss returned from training_step
.
Called when the train ends.
Return type:
on_train_epoch_end(trainer, pl_module)[source]¶
Called when the train epoch ends.
To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of thelightning.pytorch.core.LightningModule and access them in this hook:
class MyLightningModule(L.LightningModule): def init(self): super().init() self.training_step_outputs = []
def training_step(self):
loss = ...
self.training_step_outputs.append(loss)
return loss
class MyCallback(L.Callback): def on_train_epoch_end(self, trainer, pl_module): # do something with all training_step outputs, for example: epoch_mean = torch.stack(pl_module.training_step_outputs).mean() pl_module.log("training_epoch_mean", epoch_mean) # free up the memory pl_module.training_step_outputs.clear()
Return type:
on_train_epoch_start(trainer, *_)[source]¶
Called when the train epoch begins.
Return type:
Called when the train begins.
Return type:
on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]¶
Called when the validation batch ends.
Return type:
on_validation_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)[source]¶
Called when the validation batch begins.
Return type:
on_validation_end(trainer, pl_module)[source]¶
Called when the validation loop ends.
Return type:
on_validation_start(trainer, pl_module)[source]¶
Called when the validation loop begins.
Return type:
print(*args, sep=' ', **kwargs)[source]¶
You should provide a way to print without breaking the progress bar.
Return type: