lax.cond
can bind to unexpected function signature · Issue #16413 · jax-ml/jax (original) (raw)
Description
If you call a lax.cond
with the format lax.cond(predicate, function, function, callable_pytree, callable_pytree)
then lax.cond
will bind to the old function signature <Signature (pred, true_operand, true_fun: Callable, false_operand, false_fun: Callable)>
and swap the arguments in an unexpected way.
Here is a reproducing example:
import jax.lax as lax import jax.numpy as jnp import jax.tree_util as jtu
def true_branch(add_one, add_two): return add_one(add_two(jnp.array(1.)))
def false_branch(add_one, add_two): return add_two(add_one(jnp.array(1.)))
add_one = jtu.Partial(jnp.add, jnp.array(1.)) # A callable pytree add_two = jtu.Partial(jnp.add, jnp.array(2.)) four = lax.cond(True, true_branch, false_branch, add_one, add_two) # TypeError
What jax/jaxlib version are you using?
0.4.11
Which accelerator(s) are you using?
CPU
Additional system info
Linux
NVIDIA GPU info
No response