Ensure that check_jaxpr is done with abstract values by gnecula · Pull Request #4650 · jax-ml/jax (original) (raw)

Prior to this it was possible, e.g., for code that contains a Literal,
such as broadcast_to_dim[shape=(1000)] 0.0 to result in FLOPS during checking.

The assertion is broken by many tests unless we raise_to_shape for Literals.

I have timed the checks on my laptop and I do not see a reduction in the
total test time.