jax.numpy.kron — JAX documentation (original) (raw)

jax.numpy.kron#

jax.numpy.kron(a, b)[source]#

Compute the Kronecker product of two input arrays.

JAX implementation of numpy.kron().

The Kronecker product is an operation on two matrices of arbitrary size that produces a block matrix. Each element of the first matrix a is multiplied by the entire second matrix b. If a has shape (m, n) and bhas shape (p, q), the resulting matrix will have shape (m * p, n * q).

Parameters:

Returns:

A new array representing the Kronecker product of the inputs a and b. The shape of the output is the element-wise product of the input shapes.

Return type:

Array

Examples

a = jnp.array([[1, 2], ... [3, 4]]) b = jnp.array([[5, 6], ... [7, 8]]) jnp.kron(a, b) Array([[ 5, 6, 10, 12], [ 7, 8, 14, 16], [15, 18, 20, 24], [21, 24, 28, 32]], dtype=int32)