jax.numpy.fromfunction — JAX documentation (original) (raw)

jax.numpy.fromfunction#

jax.numpy.fromfunction(function, shape, *, dtype=<class 'float'>, **kwargs)[source]#

Create an array from a function applied over indices.

JAX implementation of numpy.fromfunction(). The JAX implementation differs in that it dispatches via jax.vmap(), and so unlike in NumPy the function logically operates on scalar inputs, and need not explicitly handle broadcasted inputs (See Examples below).

Parameters:

Returns:

An array of shape shape if function returns a scalar, or in general a pytree of arrays with leading dimensions shape, as determined by the output of function.

Return type:

Array

See also

Examples

Generate a multiplication table of a given shape:

jnp.fromfunction(jnp.multiply, shape=(3, 6), dtype=int) Array([[ 0, 0, 0, 0, 0, 0], [ 0, 1, 2, 3, 4, 5], [ 0, 2, 4, 6, 8, 10]], dtype=int32)

When function returns a non-scalar the output will have leading dimension of shape:

def f(x): ... return (x + 1) * jnp.arange(3) jnp.fromfunction(f, shape=(2,)) Array([[0., 1., 2.], [0., 2., 4.]], dtype=float32)

function may return multiple results, in which case each is mapped independently:

def f(x, y): ... return x + y, x * y x_plus_y, x_times_y = jnp.fromfunction(f, shape=(3, 5)) print(x_plus_y) [[0. 1. 2. 3. 4.] [1. 2. 3. 4. 5.] [2. 3. 4. 5. 6.]] print(x_times_y) [[0. 0. 0. 0. 0.] [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.]]

The JAX implementation differs slightly from NumPy’s implementation. Innumpy.fromfunction(), the function is expected to explicitly operate element-wise on the full grid of input values:

def f(x, y): ... print(f"{x.shape = }\n{y.shape = }") ... return x + y ... np.fromfunction(f, (2, 3)) x.shape = (2, 3) y.shape = (2, 3) array([[0., 1., 2.], [1., 2., 3.]])

In jax.numpy.fromfunction(), the function is vectorized viajax.vmap(), and so is expected to operate on scalar values:

jnp.fromfunction(f, (2, 3)) x.shape = () y.shape = () Array([[0., 1., 2.], [1., 2., 3.]], dtype=float32)