temporarily switch off #2414 changes · jax-ml/jax@ed8dbd2 (original) (raw)
@@ -154,7 +154,7 @@ def fori_loop(lower, upper, body_fun, init_val):
except TypeError:
use_scan = False
else:
use_scan = True
use_scan = False # TODO(mattjj): re-enable this
if use_scan:
(_, _, result), _ = scan(_fori_scan_body_fun(body_fun),
@@ -1209,7 +1209,8 @@ def scan_bind(*args, forward, length, num_consts, num_carry, jaxpr, linear):
scan_p = core.Primitive("scan")
scan_p.multiple_results = True
scan_p.def_custom_bind(scan_bind)
scan_p.def_impl(partial(xla.apply_primitive, scan_p))
scan_p.def_impl(_scan_impl)
# scan_p.def_impl(partial(xla.apply_primitive, scan_p)) # TODO(mattjj): re-enable
scan_p.def_abstract_eval(_scan_abstract_eval)
ad.primitive_jvps[scan_p] = _scan_jvp
ad.primitive_transposes[scan_p] = _scan_transpose