Dynamic shapes — PyTorch 2.7 documentation (original) (raw)

Code: symbolic_shapes.py

See also: The dynamic shapes manual

Motivation

Deep learning compilers commonly only work for static shapes, that is to say, they produced compiled programs which only work for a single specific configuration of input shapes, and must recompile if any input shape changes. This assumption works great for the majority of commonly run deep learning models today, but there are a few situations where it is insufficient:

In supporting dynamic shapes, we chose not to support dynamic rank programs, e.g., programs whose inputs tensors change in dimensionality, as this pattern rarely occurs in real-world deep learning programs, and it avoids the need to reason inductively over symbolic lists of shapes.

Abridged public API

The default dynamic behavior in PyTorch 2.1 is:

The Guard Model

When considering how to add support for dynamic shapes to TorchDynamo and TorchInductor, we made a major design decision: in order to reuse decompositions and other preexisting code written in Python/C++ targeting the PyTorch API, we must be able to trace through dynamic shapes. Unlike a fully symbolic system which might capture both branches of a conditional, we always pick one branch and specialize our trace under the assumption that we only use this trace when we would have made the same choice for that branch in the future. To do this, we maintain a “hint” for every symbolic size saying what its concrete value is at compile time (as TorchDynamo is a just-in-time compiler, it always knows what the actual input sizes are.) When we perform a condition on a tensor, we simply consult the hint to find out which branch to take.

This greatly simplifies the symbolic shape formulas we produce, but means we have a much more involved system for managing guards. Consider, for example, the following program:

def f(x, y): z = torch.cat([x, y]) if z.size(0) > 2: return z.mul(2) else: return z.add(2)

The final IR we will compile with TorchInductor will either be torch.cat([x, y]).add(2) or torch.cat([x, y]).mul(2) (with the condition flattened away), but to determine which branch we are in, we would need to know the size of z, an intermediate. Because TorchDynamo must know upfront if a compiled trace is valid (we do not support bailouts, like some JIT compilers), we must be able to reduce z.size(0) as an expression in terms of the inputs, x.size(0) + y.size(0). This is done by writing meta functions for all operators in PyTorch which can propagate size information to the output of a tensor without actually performing computation on the node.

Overall architecture

Symbolic shapes workflow:

  1. When we start compiling a frame in Dynamo, we allocate a ShapeEnv (attached to FakeTensorMode) which keeps track of symbolic shapes state.
  2. We allocate symbolic sizes for tensors on entry (what is static or dynamic is a policy decision, with some knobs).
  3. We propagate the symbolic sizes through operators, maintaining both (1) FX IR so that we can faithfully export symbolic compute, and (2) Sympy expressions representing the size vars, so we can reason about them.
  4. When we condition on symbolic sizes, either in Dynamo tracing or in Inductor optimization, we add guards based on the conditional. These can be induced from both Python and C++.
  5. These guards can induce further simplifications on symbolic variables. For example, if you assert s0 == 4, we can now replace all occurrences of s0 with 4.
  6. When we’re done tracing and optimizing, we install all of these guards with the compiled code; the compiled code is only reusable if all the guards evaluate true.

Important files:

Abridged internal API

Understanding the Python class hierarchy:

C++ is fairly similar:

When you write code that is traceable with make_fx, it must be able to deal with SymInt/SymFloat/SymBool flowing through it. The dynamic shapes manual gives some guidance for how to do this.

DimDynamic policy

Symbolic reasoning:

Unbacked SymInts

To resolve control flow, we check the hint, aka actual value, of a symbolic integer to determine which branch to go. However, in some cases, we may not have a hint: so-called unbacked symbolic integers arise when a size variable emerges from a data-dependent operation like .nonzero() or .item(). It is illegal to perform control flow on these symbolic integers, so we must graph break on these operations.

Naively implemented, this is too restrictive: most PyTorch programs will immediately fail if you try to do anything with unbacked symbolic integers. Here are the most important enhancements to make this actually work:

In future versions of PT2 (beyond PT2.1), we will extend our reasoning system to infer that an unbacked symbolic integer is size-like based on usage. For example, if you pass the result of an .item() call to a factory function like torch.empty, we will automatically infer that the result is a size (because if it was not, it would fail.) This assumption would get validated at runtime, raising an error if it was not fulfilled.