jax.numpy.heaviside — JAX documentation (original) (raw)
jax.numpy.heaviside#
jax.numpy.heaviside(x1, x2, /)[source]#
Compute the heaviside step function.
JAX implementation of numpy.heaviside.
The heaviside step function is defined by:
\[\begin{split}\mathrm{heaviside}(x1, x2) = \begin{cases} 0, & x1 < 0\\ x2, & x1 = 0\\ 1, & x1 > 0. \end{cases}\end{split}\]
Parameters:
- x1 (ArrayLike) – input array or scalar.
complexdtype are not supported. - x2 (ArrayLike) – scalar or array. Specifies the return values when
x1is0.complexdtype are not supported.x1andx2must either have same shape or broadcast compatible.
Returns:
An array containing the heaviside step function of x1, promoting to inexact dtype.
Return type:
Examples
x1 = jnp.array([[-2, 0, 3], ... [5, -1, 0], ... [0, 7, -3]]) x2 = jnp.array([2, 0.5, 1]) jnp.heaviside(x1, x2) Array([[0. , 0.5, 1. ], [1. , 0. , 1. ], [2. , 1. , 0. ]], dtype=float32) jnp.heaviside(x1, 0.5) Array([[0. , 0.5, 1. ], [1. , 0. , 0.5], [0.5, 1. , 0. ]], dtype=float32) jnp.heaviside(-3, x2) Array([0., 0., 0.], dtype=float32)