jax.numpy.piecewise — JAX documentation (original) (raw)

jax.numpy.piecewise#

jax.numpy.piecewise(x, condlist, funclist, *args, **kw)[source]#

Evaluate a function defined piecewise across the domain.

JAX implementation of numpy.piecewise(), in terms of jax.lax.switch().

Parameters:

Returns:

An array which is the result of evaluating the functions on x at the specified conditions.

Return type:

Array

See also

Examples

Here’s an example of a function which is zero for negative values, and linear for positive values:

x = jnp.array([-4, -3, -2, -1, 0, 1, 2, 3, 4])

condlist = [x < 0, x >= 0] funclist = [lambda x: 0 * x, lambda x: x] jnp.piecewise(x, condlist, funclist) Array([0, 0, 0, 0, 0, 1, 2, 3, 4], dtype=int32)

funclist can also contain a simple scalar value for constant functions:

condlist = [x < 0, x >= 0] funclist = [0, lambda x: x] jnp.piecewise(x, condlist, funclist) Array([0, 0, 0, 0, 0, 1, 2, 3, 4], dtype=int32)

You can specify a default value by appending an extra condition to funclist:

condlist = [x < -1, x > 1] funclist = [lambda x: 1 + x, lambda x: x - 1, 0] jnp.piecewise(x, condlist, funclist) Array([-3, -2, -1, 0, 0, 0, 1, 2, 3], dtype=int32)

condlist may also be a simple array of scalar conditions, in which case the associated function applies to the whole range

condlist = jnp.array([False, True, False]) funclist = [lambda x: x * 0, lambda x: x * 10, lambda x: x * 100] jnp.piecewise(x, condlist, funclist) Array([-40, -30, -20, -10, 0, 10, 20, 30, 40], dtype=int32)