jax.numpy.linalg.tensorsolve — JAX documentation (original) (raw)
jax.numpy.linalg.tensorsolve#
jax.numpy.linalg.tensorsolve(a, b, axes=None)[source]#
Solve the tensor equation a x = b for x.
JAX implementation of numpy.linalg.tensorsolve().
Parameters:
- a (ArrayLike) – input array. After reordering via
axes
(see below), shape must be(*b.shape, *x.shape)
. - b (ArrayLike) – right-hand-side array.
- axes (tuple[_int,_ ... ] | None | None) – optional tuple specifying axes of
a
that should be moved to the end
Returns:
array x such that after reordering of axes of a
, tensordot(a, x, x.ndim)
is equivalent to b
.
Return type:
Examples
key1, key2 = jax.random.split(jax.random.key(8675309)) a = jax.random.normal(key1, shape=(2, 2, 4)) b = jax.random.normal(key2, shape=(2, 2)) x = jnp.linalg.tensorsolve(a, b) x.shape (4,)
Now show that x
can be used to reconstruct b
usingtensordot():
b_reconstructed = jnp.linalg.tensordot(a, x, axes=x.ndim) jnp.allclose(b, b_reconstructed) Array(True, dtype=bool)