AttributeError: 'numpy.ndarray' object has no attribute 'at' when sampling LKJ Cholesky Covariance Priors for Multivariate Normal Models example with numpyro or blackjax (original) (raw)

I wanted to try out sampling with numpyro on an AMD GPU using a model based on this example:
https://www.pymc.io/projects/examples/en/latest/howto/LKJ.html
It works with PyMC and nutpie samplers, but not with jax samplers.
The code I run looks like the following:

import pymc as pm
import numpy


def main():
    RANDOM_SEED = 8927
    rng = numpy.random.default_rng(RANDOM_SEED)

    N = 10000

    mu_actual = numpy.array([1.0, -2.0])
    sigmas_actual = numpy.array([0.7, 1.5])
    Rho_actual = numpy.array([[1.0, -0.4], [-0.4, 1.0]])

    Sigma_actual = numpy.diag(sigmas_actual) @ Rho_actual @ numpy.diag(sigmas_actual)
    print(Sigma_actual)

    x = rng.multivariate_normal(mu_actual, Sigma_actual, size=N)
    print(type(x))

    coords = {"axis": ["y", "z"], "axis_bis": ["y", "z"], "obs_id": numpy.arange(N)}
    with pm.Model(coords=coords) as model:
        chol, corr, stds = pm.LKJCholeskyCov(
            "chol", n=2, eta=2.0, sd_dist=pm.Exponential.dist(1.0, shape=2)
        )
        cov = pm.Deterministic("cov", chol.dot(chol.T), dims=("axis", "axis_bis"))
        mu = pm.Normal("mu", 0.0, sigma=1.5, dims="axis")
        obs = pm.MvNormal("obs", mu, chol=chol, observed=x, dims=("obs_id", "axis"))
        idata = pm.sample(
            nuts_sampler="numpyro",
            progressbar=False,
            tune=1000,
            draws=1000,
            chains=4,
            cores=1,
            mp_ctx="forkserver",
            nuts_sampler_kwargs=dict(
                chain_method="vectorized",
                postprocessing_backend="gpu"
                ),
            idata_kwargs=dict(dims={"chol_stds": ["axis"], "chol_corr": ["axis", "axis_bis"]})
        )


if __name__ == "__main__":
    main()

This is the conda env.yaml I’m using:

name: test_env
channels:
  - conda-forge
dependencies:
  - ipykernel
  - ipywidgets
  - jupyter
  - jupyterlab
  - numpy
  - pip
  - pymc
  - python=3.12.7
  - numpyro
  - pip:
    - https://github.com/ROCm/jax/releases/download/rocm-jax-v0.4.35/jaxlib-0.4.35-cp312-cp312-manylinux_2_28_x86_64.whl
    - https://github.com/ROCm/jax/releases/download/rocm-jax-v0.4.35/jax_rocm60_pjrt-0.4.35-py3-none-manylinux_2_28_x86_64.whl 
    - https://github.com/ROCm/jax/releases/download/rocm-jax-v0.4.35/jax_rocm60_plugin-0.4.35-cp312-cp312-manylinux_2_28_x86_64.whl
    - https://github.com/ROCm/jax/archive/refs/tags/rocm-jax-v0.4.35.tar.gz
    - ml-dtypes==0.4.0
variables:
  ROCM_PATH: /opt/rocm-6.2.1
  LLVM_PATH: /opt/rocm-6.2.1/llvm
  ENABLE_PJRT_COMPATIBILITY: 1

When I sample with PyMC, everything works fine.
However, the numpyro sampler gives the following error:

Traceback (most recent call last):
  File "/home/eichberg/test_pymc/test_jax.py", line 46, in <module>
    main()
  File "/home/eichberg/test_pymc/test_jax.py", line 29, in main
    idata = pm.sample(
            ^^^^^^^^^^
  File "/home/eichberg/Programs/miniforge3/envs/test_env/lib/python3.12/site-packages/pymc/sampling/mcmc.py", line 809, in sample
    return _sample_external_nuts(
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/eichberg/Programs/miniforge3/envs/test_env/lib/python3.12/site-packages/pymc/sampling/mcmc.py", line 396, in _sample_external_nuts
    idata = pymc_jax.sample_jax_nuts(
            ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/eichberg/Programs/miniforge3/envs/test_env/lib/python3.12/site-packages/pymc/sampling/jax.py", line 651, in sample_jax_nuts
    initial_points = _get_batched_jittered_initial_points(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/eichberg/Programs/miniforge3/envs/test_env/lib/python3.12/site-packages/pymc/sampling/jax.py", line 245, in _get_batched_jittered_initial_points
    initial_points = _init_jitter(
                     ^^^^^^^^^^^^^
  File "/home/eichberg/Programs/miniforge3/envs/test_env/lib/python3.12/site-packages/pymc/sampling/mcmc.py", line 1482, in _init_jitter
    point_logp = model_logp_fn(point)
                 ^^^^^^^^^^^^^^^^^^^^
  File "/home/eichberg/Programs/miniforge3/envs/test_env/lib/python3.12/site-packages/pymc/sampling/jax.py", line 243, in eval_logp_initial_point
    return logp_fn(point.values())
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/eichberg/Programs/miniforge3/envs/test_env/lib/python3.12/site-packages/pymc/sampling/jax.py", line 155, in logp_fn_wrap
    return logp_fn(*x)[0]
           ^^^^^^^^^^^
  File "/tmp/tmp0o4bd774", line 7, in jax_funcified_fgraph
    tensor_variable_2 = incsubtensor(chol_cholesky_cov_packed_, tensor_variable_1, tensor_constant)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/eichberg/Programs/miniforge3/envs/test_env/lib/python3.12/site-packages/pytensor/link/jax/dispatch/subtensor.py", line 70, in incsubtensor
    return jax_fn(x, indices, y)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/eichberg/Programs/miniforge3/envs/test_env/lib/python3.12/site-packages/pytensor/link/jax/dispatch/subtensor.py", line 58, in jax_fn
    return x.at[indices].set(y)
           ^^^^
AttributeError: 'numpy.ndarray' object has no attribute 'at'

The error persists when turning off nuts_sampler_kwargs, or setting the sampler to blackjax and also when using the original jax instead of ROCm-jax (thus using CPU).
Is that a bug in PyMC, pytensor, jax, etc. or did I do something wrong?
And if it’s a bug, who’s at fault?
My guess is that jit broke something.