jax.numpy.vecmat — JAX documentation (original) (raw)
jax.numpy.vecmat#
jax.numpy.vecmat(x1, x2, /)[source]#
Batched conjugate vector-matrix product.
JAX implementation of numpy.vecmat().
Parameters:
- x1 (Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) – array of shape
(..., M). - x2 (Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) – array of shape
(..., M, N). Leading dimensions must be broadcast-compatible with leading dimensions ofx1.
Returns:
An array of shape (..., N) containing the batched conjugate vector-matrix product.
Return type:
See also
- jax.numpy.linalg.vecdot(): batched vector product.
- jax.numpy.matvec(): matrix-vector product.
- jax.numpy.matmul(): general matrix multiplication.
Examples
Simple vector-matrix product:
x1 = jnp.array([[1, 2, 3]]) x2 = jnp.array([[4, 5], ... [6, 7], ... [8, 9]]) jnp.vecmat(x1, x2) Array([[40, 46]], dtype=int32)
Batched vector-matrix product:
x1 = jnp.array([[1, 2, 3], ... [4, 5, 6]]) jnp.vecmat(x1, x2) Array([[ 40, 46], [ 94, 109]], dtype=int32)