jax.numpy.add — JAX documentation (original) (raw)
jax.numpy.add#
jax.numpy.add = <jnp.ufunc 'add'>#
Add two arrays element-wise.
JAX implementation of numpy.add. This is a universal function, and supports the additional APIs described at jax.numpy.ufunc. This function provides the implementation of the + operator for JAX arrays.
Parameters:
- x – arrays to add. Must be broadcastable to a common shape.
- y – arrays to add. Must be broadcastable to a common shape.
- args (ArrayLike)
- out (None)
- where (None)
Returns:
Array containing the result of the element-wise addition.
Return type:
Any
Examples
Calling add explicitly:
x = jnp.arange(4) jnp.add(x, 10) Array([10, 11, 12, 13], dtype=int32)
Calling add via the + operator:
x + 10 Array([10, 11, 12, 13], dtype=int32)