torch.autograd.Function.vmap — PyTorch 2.7 documentation (original) (raw)
static Function.vmap(info, in_dims, *args)[source][source]¶
Define the behavior for this autograd.Function underneath torch.vmap().
For a torch.autograd.Function() to supporttorch.vmap(), you must either override this static method, or setgenerate_vmap_rule
to True
(you may not do both).
If you choose to override this staticmethod: it must accept
- an
info
object as the first argument.info.batch_size
specifies the size of the dimension being vmapped over, whileinfo.randomness
is the randomness option passed totorch.vmap(). - an
in_dims
tuple as the second argument. For each arg inargs
,in_dims
has a correspondingOptional[int]
. It isNone
if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer specifying what dimension of the Tensor is being vmapped over. *args
, which is the same as the args to forward().
The return of the vmap staticmethod is a tuple of (output, out_dims)
. Similar to in_dims
, out_dims
should be of the same structure asoutput
and contain one out_dim
per output that specifies if the output has the vmapped dimension and what index it is in.
Please see Extending torch.func with autograd.Function for more details.