Migrating from Stable-Baselines — Stable Baselines3 2.6.1a0 documentation (original) (raw)

This is a guide to migrate from Stable-Baselines (SB2) to Stable-Baselines3 (SB3).

It also references the main changes.

Overview

Overall Stable-Baselines3 (SB3) keeps the high-level API of Stable-Baselines (SB2). Most of the changes are to ensure more consistency and are internal ones. Because of the backend change, from Tensorflow to PyTorch, the internal code is much more readable and easy to debug at the cost of some speed (dynamic graph vs static graph., see Issue #90) However, the algorithms were extensively benchmarked on Atari games and continuous control PyBullet envs (see Issue #48 and Issue #49) so you should not expect performance drop when switching from SB2 to SB3.

How to migrate?

In most cases, replacing from stable_baselines by from stable_baselines3 will be sufficient. Some files were moved to the common folder (cf below) and could result to import errors. Some algorithms were removed because of their complexity to improve the maintainability of the project. We recommend reading this guide carefully to understand all the changes that were made. You can also take a look at the rl-zoo3 and compare the imports to the rl-zoo of SB2 to have a concrete example of successful migration.

Note

If you experience massive slow-down switching to PyTorch, you may need to play with the number of threads used, using torch.set_num_threads(1) or OMP_NUM_THREADS=1, see issue #122and issue #90.

Breaking Changes

You can take a look at the issue about SB3 implementation design for more details.

Moved Files

Utility functions are no longer exported from common module, you should import them with their absolute path, e.g.:

from stable_baselines3.common.env_util import make_atari_env, make_vec_env from stable_baselines3.common.utils import set_random_seed

instead of from stable_baselines3.common import make_atari_env

Changes and renaming in parameters

Base-class (all algorithms)

Policies

A2C

Warning

PyTorch implementation of RMSprop differs from Tensorflow’s, which leads to different and potentially more unstable results. Use stable_baselines3.common.sb2_compat.rmsprop_tf_like.RMSpropTFLike optimizer to match the results with TensorFlow’s implementation. This can be done through policy_kwargs: A2C(policy_kwargs=dict(optimizer_class=RMSpropTFLike, optimizer_kwargs=dict(eps=1e-5)))

PPO

Warning

nminibatches gave different batch size depending on the number of environments: batch_size = (n_steps * n_envs) // nminibatches

PPO default hyperparameters are the one tuned for continuous control environment. We recommend taking a look at the RL Zoo for hyperparameters tuned for Atari games.

DQN

Only the vanilla DQN is implemented right now but extensions will follow. Default hyperparameters are taken from the Nature paper, except for the optimizer and learning rate that were taken from Stable Baselines defaults.

DDPG

DDPG now follows the same interface as SAC/TD3. For state/reward normalization, you should use VecNormalize as for all other algorithms.

SAC/TD3

SAC/TD3 now accept any number of critics, e.g. policy_kwargs=dict(n_critics=3), instead of only two before.

Note

SAC/TD3 default hyperparameters (including network architecture) now match the ones from the original papers. DDPG is using TD3 defaults.

SAC

SAC implementation matches the latest version of the original implementation: it uses two Q function networks and two target Q function networks instead of two Q function networks and one Value function network (SB2 implementation, first version of the original implementation). Despite this change, no change in performance should be expected.

Note

SAC predict() method has now deterministic=False by default for consistency. To match SB2 behavior, you need to explicitly pass deterministic=True

HER

The HER implementation now only supports online sampling of the new goals. This is done in a vectorized version. The goal selection strategy RANDOM is no longer supported.

New logger API

Internal Changes

Please read the Developer Guide section.

New Features (SB3 vs SB2)