Add control flow callbacks for JIT-compiled frameworks (original) (raw)

Hello Python community,
First - I’ll say that python is the best language in the world, but it’s quite difficult to make it performant (as everybody knows).
So we have come up with a just-in-time compilation for performance - offloading work to C++ and CUDA code which makes the code run extremely fast. Jax is currently the fastest framework for any numerical computation in the world. It easily beats even most optimal C++ code in, for example, structural analysis by an order of magnitude.

Thing is, JIT-compiled code is difficult to write. JIT can not compile if statements because they are untraceable and jax has a way of knowing about them.

We have to write attrocities like these:

import jax.numpy as np
import jax.lax.cond as cond()

def headfunc(count_found_sums:bool, multiply:int):
    # a,b,c,d,e,f are defined higher.
    def func1(a:int,b:int,c:np.ndarray,_,_,_): 
        sum=a+b
        count_found_items=np.count_nonzero(c==sum)
        return count_found_items

   def func2(_,_,_,d:np.ndarray, e:np.ndarray, f: np.ndarray): 
      diff1 = np.count_nonzero(d==e)
      diff2 = np.count_nonzero(d==f)
      overlap = np.logical_and(diff1=diff2))
      return np.count_nonzero(overlap))

  result = cond( # cond is the `if` statement in Jax.
        count_found_sums, 
        func1, 
        func2, 
        (a,b,c,d,e,f)  #notice: we have to pass all because conditionals require 
    )
    return result

Which is an equivalent of:

def headfunc(count_found_sums:bool, multiply:int):
    # a,b,c,d,e,f are defined higher.
    if cound_found_sums: 
        sum=a+b
        count_found_items=np.count_nonzero(c==sum)
        return count_found_items
   else:
      diff1 = np.count_nonzero(d==e)
      diff2 = np.count_nonzero(d==f)
      overlap = np.logical_and(diff1=diff2))
      return np.count_nonzero(overlap))

# **7 lines shorter**

On thousands of lines of code the code becomes 50% longer due to if conditions being more difficult to write, in my usecase at least.
Pytorch has something similar with it’s jit not being able to trace if, and they have to write something similar. (only read about it online though, never worked with it.)

Proposal:
In order to allow much simpler jit-compiled code, add a flag to python launch that would trigger a callback every time compilation will find python if and preferably for statements.

something like this:
python3 main.py --flag-control-flow

I personally had to create a 8k SLOS codebase of performance-critical scientific compute code recently and let me tell you, there is every reason to add this if python wants to be a fast for scientists and machine learning developers (which it wants to).

Jax developer jakevdp (Jake Vanderplas) · GitHub will probably be able to know implementation a little better than myself, since I’m not a developer of jax, although I do use it a lot.