jax __cuda_array_interface__
not working · Issue #16440 · jax-ml/jax (original) (raw)
Description
import numpy as np import jax.numpy as jnp from numba import cuda import cupy
@cuda.jit def _sum_nomask(w, res, ind): start = cuda.grid(1) stride = cuda.gridsize(1) n = w.shape[0] tot = 0.0 for i in range(start, n, stride): if w[i] > 0: tot += w[i]
cuda.atomic.add(res, ind, tot)
if name == "main": arr = cupy.arange(10000, dtype=np.float32) res = cupy.zeros(2, dtype=np.float32) _sum_nomask[500, 32](arr, res, 0) print("cupy:", res)
arr = jnp.arange(10000, dtype=jnp.float32)
res = jnp.zeros(2, dtype=jnp.float32)
_sum_nomask[500, 32](arr, res, 0)
print("jax:", res)
I get
$ python test_jax_numba.py cupy: [49995040. 0.] Traceback (most recent call last): File "/home/mrbecker/test_jax_numba.py", line 29, in _sum_nomask[500, 32](arr, res, 0) File "/home/mrbecker/miniforge3/envs/jax-numba/lib/python3.11/site-packages/numba/cuda/dispatcher.py", line 542, in call return self.dispatcher.call(args, self.griddim, self.blockdim, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/mrbecker/miniforge3/envs/jax-numba/lib/python3.11/site-packages/numba/cuda/dispatcher.py", line 676, in call kernel = _dispatcher.Dispatcher._cuda_call(self, *args) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/mrbecker/miniforge3/envs/jax-numba/lib/python3.11/site-packages/numba/cuda/dispatcher.py", line 683, in _compile_for_args argtypes = [self.typeof_pyval(a) for a in args] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/mrbecker/miniforge3/envs/jax-numba/lib/python3.11/site-packages/numba/cuda/dispatcher.py", line 683, in argtypes = [self.typeof_pyval(a) for a in args] ^^^^^^^^^^^^^^^^^^^^ File "/home/mrbecker/miniforge3/envs/jax-numba/lib/python3.11/site-packages/numba/cuda/dispatcher.py", line 690, in typeof_pyval return typeof(val, Purpose.argument) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/mrbecker/miniforge3/envs/jax-numba/lib/python3.11/site-packages/numba/core/typing/typeof.py", line 33, in typeof ty = typeof_impl(val, c) ^^^^^^^^^^^^^^^^^^^ File "/home/mrbecker/miniforge3/envs/jax-numba/lib/python3.11/functools.py", line 909, in wrapper return dispatch(args[0].class)(*args, **kw) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/mrbecker/miniforge3/envs/jax-numba/lib/python3.11/site-packages/numba/core/typing/typeof.py", line 46, in typeof_impl tp = _typeof_buffer(val, c) ^^^^^^^^^^^^^^^^^^^^^^ File "/home/mrbecker/miniforge3/envs/jax-numba/lib/python3.11/site-packages/numba/core/typing/typeof.py", line 69, in _typeof_buffer m = memoryview(val) ^^^^^^^^^^^^^^^ BufferError: INVALID_ARGUMENT: Python buffer protocol is only defined for CPU buffers.
Futher if you try to access the cuda array interface attribute another error comes up. (This may be a red herring.)
In [1]: import jax
In [2]: a = jax.numpy.arange(10)
In [3]: a
Out[3]: Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)
In [4]: a.__cuda_array_interface__
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
TypeError: Unregistered type : absl::lts_20230125::StatusOr<dict>
The above exception was the direct cause of the following exception:
TypeError Traceback (most recent call last)
Cell In[4], line 1
----> 1 a.__cuda_array_interface__
TypeError: Unable to convert function return value to a Python type! The signature was
(arg0: xla::PyArray) -> absl::lts_20230125::StatusOr<dict>
What jax/jaxlib version are you using?
jax and jaxlib 0.4.12
Which accelerator(s) are you using?
GPU
Additional system info
linux w/ nvidia
NVIDIA GPU info
$ nvidia-smi
Thu Jun 15 14:35:41 2023
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.182.03 Driver Version: 470.182.03 CUDA Version: 11.4 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 Quadro GV100 Off | 00000000:2D:00.0 Off | Off |
| 49% 57C P0 39W / 250W | 0MiB / 32508MiB | 3% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+