jax.numpy.power — JAX documentation (original) (raw)

jax.numpy.power#

jax.numpy.power(x1, x2, /)[source]#

Calculate element-wise base x1 exponential of x2.

JAX implementation of numpy.power.

Parameters:

Returns:

An array containing the base x1 exponentials of x2 with same dtype as input.

Return type:

Array

Note

See also

Examples

Inputs with scalar integers:

jnp.power(4, 3) Array(64, dtype=int32, weak_type=True)

Inputs with same shape:

x1 = jnp.array([2, 4, 5]) x2 = jnp.array([3, 0.5, 2]) jnp.power(x1, x2) Array([ 8., 2., 25.], dtype=float32)

Inputs with broadcast compatibility:

x3 = jnp.array([-2, 3, 1]) x4 = jnp.array([[4, 1, 6], ... [1.3, 3, 5]]) jnp.power(x3, x4) Array([[16., 3., 1.], [nan, 27., 1.]], dtype=float32)