Add jnp.ufunc and jnp.frompyfunc by jakevdp · Pull Request #17054 · jax-ml/jax (original) (raw)
This refactors some of the code from #9529 with a more narrow scope. The two new public objects are
Notably, while NumPy's np.frompyfunc
has severe performance penalties (it is implemented in terms of python-side loops), JAX's jnp.frompyfunc
can take advantage of JAX transformations like vmap
and jit
to make the resulting operations reasonably efficient.
In [1]: import jax.numpy as jnp
In [2]: def scalar_add(x, y): ...: # emphasize that only scalar tracers will be passed to this function. ...: assert jnp.shape(x) == jnp.shape(y) == () ...: return x + y ...:
In [3]: add = jnp.frompyfunc(scalar_add, nin=2, nout=1, identity=0)
In [4]: add Out[4]: <jnp.ufunc 'scalar_add'>
In [5]: x = jnp.arange(5)
In [6]: indices = jnp.array([1, 1, 3])
In [7]: add(x, 1) # Standard broadcasting Out[7]: Array([1, 2, 3, 4, 5], dtype=int32)
In [8]: add.outer(x, x) Out[8]: Array([[0, 1, 2, 3, 4], [1, 2, 3, 4, 5], [2, 3, 4, 5, 6], [3, 4, 5, 6, 7], [4, 5, 6, 7, 8]], dtype=int32)
In [9]: add.reduce(x) # reduce() method applies binary reduction. Out[9]: Array(10, dtype=int32)
In [10]: add.accumulate(x) # accumulate() method is cumulative reduction Out[10]: Array([ 0, 1, 3, 6, 10], dtype=int32)
In [11]: add.at(x, indices, 10, inplace=False) # at() method is similar to JAX's ndarray.at Out[11]: Array([ 0, 21, 2, 13, 4], dtype=int32)
In [12]: add.reduceat(x, indices) # reduction between indices Out[12]: Array([1, 3, 7], dtype=int32)
Once this is landed, we should explore making wrappers of numpy ufuncs within the jax.numpy
namespace into jnp.ufunc
objects, so ufunc methods can be used directly.
Additionally, there are a few places where efficiency could be improved. For example, for some binary functions, we should be able to substitute faster implementations for the scan
approaches: e.g. jnp.add
could use jnp.sum
for reduce
, jnp.cumsum
for accumulate
, jnp.scatter_add
for at
. and jnp.segment_sum
for reduceat
. But the current implementation is a good baseline that implements the desired behavior for arbitrary functions.