temporarily switch off #2414 changes · jax-ml/jax@ed8dbd2 (original) (raw)

Expand Up

@@ -154,7 +154,7 @@ def fori_loop(lower, upper, body_fun, init_val):

except TypeError:

use_scan = False

else:

use_scan = True

use_scan = False # TODO(mattjj): re-enable this

if use_scan:

(_, _, result), _ = scan(_fori_scan_body_fun(body_fun),

Expand Down Expand Up

@@ -1209,7 +1209,8 @@ def scan_bind(*args, forward, length, num_consts, num_carry, jaxpr, linear):

scan_p = core.Primitive("scan")

scan_p.multiple_results = True

scan_p.def_custom_bind(scan_bind)

scan_p.def_impl(partial(xla.apply_primitive, scan_p))

scan_p.def_impl(_scan_impl)

# scan_p.def_impl(partial(xla.apply_primitive, scan_p)) # TODO(mattjj): re-enable

scan_p.def_abstract_eval(_scan_abstract_eval)

ad.primitive_jvps[scan_p] = _scan_jvp

ad.primitive_transposes[scan_p] = _scan_transpose

Expand Down