Performance check (Continuous Actions) · Issue #48 · DLR-RM/stable-baselines3 (original) (raw)
Check that the algorithms reach expected performance.
This was already done prior to v0.5 for the gSDE paper but as we made big changes, it is good to check that again.
SB2 vs SB3 (Tensorflow Stable-Baselines vs Pytorch Stable-Baselines3)
- A2C (6 seeds)
a2c.pdf
a2c_ant.pdf
a2c_half.pdf
a2c_hopper.pdf
a2c_walker.pdf
- PPO (6 seeds)
ppo.pdf
ant_ppo.pdf
half_ppo.pdf
hopper_ppo.pdf
ppo_walker.pdf
- SAC (3 seeds)
sac.pdf
sac_ant.pdf
sac_half.pdf
sac_hopper.pdf
sac_walker.pdf
- TD3 (3 seeds)
td3.pdf
td3_ant.pdf
td3_half.pdf
td3_hopper.pdf
td3_walker.pdf
See https://paperswithcode.com/paper/generalized-state-dependent-exploration-for for the score that should be reached in 1M (off-policy) or 2M steps (on-policy).
Test envs; PyBullet Envs
Tested with version 0.8.0 (feat/perf-check
branch in the two zoos)
SB3 commit hash: cceffd5
rl-zoo commit hash: 99f7dd0321c5beea1a0d775ad6bc043d41f3e2db
Environments | A2C | A2C | PPO | PPO | SAC | SAC | TD3 | TD3 |
---|---|---|---|---|---|---|---|---|
SB2 | SB3 | SB2 | SB3 | SB2 | SB3 | SB2 | SB3 | |
HalfCheetah | 1859 +/- 161 | 2003 +/- 54 | 2186 +/- 260 | 1976 +/- 479 | 2833 +/- 21 | 2757 +/- 53 | 2530 +/- 141 | 2774 +/- 35 |
Ant | 2155 +/- 237 | 2286 +/- 72 | 2383 +/- 284 | 2364 +/- 120 | 3349 +/- 60 | 3146 +/- 35 | 3368 +/- 125 | 3305 +/- 43 |
Hopper | 1457 +/- 75 | 1627 +/- 158 | 1166 +/- 287 | 1567 +/- 339 | 2391 +/- 238 | 2422 +/- 168 | 2542 +/- 79 | 2429 +/- 126 |
Walker2D | 689 +/- 59 | 577 +/- 65 | 1117 +/- 121 | 1230 +/- 147 | 2202 +/- 45 | 2184 +/- 54 | 1686 +/- 584 | 2063 +/- 185 |
Generalized State-Dependent Exploration (gSDE)
- gSDE See https://arxiv.org/abs/2005.05719
See https://paperswithcode.com/paper/generalized-state-dependent-exploration-for for the score that should be reached in 1M (off-policy) or 2M steps (on-policy).
- on policy (2M steps, 6 seeds):
gsde_onpolicy.pdf
gsde_onpolicy_ant.pdf
gsde_onpolicy_half.pdf
gsde_onpolicy_hopper.pdf
gsde_onpolicy_walker.pdf
- off-policy (1M steps, 3 seeds):
gsde_off_policy.pdf
gsde_offpolicy_ant.pdf
gsde_offpolicy_half.pdf
gsde_offpolicy_hopper.pdf
gsde_offpolicy_walker.pdf
SB3 commit hash: b948b7f
rl zoo commit hash: b56c1470c9a958c196f60e814de893050e2469f0
Environments | A2C | A2C | PPO | PPO | SAC | SAC | TD3 | TD3 |
---|---|---|---|---|---|---|---|---|
Gaussian | gSDE | Gaussian | gSDE | Gaussian | gSDE | Gaussian | gSDE | |
HalfCheetah | 2003 +/- 54 | 2032 +/- 122 | 1976 +/- 479 | 2826 +/- 45 | 2757 +/- 53 | 2984 +/- 202 | 2774 +/- 35 | 2592 +/- 84 |
Ant | 2286 +/- 72 | 2443 +/- 89 | 2364 +/- 120 | 2782 +/- 76 | 3146 +/- 35 | 3102 +/- 37 | 3305 +/- 43 | 3345 +/- 39 |
Hopper | 1627 +/- 158 | 1561 +/- 220 | 1567 +/- 339 | 2512 +/- 21 | 2422 +/- 168 | 2262 +/- 1 | 2429 +/- 126 | 2515 +/- 67 |
Walker2D | 577 +/- 65 | 839 +/- 56 | 1230 +/- 147 | 2019 +/- 64 | 2184 +/- 54 | 2136 +/- 67 | 2063 +/- 185 | 1814 +/- 395 |
DDPG
Using TD3 hyperparameters as base with some minor adjustements (lr, batch_size
) for stability.
6 seeds, 1M steps.
Environments | DDPG |
---|---|
Gaussian | |
HalfCheetah | 2272 +/- 69 |
Ant | 1651 +/- 407 |
Hopper | 1201 +/- 211 |
Walker2D | 882 +/- 186 |