jax __cuda_array_interface__ not working · Issue #16440 · jax-ml/jax (original) (raw)

Description

import numpy as np import jax.numpy as jnp from numba import cuda import cupy

@cuda.jit def _sum_nomask(w, res, ind): start = cuda.grid(1) stride = cuda.gridsize(1) n = w.shape[0] tot = 0.0 for i in range(start, n, stride): if w[i] > 0: tot += w[i]

cuda.atomic.add(res, ind, tot)

if name == "main": arr = cupy.arange(10000, dtype=np.float32) res = cupy.zeros(2, dtype=np.float32) _sum_nomask[500, 32](arr, res, 0) print("cupy:", res)

arr = jnp.arange(10000, dtype=jnp.float32)
res = jnp.zeros(2, dtype=jnp.float32)
_sum_nomask[500, 32](arr, res, 0)
print("jax:", res)

I get

$ python test_jax_numba.py cupy: [49995040. 0.] Traceback (most recent call last): File "/home/mrbecker/test_jax_numba.py", line 29, in _sum_nomask[500, 32](arr, res, 0) File "/home/mrbecker/miniforge3/envs/jax-numba/lib/python3.11/site-packages/numba/cuda/dispatcher.py", line 542, in call return self.dispatcher.call(args, self.griddim, self.blockdim, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/mrbecker/miniforge3/envs/jax-numba/lib/python3.11/site-packages/numba/cuda/dispatcher.py", line 676, in call kernel = _dispatcher.Dispatcher._cuda_call(self, *args) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/mrbecker/miniforge3/envs/jax-numba/lib/python3.11/site-packages/numba/cuda/dispatcher.py", line 683, in _compile_for_args argtypes = [self.typeof_pyval(a) for a in args] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/mrbecker/miniforge3/envs/jax-numba/lib/python3.11/site-packages/numba/cuda/dispatcher.py", line 683, in argtypes = [self.typeof_pyval(a) for a in args] ^^^^^^^^^^^^^^^^^^^^ File "/home/mrbecker/miniforge3/envs/jax-numba/lib/python3.11/site-packages/numba/cuda/dispatcher.py", line 690, in typeof_pyval return typeof(val, Purpose.argument) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/mrbecker/miniforge3/envs/jax-numba/lib/python3.11/site-packages/numba/core/typing/typeof.py", line 33, in typeof ty = typeof_impl(val, c) ^^^^^^^^^^^^^^^^^^^ File "/home/mrbecker/miniforge3/envs/jax-numba/lib/python3.11/functools.py", line 909, in wrapper return dispatch(args[0].class)(*args, **kw) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/mrbecker/miniforge3/envs/jax-numba/lib/python3.11/site-packages/numba/core/typing/typeof.py", line 46, in typeof_impl tp = _typeof_buffer(val, c) ^^^^^^^^^^^^^^^^^^^^^^ File "/home/mrbecker/miniforge3/envs/jax-numba/lib/python3.11/site-packages/numba/core/typing/typeof.py", line 69, in _typeof_buffer m = memoryview(val) ^^^^^^^^^^^^^^^ BufferError: INVALID_ARGUMENT: Python buffer protocol is only defined for CPU buffers.

Futher if you try to access the cuda array interface attribute another error comes up. (This may be a red herring.)

In [1]: import jax

In [2]: a = jax.numpy.arange(10)

In [3]: a
Out[3]: Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

In [4]: a.__cuda_array_interface__
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
TypeError: Unregistered type : absl::lts_20230125::StatusOr<dict>

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
Cell In[4], line 1
----> 1 a.__cuda_array_interface__

TypeError: Unable to convert function return value to a Python type! The signature was
    (arg0: xla::PyArray) -> absl::lts_20230125::StatusOr<dict>

What jax/jaxlib version are you using?

jax and jaxlib 0.4.12

Which accelerator(s) are you using?

GPU

Additional system info

linux w/ nvidia

NVIDIA GPU info

$ nvidia-smi
Thu Jun 15 14:35:41 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.182.03   Driver Version: 470.182.03   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Quadro GV100        Off  | 00000000:2D:00.0 Off |                  Off |
| 49%   57C    P0    39W / 250W |      0MiB / 32508MiB |      3%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+