jax.numpy.vecmat — JAX documentation (original) (raw)

jax.numpy.vecmat#

jax.numpy.vecmat(x1, x2, /)[source]#

Batched conjugate vector-matrix product.

JAX implementation of numpy.vecmat().

Parameters:

Returns:

An array of shape (..., N) containing the batched conjugate vector-matrix product.

Return type:

Array

See also

Examples

Simple vector-matrix product:

x1 = jnp.array([[1, 2, 3]]) x2 = jnp.array([[4, 5], ... [6, 7], ... [8, 9]]) jnp.vecmat(x1, x2) Array([[40, 46]], dtype=int32)

Batched vector-matrix product:

x1 = jnp.array([[1, 2, 3], ... [4, 5, 6]]) jnp.vecmat(x1, x2) Array([[ 40, 46], [ 94, 109]], dtype=int32)