Customize the progress bar — PyTorch Lightning 2.5.1.post0 documentation (original) (raw)
Lightning supports two different types of progress bars (tqdm and rich). TQDMProgressBar is used by default, but you can override it by passing a custom TQDMProgressBar or RichProgressBar to the callbacks
argument of the Trainer.
You could also use the ProgressBar class to implement your own progress bar.
TQDMProgressBar¶
The TQDMProgressBar uses the tqdm library internally and is the default progress bar used by Lightning. It prints to stdout
and shows up to four different bars:
- sanity check progress: the progress during the sanity check run
- train progress: shows the training progress. It will pause if validation starts and will resume when it ends, and also accounts for multiple validation runs during training when val_check_interval is used.
- validation progress: only visible during validation; shows total progress over all validation datasets.
- test progress: only active when testing; shows total progress over all test datasets.
For infinite datasets, the progress bar never ends.
You can update refresh_rate
(rate (number of batches) at which the progress bar get updated) for TQDMProgressBar by:
from lightning.pytorch.callbacks import TQDMProgressBar
trainer = Trainer(callbacks=[TQDMProgressBar(refresh_rate=10)])
Note
The smoothing
option has no effect when using the default implementation of TQDMProgressBar, as the progress bar is updated using the bar.refresh()
method instead of bar.update()
. This can cause the progress bar to become desynchronized with the actual progress. To avoid this issue, you can use the bar.update()
method instead, but this may require customizing the TQDMProgressBar class.
By default the training progress bar is reset (overwritten) at each new epoch. If you wish for a new progress bar to be displayed at the end of every epoch, setTQDMProgressBar.leave to True
.
trainer = Trainer(callbacks=[TQDMProgressBar(leave=True)])
If you want to customize the default TQDMProgressBar used by Lightning, you can override specific methods of the callback class and pass your custom implementation to the Trainer.
class LitProgressBar(TQDMProgressBar): def init_validation_tqdm(self): bar = super().init_validation_tqdm() bar.set_description("running validation...") return bar
trainer = Trainer(callbacks=[LitProgressBar()])
RichProgressBar¶
Rich is a Python library for rich text and beautiful formatting in the terminal. To use the RichProgressBar as your progress bar, first install the package:
Then configure the callback and pass it to the Trainer:
from lightning.pytorch.callbacks import RichProgressBar
trainer = Trainer(callbacks=[RichProgressBar()])
Customize the theme for your RichProgressBar like this:
from lightning.pytorch.callbacks import RichProgressBar from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBarTheme
create your own theme!
progress_bar = RichProgressBar( theme=RichProgressBarTheme( description="green_yellow", progress_bar="green1", progress_bar_finished="green1", progress_bar_pulse="#6206E0", batch_progress="green_yellow", time="grey82", processing_speed="grey82", metrics="grey82", metrics_text_delimiter="\n", metrics_format=".3e", ) )
trainer = Trainer(callbacks=progress_bar)
You can customize the components used within RichProgressBar with ease by overriding theconfigure_columns()
method.
from rich.progress import TextColumn
custom_column = TextColumn("[progress.description]Custom Rich Progress Bar!")
class CustomRichProgressBar(RichProgressBar): def configure_columns(self, trainer): return [custom_column]
progress_bar = CustomRichProgressBar()
If you wish for a new progress bar to be displayed at the end of every epoch, you should enableRichProgressBar.leave by passing True
from lightning.pytorch.callbacks import RichProgressBar
trainer = Trainer(callbacks=[RichProgressBar(leave=True)])
See also
- RichProgressBar docs.
- RichModelSummary docs to customize the model summary table.
- Rich library.
Note
Progress bar is automatically enabled with the Trainer, and to disable it, one should do this:
trainer = Trainer(enable_progress_bar=False)