Add jnp.ufunc and jnp.frompyfunc by jakevdp · Pull Request #17054 · jax-ml/jax (original) (raw)

This refactors some of the code from #9529 with a more narrow scope. The two new public objects are

Notably, while NumPy's np.frompyfunc has severe performance penalties (it is implemented in terms of python-side loops), JAX's jnp.frompyfunc can take advantage of JAX transformations like vmap and jit to make the resulting operations reasonably efficient.

In [1]: import jax.numpy as jnp

In [2]: def scalar_add(x, y): ...: # emphasize that only scalar tracers will be passed to this function. ...: assert jnp.shape(x) == jnp.shape(y) == () ...: return x + y ...:

In [3]: add = jnp.frompyfunc(scalar_add, nin=2, nout=1, identity=0)

In [4]: add Out[4]: <jnp.ufunc 'scalar_add'>

In [5]: x = jnp.arange(5)

In [6]: indices = jnp.array([1, 1, 3])

In [7]: add(x, 1) # Standard broadcasting Out[7]: Array([1, 2, 3, 4, 5], dtype=int32)

In [8]: add.outer(x, x) Out[8]: Array([[0, 1, 2, 3, 4], [1, 2, 3, 4, 5], [2, 3, 4, 5, 6], [3, 4, 5, 6, 7], [4, 5, 6, 7, 8]], dtype=int32)

In [9]: add.reduce(x) # reduce() method applies binary reduction. Out[9]: Array(10, dtype=int32)

In [10]: add.accumulate(x) # accumulate() method is cumulative reduction Out[10]: Array([ 0, 1, 3, 6, 10], dtype=int32)

In [11]: add.at(x, indices, 10, inplace=False) # at() method is similar to JAX's ndarray.at Out[11]: Array([ 0, 21, 2, 13, 4], dtype=int32)

In [12]: add.reduceat(x, indices) # reduction between indices Out[12]: Array([1, 3, 7], dtype=int32)

Once this is landed, we should explore making wrappers of numpy ufuncs within the jax.numpy namespace into jnp.ufunc objects, so ufunc methods can be used directly.

Additionally, there are a few places where efficiency could be improved. For example, for some binary functions, we should be able to substitute faster implementations for the scan approaches: e.g. jnp.add could use jnp.sum for reduce, jnp.cumsum for accumulate, jnp.scatter_add for at. and jnp.segment_sum for reduceat. But the current implementation is a good baseline that implements the desired behavior for arbitrary functions.