Control Flow - Cond — PyTorch 2.7 documentation (original) (raw)

torch.cond is a structured control flow operator. It can be used to specify if-else like control flow and can logically be seen as implemented as follows.

def cond( pred: Union[bool, torch.Tensor], true_fn: Callable, false_fn: Callable, operands: Tuple[torch.Tensor] ): if pred: return true_fn(*operands) else: return false_fn(*operands)

Its unique power lies in its ability of expressing data-dependent control flow: it lowers to a conditional operator (torch.ops.higher_order.cond), which preserves predicate, true function and false functions. This unlocks great flexibility in writing and deploying models that change model architecture based on the value or shape of inputs or intermediate outputs of tensor operations.

Examples

Below is an example that uses cond to branch based on input shape:

import torch

def true_fn(x: torch.Tensor): return x.cos() + x.sin()

def false_fn(x: torch.Tensor): return x.sin()

class DynamicShapeCondPredicate(torch.nn.Module): """ A basic usage of cond based on dynamic shape predicate. """

def __init__(self):
    super().__init__()

def forward(self, x: torch.Tensor) -> torch.Tensor:
    def true_fn(x: torch.Tensor):
        return x.cos()

    def false_fn(x: torch.Tensor):
        return x.sin()

    return torch.cond(x.shape[0] > 4, true_fn, false_fn, (x,))

dyn_shape_mod = DynamicShapeCondPredicate()

We can eagerly run the model and expect the results vary based on input shape:

inp = torch.randn(3) inp2 = torch.randn(5) assert torch.equal(dyn_shape_mod(inp), false_fn(inp)) assert torch.equal(dyn_shape_mod(inp2), true_fn(inp2))

We can export the model for further transformations and deployment:

inp = torch.randn(4, 3) dim_batch = torch.export.Dim("batch", min=2) ep = torch.export.export(DynamicShapeCondPredicate(), (inp,), {}, dynamic_shapes={"x": {0: dim_batch}}) print(ep)

This gives us an exported program as shown below:

class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[s0, 3]): sym_size: Sym(s0) = torch.ops.aten.sym_size.int(arg0_1, 0) gt: Sym(s0 > 4) = sym_size > 4; sym_size = None true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None return (conditional,)

class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: f32[s0, 3]):
        cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1)
        sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
        add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin);  cos = sin = None
        return add

class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: f32[s0, 3]):
        sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
        return sin

Notice that torch.cond is lowered to torch.ops.higher_order.cond, its predicate becomes a Symbolic expression over the shape of input, and branch functions becomes two sub-graph attributes of the top level graph module.

Here is another example that showcases how to express a data-dependent control flow:

class DataDependentCondPredicate(torch.nn.Module): """ A basic usage of cond based on data dependent predicate. """ def init(self): super().init()

def forward(self, x: torch.Tensor) -> torch.Tensor:
    return torch.cond(x.sum() > 4.0, true_fn, false_fn, (x,))

The exported program we get after export:

class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[s0, 3]): sum_1: f32[] = torch.ops.aten.sum.default(arg0_1) gt: b8[] = torch.ops.aten.gt.Scalar(sum_1, 4.0); sum_1 = None

    true_graph_0 = self.true_graph_0
    false_graph_0 = self.false_graph_0
    conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]);  gt = true_graph_0 = false_graph_0 = arg0_1 = None
    return (conditional,)

class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: f32[s0, 3]):
        cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1)
        sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
        add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin);  cos = sin = None
        return add

class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: f32[s0, 3]):
        sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
        return sin

Invariants of torch.ops.higher_order.cond

There are several useful invariants for torch.ops.higher_order.cond:

API Reference

torch._higher_order_ops.cond.cond(pred, true_fn, false_fn, operands=())[source]

Conditionally applies true_fn or false_fn.

cond is structured control flow operator. That is, it is like a Python if-statement, but has restrictions on true_fn, false_fn, and operands that enable it to be capturable using torch.compile and torch.export.

Assuming the constraints on cond’s arguments are met, cond is equivalent to the following:

def cond(pred, true_branch, false_branch, operands): if pred: return true_branch(*operands) else: return false_branch(*operands)

Parameters

Return type

Any

Example:

def true_fn(x: torch.Tensor): return x.cos() def false_fn(x: torch.Tensor): return x.sin() return cond(x.shape[0] > 4, true_fn, false_fn, (x,))

Restrictions: