jax.flatten_util module — JAX documentation (original) (raw)
Contents
jax.flatten_util
module#
List of Functions#
ravel_pytree(pytree) | Ravel (flatten) a pytree of arrays down to a 1D array. |
---|
jax.flatten_util
module#ravel_pytree(pytree) | Ravel (flatten) a pytree of arrays down to a 1D array. |
---|