Jax sampling for bayesian neural network (original) (raw)

I am working with a version of the bayesian neural network model in this example notebook. I was wondering how one can change this to sample from the GPU using a JAX backend?

You would have to uses pm.sample(..., nuts_backend="numpyro") whilst being on a machine with an available GPU.

So just so I understand correctly.
I will have to use an MCMC-based sampling approach and not the ADVI inference approach in the tutorial?