jax.numpy.heaviside — JAX documentation (original) (raw)

jax.numpy.heaviside#

jax.numpy.heaviside(x1, x2, /)[source]#

Compute the heaviside step function.

JAX implementation of numpy.heaviside.

The heaviside step function is defined by:

\[\begin{split}\mathrm{heaviside}(x1, x2) = \begin{cases} 0, & x1 < 0\\ x2, & x1 = 0\\ 1, & x1 > 0. \end{cases}\end{split}\]

Parameters:

Returns:

An array containing the heaviside step function of x1, promoting to inexact dtype.

Return type:

Array

Examples

x1 = jnp.array([[-2, 0, 3], ... [5, -1, 0], ... [0, 7, -3]]) x2 = jnp.array([2, 0.5, 1]) jnp.heaviside(x1, x2) Array([[0. , 0.5, 1. ], [1. , 0. , 1. ], [2. , 1. , 0. ]], dtype=float32) jnp.heaviside(x1, 0.5) Array([[0. , 0.5, 1. ], [1. , 0. , 0.5], [0.5, 1. , 0. ]], dtype=float32) jnp.heaviside(-3, x2) Array([0., 0., 0.], dtype=float32)