odeint: don't hoist non-differentiable consts by mattjj · Pull Request #3587 · jax-ml/jax (original) (raw)

fixes #3584

This could use further revision! Left a todo.

The issue is that in #3562 we started closure-converting the dynamics function (by tracing it to a jaxpr up-front) so as to handle closed-over constants with respect to which we want to differentiate the odeint call. But if the dynamics function closes over integer-valued constants, then we can no longer call vjp on the closure-converted function without getting an error.

One fix would be to support (trivial) differentiation with respect to integer-valued inputs. That would work if we supperss the error message for integer-valued inputs in vjp and add a trivial tangent space for integer-valued arrays. Since that's potentially a further-reaching change, this commit instead just applies a local fix to avoid adding integer-valued inputs to the dynamics function by adapting the closure-conversion code.