jax.custom_vjp — JAX documentation (original) (raw)
Contents
jax.custom_vjp#
class jax.custom_vjp(fun, nondiff_argnums=())[source]#
Set up a JAX-transformable function for a custom VJP rule definition.
This class is meant to be used as a function decorator. Instances are callables that behave similarly to the underlying function to which the decorator was applied, except when a reverse-mode differentiation transformation (like jax.grad()) is applied, in which case a custom user-supplied VJP rule function is used instead of tracing into and performing automatic differentiation of the underlying function’s implementation. There is a single instance method, defvjp(), which may be used to define the custom VJP rule.
This decorator precludes the use of forward-mode automatic differentiation.
For example:
@jax.custom_vjp def f(x, y): return jnp.sin(x) * y
def f_fwd(x, y): return f(x, y), (jnp.cos(x), jnp.sin(x), y)
def f_bwd(res, g): cos_x, sin_x, y = res return (cos_x * g * y, sin_x * g)
f.defvjp(f_fwd, f_bwd)
For a more detailed introduction, see the tutorial.
Parameters:
- fun (Callable [ ... , ReturnValue ])
- nondiff_argnums (Sequence _[_int])
__init__(fun, nondiff_argnums=())[source]#
Parameters:
- fun (Callable [ ... , ReturnValue ])
- nondiff_argnums (Sequence _[_int])
Methods
__init__(fun[, nondiff_argnums]) | |
---|---|
defvjp(fwd, bwd[, symbolic_zeros, ...]) | Define a custom VJP rule for the function represented by this instance. |