[host_callback] Add support for pmap and for passing the device to tap by gnecula · Pull Request #5182 · jax-ml/jax (original) (raw)

@gnecula: for the logging usecase, we would like the data in the form @shoyer mentioned. But there is not enough info to get it into this form with the current setup, even with the device information. Consider:

@jax.pmap
@jax.vmap
def f(x):
  x = host_callback.id_tap(print, x)
  return x**2

If we switch the vmap and pmap, the output of the function is the same. But in the one case the prints are

[0 0] (('batch', {'batch_dims': (0,)}),)
[1 1] (('batch', {'batch_dims': (0,)}),)

and for the other it is

[0 1] (('batch', {'batch_dims': (0,)}),)
[0 1] (('batch', {'batch_dims': (0,)}),)

I do realize that this is not a common way of using pmap and vmap, and perhaps its not worth worrying about. Other than this sort of scenario, does one have enough info to reconstruct the log as though it ran on a single device?