jax.numpy.power — JAX documentation (original) (raw)
jax.numpy.power#
jax.numpy.power(x1, x2, /)[source]#
Calculate element-wise base x1 exponential of x2.
JAX implementation of numpy.power.
Parameters:
- x1 (ArrayLike) – scalar or array. Specifies the bases.
- x2 (ArrayLike) – scalar or array. Specifies the exponent.
x1andx2should either have same shape or be broadcast compatible.
Returns:
An array containing the base x1 exponentials of x2 with same dtype as input.
Return type:
Note
- When
x2is a concrete integer scalar,jnp.powerlowers tojax.lax.integer_pow(). - When
x2is a traced scalar or an array,jnp.powerlowers tojax.lax.pow(). jnp.powerraises aTypeErrorfor integer type raised to a concrete negative integer power. For a non-concrete power, the operation is invalid and the returned value is implementation-defined.jnp.powerreturnsnanfor negative value raised to the power of non-integer values.
See also
- jax.lax.pow(): Computes element-wise power, \(x^y\).
- jax.lax.integer_pow(): Computes element-wise power \(x^y\), where\(y\) is a fixed integer.
- jax.numpy.float_power(): Computes the first array raised to the power of second array, element-wise, by promoting to the inexact dtype.
- jax.numpy.pow(): Computes the first array raised to the power of second array, element-wise.
Examples
Inputs with scalar integers:
jnp.power(4, 3) Array(64, dtype=int32, weak_type=True)
Inputs with same shape:
x1 = jnp.array([2, 4, 5]) x2 = jnp.array([3, 0.5, 2]) jnp.power(x1, x2) Array([ 8., 2., 25.], dtype=float32)
Inputs with broadcast compatibility:
x3 = jnp.array([-2, 3, 1]) x4 = jnp.array([[4, 1, 6], ... [1.3, 3, 5]]) jnp.power(x3, x4) Array([[16., 3., 1.], [nan, 27., 1.]], dtype=float32)