jax.test_util module — JAX documentation (original) (raw)

Contents

jax.test_util module#

List of Functions#

check_grads(f, args, order[, modes, atol, ...]) Check gradients from automatic differentiation against finite differences.
check_jvp(f, f_jvp, args[, atol, rtol, eps, ...]) Check a JVP from automatic differentiation against finite differences.
check_vjp(f, f_vjp, args[, atol, rtol, eps, ...]) Check a VJP from automatic differentiation against finite differences.