[host_callback] Add support for pmap and for passing the device to tap by gnecula · Pull Request #5182 · jax-ml/jax (original) (raw)
@gnecula: for the logging usecase, we would like the data in the form @shoyer mentioned. But there is not enough info to get it into this form with the current setup, even with the device information. Consider:
@jax.pmap
@jax.vmap
def f(x):
x = host_callback.id_tap(print, x)
return x**2
If we switch the vmap
and pmap
, the output of the function is the same. But in the one case the prints are
[0 0] (('batch', {'batch_dims': (0,)}),)
[1 1] (('batch', {'batch_dims': (0,)}),)
and for the other it is
[0 1] (('batch', {'batch_dims': (0,)}),)
[0 1] (('batch', {'batch_dims': (0,)}),)
I do realize that this is not a common way of using pmap
and vmap
, and perhaps its not worth worrying about. Other than this sort of scenario, does one have enough info to reconstruct the log as though it ran on a single device?