jax.numpy.linalg.tensorsolve — JAX documentation (original) (raw)

jax.numpy.linalg.tensorsolve#

jax.numpy.linalg.tensorsolve(a, b, axes=None)[source]#

Solve the tensor equation a x = b for x.

JAX implementation of numpy.linalg.tensorsolve().

Parameters:

Returns:

array x such that after reordering of axes of a, tensordot(a, x, x.ndim)is equivalent to b.

Return type:

Array

Examples

key1, key2 = jax.random.split(jax.random.key(8675309)) a = jax.random.normal(key1, shape=(2, 2, 4)) b = jax.random.normal(key2, shape=(2, 2)) x = jnp.linalg.tensorsolve(a, b) x.shape (4,)

Now show that x can be used to reconstruct b usingtensordot():

b_reconstructed = jnp.linalg.tensordot(a, x, axes=x.ndim) jnp.allclose(b, b_reconstructed) Array(True, dtype=bool)