Forward-mode differentiation rule for 'custom_lin' not implemented · Issue #2784 · jax-ml/jax (original) (raw)
Running this code produces the above error.
@partial(jax.custom_vjp, nondiff_argnums=(0, 1))
def _clip_gradient(lo, hi, x):
return x # identity function
def clip_gradient_fwd(lo, hi, x):
# return x, None
return x, (hi, )
def clip_gradient_bwd(lo, hi, _, g):
return (np.clip(g, lo, hi),)
_clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)
def clip_gradient(x):
lo = -1
hi = x + 1 # causes things to break
return _clip_gradient(lo, hi, x)
print(jax.grad(clip_gradient)(1.))
Replacing the residual with None
(see commented out line in clip_gradient_fwd
) makes the output
Traced<ConcreteArray(1.0)>with<JVPTrace(level=1/0)>
with primal = DeviceArray(1., dtype=float32)
tangent = Traced<ShapedArray(float32[]):JaxprTrace(level=0/0)>
I was also able to get a mismatched tracer levels error on a more complex example.
To my understanding setting hi = x + 1
is the issue as it creates a trace of hi
with x
.
This example may seem a bit contrived, but I originally came across this trying to set an initial step size for odeint
(see #2604).
IIUC we want to make concrete all JVPTracer
instances that are in static args here.