jax.numpy.fromfunction — JAX documentation (original) (raw)
jax.numpy.fromfunction#
jax.numpy.fromfunction(function, shape, *, dtype=<class 'float'>, **kwargs)[source]#
Create an array from a function applied over indices.
JAX implementation of numpy.fromfunction(). The JAX implementation differs in that it dispatches via jax.vmap(), and so unlike in NumPy the function logically operates on scalar inputs, and need not explicitly handle broadcasted inputs (See Examples below).
Parameters:
- function (Callable _[_ _..._ _,_ Array]) – a function that takes N dynamic scalars and outputs a scalar.
- shape (Any) – a length-N tuple of integers specifying the output shape.
- dtype (DTypeLike) – optionally specify the dtype of the inputs. Defaults to floating-point.
- kwargs – additional keyword arguments are passed statically to
function
.
Returns:
An array of shape shape
if function
returns a scalar, or in general a pytree of arrays with leading dimensions shape
, as determined by the output of function
.
Return type:
See also
- jax.vmap(): the core transformation that the fromfunction()API is built on.
Examples
Generate a multiplication table of a given shape:
jnp.fromfunction(jnp.multiply, shape=(3, 6), dtype=int) Array([[ 0, 0, 0, 0, 0, 0], [ 0, 1, 2, 3, 4, 5], [ 0, 2, 4, 6, 8, 10]], dtype=int32)
When function
returns a non-scalar the output will have leading dimension of shape
:
def f(x): ... return (x + 1) * jnp.arange(3) jnp.fromfunction(f, shape=(2,)) Array([[0., 1., 2.], [0., 2., 4.]], dtype=float32)
function
may return multiple results, in which case each is mapped independently:
def f(x, y): ... return x + y, x * y x_plus_y, x_times_y = jnp.fromfunction(f, shape=(3, 5)) print(x_plus_y) [[0. 1. 2. 3. 4.] [1. 2. 3. 4. 5.] [2. 3. 4. 5. 6.]] print(x_times_y) [[0. 0. 0. 0. 0.] [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.]]
The JAX implementation differs slightly from NumPy’s implementation. Innumpy.fromfunction(), the function is expected to explicitly operate element-wise on the full grid of input values:
def f(x, y): ... print(f"{x.shape = }\n{y.shape = }") ... return x + y ... np.fromfunction(f, (2, 3)) x.shape = (2, 3) y.shape = (2, 3) array([[0., 1., 2.], [1., 2., 3.]])
In jax.numpy.fromfunction(), the function is vectorized viajax.vmap(), and so is expected to operate on scalar values:
jnp.fromfunction(f, (2, 3)) x.shape = () y.shape = () Array([[0., 1., 2.], [1., 2., 3.]], dtype=float32)