torch.compile — PyTorch 2.7 documentation (original) (raw)

torch.compile(model: Callable[[_InputT], _RetT], *, fullgraph: bool = False, dynamic: Optional[bool] = None, backend: Union[str, Callable] = 'inductor', mode: Optional[str] = None, options: Optional[dict[str, Union[str, int, bool]]] = None, disable: bool = False) → Callable[[_InputT], _RetT][source][source]

torch.compile(model: None = None, *, fullgraph: bool = False, dynamic: Optional[bool] = None, backend: Union[str, Callable] = 'inductor', mode: Optional[str] = None, options: Optional[dict[str, Union[str, int, bool]]] = None, disable: bool = False) → Callable[[Callable[[_InputT], _RetT]], Callable[[_InputT], _RetT]]

Optimizes given model/function using TorchDynamo and specified backend. If you are compiling an torch.nn.Module, you can also use torch.nn.Module.compile()to compile the module inplace without changing its structure.

Concretely, for every frame executed within the compiled region, we will attempt to compile it and cache the compiled result on the code object for future use. A single frame may be compiled multiple times if previous compiled results are not applicable for subsequent calls (this is called a “guard failure), you can use TORCH_LOGS=guards to debug these situations. Multiple compiled results can be associated with a frame up totorch._dynamo.config.recompile_limit, which defaults to 8; at which point we will fall back to eager. Note that compile caches are per_code object_, not frame; if you dynamically create multiple copies of a function, they will all share the same code cache.

Parameters

Example:

@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True) def foo(x): return torch.sin(x) + torch.cos(x)