Remove the JAX lazy sublanguage. by hawkinsp · Pull Request #6002 · jax-ml/jax (original) (raw)
Back in the mists of time, before omnistaging landed in JAX, we used lazy
expressions to avoid materializing large constants inside jit
computations.
Omnistaging, which means that computations that are in the dynamic scope of ajit
are staged into the jit
computation, has subsumed most of the reasons
for laziness to exist, and this PR removes the laziness support for simplicity.
At the time of this PR, laziness is used only for broadcasts and transposes in
eager mode (i.e., outside a jit
). This allows us to:
a) fuse together multiple broadcasts and transposes, and
b) if a lazy expression is lexically captured by a jit
computation, we can
avoid materializing it in its expanded form.
It is not clear that laziness has sufficient power to weight ratio to continue
to exist, and it is making other work on improving JAX dispatch times more
difficult. As a result, this PR removes laziness to unblock that work; if we
want laziness again we would want to reimplement it in C++ anyway.