Tensorboard Integration — Stable Baselines3 2.6.1a0 documentation (original) (raw)
Basic Usage
To use Tensorboard with stable baselines3, you simply need to pass the location of the log folder to the RL agent:
from stable_baselines3 import A2C
model = A2C("MlpPolicy", "CartPole-v1", verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/") model.learn(total_timesteps=10_000)
You can also define custom logging name when training (by default it is the algorithm name)
from stable_baselines3 import A2C
model = A2C("MlpPolicy", "CartPole-v1", verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/") model.learn(total_timesteps=10_000, tb_log_name="first_run")
Pass reset_num_timesteps=False to continue the training curve in tensorboard
By default, it will create a new curve
Keep tb_log_name constant to have continuous curve (see note below)
model.learn(total_timesteps=10_000, tb_log_name="second_run", reset_num_timesteps=False) model.learn(total_timesteps=10_000, tb_log_name="third_run", reset_num_timesteps=False)
Note
If you specify different tb_log_name
in subsequent runs, you will have split graphs, like in the figure below. If you want them to be continuous, you must keep the same tb_log_name
(see issue #975). And, if you still managed to get your graphs split by other means, just put tensorboard log files into the same folder.
Once the learn function is called, you can monitor the RL agent during or after the training, with the following bash command:
tensorboard --logdir ./a2c_cartpole_tensorboard/
Note
You can find explanations about the logger output and names in the Logger section.
you can also add past logging folders:
tensorboard --logdir ./a2c_cartpole_tensorboard/;./ppo2_cartpole_tensorboard/
It will display information such as the episode reward (when using a Monitor
wrapper), the model losses and other parameter unique to some models.
Logging More Values
Using a callback, you can easily log more values with TensorBoard. Here is a simple example on how to log both additional tensor or arbitrary scalar value:
import numpy as np
from stable_baselines3 import SAC from stable_baselines3.common.callbacks import BaseCallback
model = SAC("MlpPolicy", "Pendulum-v1", tensorboard_log="/tmp/sac/", verbose=1)
class TensorboardCallback(BaseCallback): """ Custom callback for plotting additional values in tensorboard. """
def __init__(self, verbose=0):
super().__init__(verbose)
def _on_step(self) -> bool:
# Log scalar value (here a random variable)
value = np.random.random()
self.logger.record("random_value", value)
return True
model.learn(50000, callback=TensorboardCallback())
Note
If you want to log values more often than the default to tensorboard, you manually call self.logger.dump(self.num_timesteps)
in a callback (see issue #506).
Logging Images
TensorBoard supports periodic logging of image data, which helps evaluating agents at various stages during training.
Warning
To support image logging pillow must be installed otherwise, TensorBoard ignores the image and logs a warning.
Here is an example of how to render an image to TensorBoard at regular intervals:
from stable_baselines3 import SAC from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.logger import Image
model = SAC("MlpPolicy", "Pendulum-v1", tensorboard_log="/tmp/sac/", verbose=1)
class ImageRecorderCallback(BaseCallback): def init(self, verbose=0): super().init(verbose)
def _on_step(self):
image = self.training_env.render(mode="rgb_array")
# "HWC" specify the dataformat of the image, here channel last
# (H for height, W for width, C for channel)
# See https://pytorch.org/docs/stable/tensorboard.html
# for supported formats
self.logger.record("trajectory/image", Image(image, "HWC"), exclude=("stdout", "log", "json", "csv"))
return True
model.learn(50000, callback=ImageRecorderCallback())
Logging Figures/Plots
TensorBoard supports periodic logging of figures/plots created with matplotlib, which helps evaluate agents at various stages during training.
Warning
To support figure logging matplotlib must be installed otherwise, TensorBoard ignores the figure and logs a warning.
Here is an example of how to store a plot in TensorBoard at regular intervals:
import numpy as np import matplotlib.pyplot as plt
from stable_baselines3 import SAC from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.logger import Figure
model = SAC("MlpPolicy", "Pendulum-v1", tensorboard_log="/tmp/sac/", verbose=1)
class FigureRecorderCallback(BaseCallback): def init(self, verbose=0): super().init(verbose)
def _on_step(self):
# Plot values (here a random variable)
figure = plt.figure()
figure.add_subplot().plot(np.random.random(3))
# Close the figure after logging it
self.logger.record("trajectory/figure", Figure(figure, close=True), exclude=("stdout", "log", "json", "csv"))
plt.close()
return True
model.learn(50000, callback=FigureRecorderCallback())
Logging Videos
TensorBoard supports periodic logging of video data, which helps evaluate agents at various stages during training.
Warning
To support video logging moviepy must be installed otherwise, TensorBoard ignores the video and logs a warning.
Here is an example of how to render an episode and log the resulting video to TensorBoard at regular intervals:
from typing import Any, Dict
import gymnasium as gym import torch as th import numpy as np
from stable_baselines3 import A2C from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.logger import Video
class VideoRecorderCallback(BaseCallback):
def init(self, eval_env: gym.Env, render_freq: int, n_eval_episodes: int = 1, deterministic: bool = True):
"""
Records a video of an agent's trajectory traversing eval_env
and logs it to TensorBoard
:param eval_env: A gym environment from which the trajectory is recorded
:param render_freq: Render the agent's trajectory every eval_freq call of the callback.
:param n_eval_episodes: Number of episodes to render
:param deterministic: Whether to use deterministic or stochastic policy
"""
super().__init__()
self._eval_env = eval_env
self._render_freq = render_freq
self._n_eval_episodes = n_eval_episodes
self._deterministic = deterministic
def _on_step(self) -> bool:
if self.n_calls % self._render_freq == 0:
screens = []
def grab_screens(_locals: Dict[str, Any], _globals: Dict[str, Any]) -> None:
"""
Renders the environment in its current state, recording the screen in the captured `screens` list
:param _locals: A dictionary containing all local variables of the callback's scope
:param _globals: A dictionary containing all global variables of the callback's scope
"""
# We expect `render()` to return a uint8 array with values in [0, 255] or a float array
# with values in [0, 1], as described in
# https://pytorch.org/docs/stable/tensorboard.html#torch.utils.tensorboard.writer.SummaryWriter.add_video
screen = self._eval_env.render(mode="rgb_array")
# PyTorch uses CxHxW vs HxWxC gym (and tensorflow) image convention
screens.append(screen.transpose(2, 0, 1))
evaluate_policy(
self.model,
self._eval_env,
callback=grab_screens,
n_eval_episodes=self._n_eval_episodes,
deterministic=self._deterministic,
)
self.logger.record(
"trajectory/video",
Video(th.from_numpy(np.asarray([screens])), fps=40),
exclude=("stdout", "log", "json", "csv"),
)
return True
model = A2C("MlpPolicy", "CartPole-v1", tensorboard_log="runs/", verbose=1) video_recorder = VideoRecorderCallback(gym.make("CartPole-v1"), render_freq=5000) model.learn(total_timesteps=int(5e4), callback=video_recorder)
Logging Hyperparameters
TensorBoard supports logging of hyperparameters in its HPARAMS tab, which helps to compare agents trainings.
Warning
To display hyperparameters in the HPARAMS section, a metric_dict
must be given (as well as a hparam_dict
).
Here is an example of how to save hyperparameters in TensorBoard:
from stable_baselines3 import A2C from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.logger import HParam
class HParamCallback(BaseCallback): """ Saves the hyperparameters and metrics at the start of the training, and logs them to TensorBoard. """
def _on_training_start(self) -> None:
hparam_dict = {
"algorithm": self.model.__class__.__name__,
"learning rate": self.model.learning_rate,
"gamma": self.model.gamma,
}
# define the metrics that will appear in the `HPARAMS` Tensorboard tab by referencing their tag
# Tensorbaord will find & display metrics from the `SCALARS` tab
metric_dict = {
"rollout/ep_len_mean": 0,
"train/value_loss": 0.0,
}
self.logger.record(
"hparams",
HParam(hparam_dict, metric_dict),
exclude=("stdout", "log", "json", "csv"),
)
def _on_step(self) -> bool:
return True
model = A2C("MlpPolicy", "CartPole-v1", tensorboard_log="runs/", verbose=1) model.learn(total_timesteps=int(5e4), callback=HParamCallback())
Directly Accessing The Summary Writer
If you would like to log arbitrary data (in one of the formats supported by pytorch), you can get direct access to the underlying SummaryWriter in a callback:
Warning
This is method is not recommended and should only be used by advanced users.
from stable_baselines3 import SAC from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.logger import TensorBoardOutputFormat
model = SAC("MlpPolicy", "Pendulum-v1", tensorboard_log="/tmp/sac/", verbose=1)
class SummaryWriterCallback(BaseCallback):
def _on_training_start(self):
self._log_freq = 1000 # log every 1000 calls
output_formats = self.logger.output_formats
# Save reference to tensorboard formatter object
# note: the failure case (not formatter found) is not handled here, should be done with try/except.
self.tb_formatter = next(formatter for formatter in output_formats if isinstance(formatter, TensorBoardOutputFormat))
def _on_step(self) -> bool:
if self.n_calls % self._log_freq == 0:
# You can have access to info from the env using self.locals.
# for instance, when using one env (index 0 of locals["infos"]):
# lap_count = self.locals["infos"][0]["lap_count"]
# self.tb_formatter.writer.add_scalar("train/lap_count", lap_count, self.num_timesteps)
self.tb_formatter.writer.add_text("direct_access", "this is a value", self.num_timesteps)
self.tb_formatter.writer.flush()
model.learn(50000, callback=SummaryWriterCallback())