Jax sampling for bayesian neural network (original) (raw)
I am working with a version of the bayesian neural network model in this example notebook. I was wondering how one can change this to sample from the GPU using a JAX backend?
You would have to uses pm.sample(..., nuts_backend="numpyro")
whilst being on a machine with an available GPU.
So just so I understand correctly.
I will have to use an MCMC-based sampling approach and not the ADVI inference approach in the tutorial?