jax.numpy.ufunc — JAX documentation (original) (raw)
jax.numpy.ufunc#
class jax.numpy.ufunc(func, /, nin, nout, *, name=None, nargs=None, identity=None, call=None, reduce=None, accumulate=None, at=None, reduceat=None)#
Universal functions which operation element-by-element on arrays.
JAX implementation of numpy.ufunc.
This is a class for JAX-backed implementations of NumPy’s ufunc APIs. Most users will never need to instantiate ufunc, but rather will use the pre-defined ufuncs in jax.numpy.
For constructing your own ufuncs, see jax.numpy.frompyfunc().
Examples
Universal functions are functions that apply element-wise to broadcasted arrays, but they also come with a number of extra attributes and methods.
As an example, consider the function jax.numpy.add. The object acts as a function that applies addition to broadcasted arrays in an element-wise manner:
x = jnp.array([1, 2, 3, 4, 5]) jnp.add(x, 1) Array([2, 3, 4, 5, 6], dtype=int32)
Each ufunc object includes a number of attributes that describe its behavior:
jnp.add.nin # number of inputs 2 jnp.add.nout # number of outputs 1 jnp.add.identity # identity value, or None if no identity exists 0
Binary ufuncs like jax.numpy.add include number of methods to apply the function to arrays in different manners.
The outer() method applies the function to the pair-wise outer-product of the input array values:
jnp.add.outer(x, x) Array([[ 2, 3, 4, 5, 6], [ 3, 4, 5, 6, 7], [ 4, 5, 6, 7, 8], [ 5, 6, 7, 8, 9], [ 6, 7, 8, 9, 10]], dtype=int32)
The ufunc.reduce() method performs a reduction over the array. For example, jnp.add.reduce() is equivalent to jnp.sum:
jnp.add.reduce(x) Array(15, dtype=int32)
The ufunc.accumulate() method performs a cumulative reduction over the array. For example, jnp.add.accumulate() is equivalent to jax.numpy.cumulative_sum():
jnp.add.accumulate(x) Array([ 1, 3, 6, 10, 15], dtype=int32)
The ufunc.at() method applies the function at particular indices in the array; for jnp.add the computation is similar to jax.lax.scatter_add():
jnp.add.at(x, 0, 100, inplace=False) Array([101, 2, 3, 4, 5], dtype=int32)
And the ufunc.reduceat() method performs a number of reduceoperations between specified indices of an array; for jnp.add the operation is similar to jax.ops.segment_sum():
jnp.add.reduceat(x, jnp.array([0, 2])) Array([ 3, 12], dtype=int32)
In this case, the first element is x[0:2].sum(), and the second element is x[2:].sum().
Parameters:
- func (Callable[ ... , Any ])
- nin (int)
- nout (int)
- name (str | None)
- nargs (int | None)
- identity (Any)
- call (Callable[ ... , Any ] | None)
- reduce (Callable[ ... , Any ] | None)
- accumulate (Callable[ ... , Any ] | None)
- at (Callable[ ... , Any ] | None)
- reduceat (Callable[ ... , Any ] | None)
__init__(func, /, nin, nout, *, name=None, nargs=None, identity=None, call=None, reduce=None, accumulate=None, at=None, reduceat=None)[source]#
Parameters:
- func (Callable[ ... , Any ])
- nin (int)
- nout (int)
- name (str | None)
- nargs (int | None)
- identity (Any)
- call (Callable[ ... , Any ] | None)
- reduce (Callable[ ... , Any ] | None)
- accumulate (Callable[ ... , Any ] | None)
- at (Callable[ ... , Any ] | None)
- reduceat (Callable[ ... , Any ] | None)
Methods
| __init__(func, /, nin, nout, *[, name, ...]) | |
|---|---|
| accumulate(a[, axis, dtype, out]) | Accumulate operation derived from binary ufunc. |
| at(a, indices[, b, inplace]) | Update elements of an array via the specified unary or binary ufunc. |
| outer(A, B, /) | Apply the function to all pairs of values in A and B. |
| reduce(a[, axis, dtype, out, keepdims, ...]) | Reduction operation derived from a binary function. |
| reduceat(a, indices[, axis, dtype, out]) | Reduce an array between specified indices via a binary ufunc. |
Attributes