Some questions on GPU based sampling (original) (raw)
February 21, 2025, 4:55pm 1
Hi all,
up to now, I was using only CPU based sampling and was somehow hesitant to set up an environment which might support GPU based sampling.
I found the article Set up JAX sampling with GPUs in PyMC which gives some hints about the process of setting up such an environment. However, the article is not the newest and might not reflect current versions and developments.
I just ordered a new PC with an RTX 5070 Ti, which as far as I know requires CUDA 12.0.
I now have the following questions:
- Which sampling library currently is the “state of the art” library for usage with PyMC to empower GPU based sampling? jax? numpyro? Others?
- Is CUDA 12.0 - to your knowledge - supported by these samplers?
- If so, do you expect any problems with using such a “new” GPU version like the RTX 5070 Ti?
I plan to use the GPU support in a Windows 11 environment, however the jax documentation suggests that this is still not possible natively, so I would have to use WSL like described in the above mentioned article. Is there any option to use a GPU based sampler on Windows natively?
Thanks for sharing any experiences. I will give my own report as soon as the PC arrives, which is probably in 2-3 weeks.
Best regards
Matthias
jovank February 22, 2025, 12:08am 2
I started using GPU based sampling with pymc 5.20.1, even tough it was possible before this as well. I have tried it on two different GPUs, an old one with 4GB of VRAM and the GeForce 3090. Furthermore, I am running Ubuntu 24 in WSL on Windows. What worked for me is this:
- Install numpyro
pip install numpyro - Install the jax backend with cuda (yes, cuda12 is supported by numpyro and jax)
pip install -U "jax[cuda12]" - Specify the numpyro sampler in pm.samle
trace = pm.sample(
samples=1000,
chains=4,
cores=1, # I set this to 1 because I have one GPU
nuts_sampler="numpyro",
nuts_sampler_kwargs={
"chain_method": "vectorized" # this will run all chains concurrently
# "parallel" is useful if you have
# more devices
},
initvals=...,
)
For some reason, I never managed to run the blackjax sampler 
When comparing speed, sampling 1000 samples for the model I am working on took:
- 4 x 10 mins = 40 mins on the GeForce 3090
- 4 x 1 hour = 4 hours on the older GPU
- 12 hours on my laptop with an Intel Core I7 1065G7 CPU (surely it is faster on better CPUs than this)
aseyboldt February 22, 2025, 12:58am 3
I am a bit biased, but I would recommend using nutpie instead of numpyro or blackjax. I my experience it outperforms both of those, on both CPU and GPU, but especially on the GPU.
You have to select the jax backend when compiling the model though:
import nutpie
compiled = nutpie.compile_pymc_model(model, backend="jax")
trace = nutpie.sample(compiled)
You can also specify if pytensor or jax should take the necessary derivatives by passing gradient_backend="jax" or gradient_backend="pytensor".
You can also achieve that through pm.sample(nuts_sampler="nutpie", nuts_sampler_kwargs=dict(backend="jax"))
krum_sv November 4, 2025, 10:32pm 5
I just stumbled across this thread to learn that nutpie can also be used with JAX backend - which gave me a great speed-up!
So far, I’ve been switching back and forth between numpyro & nutpie to see which sampler is working better for which model…
However, since I’m mostly working with some scan()-based timeseries models (similar to this), I was running several times into the following error during compilation:
NotImplementedError: Scan with MIT-MOT (gradients of scan) cannot yet be converted to JAX
The solution was to switch also the gradient backend manually to JAX. So this seems to be the most effective sampler setup for me at the moment:
with freeze_dims_and_data(model):
pm.sample(nuts_sampler='nutpie', nuts_sampler_kwargs={'backend': 'jax', 'gradient_backend': 'jax'})
P.S.: I also passed compile_kwargs={'mode': 'JAX'} to sample_prior_predictive() and sample_posterior_predictive(), which gave me also significant speed-up there.
Nice! These are indeed the recommended settings for time-series modeling.
We have a PR to fix the MIT-MOT thing though. So hopefully that will be a forgotten bad memory soon.
ccyang November 6, 2025, 9:28am 7
This setting seems not work for BART. Is there similar setting to speed up BART?
Thanks!
You should try the pymc-bart-rs package developed by @gstechschulte It’s super fast.
ckrapu November 7, 2025, 12:59pm 10
What methods would folks recommend for sampling with very large models for which the trace size exceeds the GPU memory?
I see that one may store the trace on disk using CPU-based sampling, but it’s not clear if this is feasible using BlackJAX or similar.
I imagine it doesn’t work with non-default backends, because PyMC doesn’t control what is happening between steps in whatever package we ship out to for sampling. I think there’s some work happening in this direction on nutpie, see for example here.
Actually you can (partially) sample on GPU using the PyMC sampler if you pass compile_kwargs={'mode':'JAX'}. I say partially because this will compile the model logp_dlogp function to JAX, which will use GPU by default. But the actual HMC and adaptation will still be in pure python. As a result you should get access to local storage.
Some notes about this:
- You will likely have to set
mp_ctx='forkserver', otherwise JAX will deadlock - PyMC isn’t doing any smart
pmapto delegate work between multiple GPUs. Actually it’s doing the dumbest thing possible. So I betcores > 1will not work the way you hope. - The return type from a pytensor function compiled to JAX mode is a JAX array. I have run into issues with this in past which necessitated local patching. PyMC needs some work to be more robust in this case. If you asked me to bed, I’d be that the store-to-disk logic doesn’t account for this.
As always, PRs welcome.
aseyboldt November 7, 2025, 6:21pm 12
You can also use nutpie. It doesn’t require that the whole trace fits in GPU memory, not even in CPU memory if you store the trace in a zarr archive during sampling.
With the jax backend, you can also use it to sample on the GPU.