jax.test_util module — JAX documentation (original) (raw)
Contents
jax.test_util
module#
List of Functions#
check_grads(f, args, order[, modes, atol, ...]) | Check gradients from automatic differentiation against finite differences. |
---|---|
check_jvp(f, f_jvp, args[, atol, rtol, eps, ...]) | Check a JVP from automatic differentiation against finite differences. |
check_vjp(f, f_vjp, args[, atol, rtol, eps, ...]) | Check a VJP from automatic differentiation against finite differences. |