torch.autograd.Function.vmap — PyTorch 2.7 documentation (original) (raw)

static Function.vmap(info, in_dims, *args)[source][source]

Define the behavior for this autograd.Function underneath torch.vmap().

For a torch.autograd.Function() to supporttorch.vmap(), you must either override this static method, or setgenerate_vmap_rule to True (you may not do both).

If you choose to override this staticmethod: it must accept

The return of the vmap staticmethod is a tuple of (output, out_dims). Similar to in_dims, out_dims should be of the same structure asoutput and contain one out_dim per output that specifies if the output has the vmapped dimension and what index it is in.

Please see Extending torch.func with autograd.Function for more details.