Forward-mode differentiation rule for 'custom_lin' not implemented · Issue #2784 · jax-ml/jax (original) (raw)

Running this code produces the above error.

@partial(jax.custom_vjp, nondiff_argnums=(0, 1))
def _clip_gradient(lo, hi, x):
  return x  # identity function

def clip_gradient_fwd(lo, hi, x):
    # return x, None
    return x, (hi, ) 

def clip_gradient_bwd(lo, hi, _, g):
    return (np.clip(g, lo, hi),)

_clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)

def clip_gradient(x):
    lo = -1
    hi = x + 1  # causes things to break
    return _clip_gradient(lo, hi, x)

print(jax.grad(clip_gradient)(1.))

Replacing the residual with None (see commented out line in clip_gradient_fwd) makes the output

Traced<ConcreteArray(1.0)>with<JVPTrace(level=1/0)>
  with primal = DeviceArray(1., dtype=float32)
       tangent = Traced<ShapedArray(float32[]):JaxprTrace(level=0/0)>

I was also able to get a mismatched tracer levels error on a more complex example.

To my understanding setting hi = x + 1 is the issue as it creates a trace of hi with x.

This example may seem a bit contrived, but I originally came across this trying to set an initial step size for odeint (see #2604).

IIUC we want to make concrete all JVPTracer instances that are in static args here.