lax.cond can bind to unexpected function signature · Issue #16413 · jax-ml/jax (original) (raw)

Description

If you call a lax.cond with the format lax.cond(predicate, function, function, callable_pytree, callable_pytree) then lax.cond will bind to the old function signature <Signature (pred, true_operand, true_fun: Callable, false_operand, false_fun: Callable)> and swap the arguments in an unexpected way.

Here is a reproducing example:

import jax.lax as lax import jax.numpy as jnp import jax.tree_util as jtu

def true_branch(add_one, add_two): return add_one(add_two(jnp.array(1.)))

def false_branch(add_one, add_two): return add_two(add_one(jnp.array(1.)))

add_one = jtu.Partial(jnp.add, jnp.array(1.)) # A callable pytree add_two = jtu.Partial(jnp.add, jnp.array(2.)) four = lax.cond(True, true_branch, false_branch, add_one, add_two) # TypeError

What jax/jaxlib version are you using?

0.4.11

Which accelerator(s) are you using?

CPU

Additional system info

Linux

NVIDIA GPU info

No response