Match performance with stable-baselines (discrete case) by Miffyli · Pull Request #110 · DLR-RM/stable-baselines3 (original) (raw)
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service andprivacy statement. We’ll occasionally send you account related emails.
Already on GitHub?Sign in to your account
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.Learn more about bidirectional Unicode characters
[ Show hidden characters]({{ revealButtonHref }})
This PR will be done when stable-baselines3 agent performance matches stable-baselines in discrete envs. Will be tested on discrete control tasks and Atari environments.
PS: Sorry about the confusing branch-name.
Changes
- Fix storing correct dones ([bug] on-policy rollout collects current "dones" instead of last "dones" #105, credits to AndyShih12)
- Fix number of filters in NatureCNN
- Add
common.sb2_compat.RMSpropTFLike
, which is a modification of RMSprop that matches TF version, and is required for matching performance in A2C.
TODO
- Match performance of A2C and PPO.
- A2C Cartpole matches (mostly, see this. Averaged over 10 random seeds for both. Requires the TF-like RMSprop, and even still in the very end SB3 seems more unstable.)
- A2C Atari matches (mostly, see sb2 and sb3. Original sb3 result here. Three random seeds, each line separate run (ignore legend). Using TF-like RMSprop. Performance and stability mostly matches, except sb2 has sudden spike in performance in Q*Bert. Something to do with stability in distributions?)
- PPO Cartpole (using rl-zoo parameters, see learning curves, averaged over 20 random seeds)
- PPO Atari (mostly, see sb2 and sb3 results (shaded curves averaged over two seeds). Q*Bert still seems to have an edge on SB2 for unknown reasons)
- Check and match performance of DQN. Seems ok. See following learning curves, each curve is an average over three random seeds:
atari_spaceinvaders.pdf
atari_qbert.pdf
atari_breakout.pdf
atari_pong.pdf - Check if "dones" fix can (and should) be moved to computing GAE side.
Write docs on how to match A2C and PPO settings to stable-baselines ("moving from stable-baselines"). There are some important quirks to note here.Move this to migration guide PR Migration Guide #123 .
Types of changes
- Bug fix (non-breaking change which fixes an issue)
- New feature (non-breaking change which adds functionality)
- Breaking change (fix or feature that would cause existing functionality to change)
- Documentation (update in the documentation)
Checklist:
- I've read the CONTRIBUTION guide (required)
- I have updated the changelog accordingly (required).
- My change requires a change to the documentation.
- I have updated the tests accordingly (required for a bug fix or a new feature).
- I have updated the documentation accordingly.
- I have reformatted the code using
make format
(required) - I have checked the codestyle using
make check-codestyle
andmake lint
(required) - I have ensured
make pytest
andmake type
both pass. (required)
@@ -74,7 +74,7 @@ def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 512): |
---|
nn.ReLU(), |
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), |
nn.ReLU(), |
nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=0), |
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0), |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thinking about that, we need to double check VecFrameStack
, even though it is the same as SB2.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sadly (luckily? =) ) it did not fix the issues yet. SB3 is still consistently worse in a few of the Atari games I have tested. I am in the process of step-by-step comparisons, will see how that goes.
Edit: Ah yes, stacking on the wrong channels?
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one thing that may change is the optimizer implementation and default parameters, for the initialization, I think (at least I tried) to reproduce what was done in SB2.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my question was more what is the fps we want to reach? (what did you have with SB2?)
Hmm I do not have conclusive numbers just yet because I have been running many experiments on same system and can not guarantee fair comparisons, but I think PyTorch variants are about 10% slower with Atari games and 25% slower on toy environments. The latter required the OMP_NUM_THREADS tuning. This sounds reasonable, given the non-compiled nature of PyTorch and the fact the code has not been optimized much yet.
Yes, the issue was that nminibatches lead to different mini-batchsize depending on the number of environments
Ah alright. I will write big notes about this on the "moving from stable-baselines" docs :)
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One major change in parameters is the use of
batch_size=64
rather thannminibatches=4
in PPO. Using such small batch-size made things very slow FPS-wise, but in some cases sped up the learning. I will focus more on these running-speed things in an another PR.
I would like to add that we may be able to gain a non minuscule speedup by avoiding single data stores but instead storing a whole batch at once.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I started documenting the migration here ;)
#123
I would like to add that we may be able to gain a non minuscule speedup by avoiding single data stores but instead storing a while batch at once.
?
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mistyped, I meant that if we store a whole batch at once, we should get a sizeable speedup over storing one transition at a time.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
still not sure what you mean...
I am wondering: are you using clip_range_vf
for PPO?
The behavior of this parameter changed between SB2 and SB3 (i'm currently documenting all thoses changes in a new branch).
I am wondering: are you using clip_range_vf for PPO?
The behavior of this parameter changed between SB2 and SB3 (i'm currently documenting all thoses changes in a new branch).
I am using the parameters from rl-zoo for Atari PPO runs, where vf clipping is disabled and I set cliprange_vf=-1
for sb2 and clip_range_vf=None
for sb3, which I understood is the matching behaviour (disables it).
Some progression with A2C: With CartPole you get very similar learning curves (below, averaged over 10 random seeds) with rl-zoo parameters, after you update the PyTorch RMSProp to match TF's implementation. Turns out PyTorch RMSProp does things a little bit different, and these are crucial for stable learning like shown. These changes require a new optimizer or changes to PyTorch code, so should we include a modified RMSProp in stable-baselines3 like done here in another repo? We could include this as an additional optimizer and instruct to use it if one wants to replicate sb2 results, but we could also consider making it default RMSProp optimizer because of its (apparent) stability.
A2C seems to check out mostly (see the original post with plots) with the fixed RMSprop that is now included under sb2_compat
. If this approach is ok, I can write notes in docs about this RMSprop with A2C and do same in #123 .
A2C seems to check out mostly (see the original post with plots)
including Atari games?
If this approach is ok, I can write notes in docs about this RMSprop with A2C and do same in #123 .
Sounds reasonable, I don't see any better solution... The only thing is which default should we use?
(I will re-run a quick continuous control benchmark with the updated RMSProp, depending on that we will see)
including Atari games?
Yup! See the original post with plots. To me they seem "close enough" (with this limited amount of runs), except for Q*Bert which at end gets a sudden boost in performance in sb2. I will be checking PPO next and see if there is something common to A2C and PPO the is derp.
Sounds reasonable, I don't see any better solution... The only thing is which default should we use?
TF variant seems more stable and pytorch-image-models repo guys also say they have had better success with it. I'd personally go with that one by default.
(I will re-run a quick continuous control benchmark with the updated RMSProp, depending on that we will see)
Remember to set the parameters manually! I forgot this first time around ^^
policy_kwargs["optimizer_class"] = RMSpropTFLike
policy_kwargs["optimizer_kwargs"] = dict(alpha=0.99, eps=1e-5, weight_decay=0)
After a quick run on Bullet envs, th.optim.RMSProp
yield better performances for continuous control...
Mean final reward over 3 seeds on HalfCheetahBulletEnv-v0
:
tf-rmsprop: 1192
torch-rmsprop: 1912
How is the stability, though? I noticed tf.optim.RMSprop
learns faster with its bigger gradients but does not seem to converge so easily, while the TF-variant learns slower but is more stable (see the plots I have above).
Edit: In the light of these results we could keep the original enabled by default, though, and instruct people to use the TF-variant if they are experiencing unstable learning.
How is the stability, though?
A bit unstable at the beginning.
Edit: In the light of these results we could keep the original enabled by default, though, and instruct people to use the TF-variant if they are experiencing unstable learning.
Yes, and add the tf-version as default in the zoo for Atari?
see the plots I have above
I only see the plots where the two are similar.
See the original post with plots.
In the original post, I only see ppo plots...
Yes, and add the tf-version as default in the zoo for Atari?
Works for me 👍
I only see the plots where the two are similar.
In the original post, I only see ppo plots...
Hmm there should be four A2C plots in total under "TODO" heading: A2C cartpole comparisons (with rmsprop fixes), sb2 and sb3 Atari results for A2C and sb3 Atari results without rmsprop fix.
Hmm there should be four A2C plots in total under "TODO" heading: A2C cartpole comparisons (with rmsprop fixes), sb2 and sb3 Atari results for A2C and sb3 Atari results without rmsprop fix.
🙈 I was looking at the issue, not the PR...
Ran some more Atari PPO runs and now the performance seems to match (see the original post for plots). SB3 seems to be consistently lower than SB2 but nothing seems horribly broken. Q*Bert has an edge on SB2 for some reason with both PPO and A2C. I will be re-running experiments with more seeds, but that will take time. @araffin could you comment on the learning curves and tell what you think about the results?
Ran some more Atari PPO runs and now the performance seems to match (see the original post for plots). SB3 seems to be consistently lower than SB2 but nothing seems horribly broken
Do you know if the ADAM implementation is the same for A2C/PPO?
Do you know if the ADAM implementation is the same for A2C/PPO?
Quick googling and skimming over the codes they seem to match, and also the A2C experiments matched with Adam (equally unstable :D), so I believe that part checks out.
And how many random seeds did you try?
For me, it looks good ;) I cannot spot dramatic performance drop, and ppo matches ppo2 performance for the continuous case (I will do a check again though before 1.0 release).
And how many random seeds did you try?
Each of the curves is slightly different setup but, in general, tend to have the same result (see Figure 5 here, where we have five random seeds per curve). I.e. you can treat each curve as separate run with different random seed. But I will run some more for better conclusion.
Actually DQN uses Adam for optimizing, and it has been using it since stable-baselinse2, while (I think) the original implementation used rmsprop. It might be worth of trying out what happens if you change the optimizer to stabler rmsprop, as Adam made things unstable with PPO.
On sidenote: I ran Pong on sb3 DQN and was not able to get any improvement while sb2 learns it quickly (inside ~2M steps). I thought sb3 DQN was able to learn Pong, tho? Using parameters from rl-zoo, minus prioritized memory etc.
On sidenote: I ran Pong on sb3 DQN and was not able to get any improvement while sb2 learns it quickly (inside ~2M steps). I thought sb3 DQN was able to learn Pong, tho? Using parameters from rl-zoo, minus prioritized memory etc.
It was but not as good as expected... SB2 DQN has nothing to do with vanilla DQN...
It was but not as good as expected... SB2 DQN has nothing to do with vanilla DQN...
To clarify to others: araffin referred to the fact how, by default, SB2 DQN has bunch of modifications enabled (Double-Q, Dueling). Those were disabled for those runs.
I ran more experiments with Atari with the recent hotfix #132 . The learning curves are included in the main post and match mostly. While not perfect I can not tell if issue is in lack of random seeds used (three is rather low), and in any case I do not have the compute to run enough training runs to debug deeper if something differs.
I ran more experiments with Atari with the recent hotfix #132 . The learning curves are included in the main post and match mostly. While not perfect I can not tell if issue is in lack of random seeds used (three is rather low), and in any case I do not have the compute to run enough training runs to debug deeper if something differs.
Looks good, no? SB3 DQN has even slightly better performance on one and I'm pretty sure SB3 DQN is faster than SB2, no?
Btw, which hyperparameters did you use? (Please update the defaults if you used different ones)
Otherwise, it looks like it is ready to merge, no?
The last_done
fix is only 3 lines of code...
Looks good, no? SB3 DQN has even slightly better performance on one and I'm pretty sure SB3 DQN is faster than SB2, no?
Btw, which hyperparameters did you use? (Please update the defaults if you used different ones)
Preferably I would want to performance match in both good and bad (i.e. not better or worse) just to keep consistent results, but that'd still require a lot of work ^^. I used the hyperparameters from sb2 rl-zoo, plus disabling all the DQN improvements for SB2. I am not quite sure what you mean by "update defaults".
I used the hyperparameters from sb2 rl-zoo, plus disabling all the DQN improvements for SB2. I am not quite sure what you mean by "update defaults".
I meant updating the default hyperparameters. The current ones are from the DQN nature paper and therefore do no correspond to your benchmark. The main differences are the buffer size and the final value of the exploration rate.
I would update the default with the one you used ;) (but we know they work)
I meant updating the default hyperparameters. The current ones are from the DQN nature paper and therefore do no correspond to your benchmark. The main differences are the buffer size and the final value of the exploration rate.
Hmm I would those values from the original paper, as this is what users would expect when seeing "DQN". I do not think these parameters I used are the best (do not learn fastest / stablest), but I needed the replay-buffer size at the very least to be able to fit multiple experiments at same time on same machine.
Miffyli marked this pull request as ready for review
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, very impressive and valuable detective work =)
araffin deleted the review/pg-performance branch
This was referenced
Aug 5, 2020
Thank you for your hard work on this to investigate and align the performance!
This PR is currently referenced in the Atari Results section of the documentation here: https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html
Regarding the learning curves, would you please be able to clarify,
- What Atari environment versions these are? e.g. NoFrameskip-v4
- If the learning curves apply to the agent reward (after preprocessing normalization) or the agent score? (This was an issue in Performance Check (Discrete actions) #49)
- What code would be needed to replicate these results the code referenced? The documentation indicates it would of the form:
python train.py --algo ppo --env $ENV_ID --eval-episodes 10 --eval-freq 10000
- In the PPO learning curves, what settings "Full," "Minimal," and "Multi-discrete" refer to?
Thank you for your help
Thanks for the kind words!
I ran these experiments using a different code base from zoo (one I was most familiar at the time), so replicating results exactly might be bit tricky.
- Yes, NoFrameskip-v4, but with all the default Atari wrappers.
- If you refer to Use Monitor episode reward/length for evaluate_policy #220, yes, these results were obtained before that fix (i.e. learning curves use modified rewards)
- Looking at the code I used, that command should use same hyperparameters and wrappers I used. Note that these results use monitor files for creating learning curves which includes exploration, so DQN results will look different from zoo's. PPO/A2C should be fine.
- Those are different action-spaces explored in this work (the code I used to run these experiments). "Minimal" is the default you have with Atari envs, "Full" is where you always have access to all actions despite not used by the game and "Multi-discrete" is "Full" but where button press and joystick movement are separate into two different discrete actions.