Ensure that check_jaxpr is done with abstract values by gnecula · Pull Request #4650 · jax-ml/jax (original) (raw)
Prior to this it was possible, e.g., for code that contains a Literal,
such as broadcast_to_dim[shape=(1000)] 0.0
to result in FLOPS during checking.
The assertion is broken by many tests unless we raise_to_shape for Literals.
I have timed the checks on my laptop and I do not see a reduction in the
total test time.