TorchScript Language Reference — PyTorch 2.7 documentation (original) (raw)

This reference manual describes the syntax and core semantics of the TorchScript language. TorchScript is a statically typed subset of the Python language. This document explains the supported features of Python in TorchScript and also how the language diverges from regular Python. Any features of Python that are not mentioned in this reference manual are not part of TorchScript. TorchScript focuses specifically on the features of Python that are needed to represent neural network models in PyTorch.

Terminology

This document uses the following terminologies:

Pattern Notes
::= Indicates that the given symbol is defined as.
" " Represents real keywords and delimiters that are part of the syntax.
A | B Indicates either A or B.
( ) Indicates grouping.
[] Indicates optional.
A+ Indicates a regular expression where term A is repeated at least once.
A* Indicates a regular expression where term A is repeated zero or more times.

Type System

TorchScript is a statically typed subset of Python. The largest difference between TorchScript and the full Python language is that TorchScript only supports a small set of types that are needed to express neural net models.

TorchScript Types

The TorchScript type system consists of TSType and TSModuleType as defined below.

TSAllType ::= TSType | TSModuleType TSType ::= TSMetaType | TSPrimitiveType | TSStructuralType | TSNominalType

TSType represents the majority of TorchScript types that are composable and that can be used in TorchScript type annotations.TSType refers to any of the following:

TSModuleType represents torch.nn.Module and its subclasses. It is treated differently from TSType because its type schema is inferred partly from the object instance and partly from the class definition. As such, instances of a TSModuleType may not follow the same static type schema. TSModuleType cannot be used as a TorchScript type annotation or be composed with TSType for type safety considerations.

Primitive Types

Primitive TorchScript types are types that represent a single type of value and go with a single pre-defined type name.

TSPrimitiveType ::= "int" | "float" | "double" | "complex" | "bool" | "str" | "None"

Structural Types

Structural types are types that are structurally defined without a user-defined name (unlike nominal types), such as Future[int]. Structural types are composable with any TSType.

TSStructuralType ::= TSTuple | TSNamedTuple | TSList | TSDict | TSOptional | TSUnion | TSFuture | TSRRef | TSAwait

TSTuple ::= "Tuple" "[" (TSType ",")* TSType "]" TSNamedTuple ::= "namedtuple" "(" (TSType ",")* TSType ")" TSList ::= "List" "[" TSType "]" TSOptional ::= "Optional" "[" TSType "]" TSUnion ::= "Union" "[" (TSType ",")* TSType "]" TSFuture ::= "Future" "[" TSType "]" TSRRef ::= "RRef" "[" TSType "]" TSAwait ::= "Await" "[" TSType "]" TSDict ::= "Dict" "[" KeyType "," TSType "]" KeyType ::= "str" | "int" | "float" | "bool" | TensorType | "Any"

Where:

Compared to Python

Apart from being composable with TorchScript types, these TorchScript structural types often support a common subset of the operators and methods of their Python counterparts.

Example 1

This example uses typing.NamedTuple syntax to define a tuple:

import torch from typing import NamedTuple from typing import Tuple

class MyTuple(NamedTuple): first: int second: int

def inc(x: MyTuple) -> Tuple[int, int]: return (x.first+1, x.second+1)

t = MyTuple(first=1, second=2) scripted_inc = torch.jit.script(inc) print("TorchScript:", scripted_inc(t))

The example above produces the following output:

Example 2

This example uses collections.namedtuple syntax to define a tuple:

import torch from typing import NamedTuple from typing import Tuple from collections import namedtuple

_AnnotatedNamedTuple = NamedTuple('_NamedTupleAnnotated', [('first', int), ('second', int)]) _UnannotatedNamedTuple = namedtuple('_NamedTupleAnnotated', ['first', 'second'])

def inc(x: _AnnotatedNamedTuple) -> Tuple[int, int]: return (x.first+1, x.second+1)

m = torch.jit.script(inc) print(inc(_UnannotatedNamedTuple(1,2)))

The example above produces the following output:

Example 3

This example illustrates a common mistake of annotating structural types, i.e., not importing the composite type classes from the typing module:

import torch

ERROR: Tuple not recognized because not imported from typing

@torch.jit.export def inc(x: Tuple[int, int]): return (x[0]+1, x[1]+1)

m = torch.jit.script(inc) print(m((1,2)))

Running the above code yields the following scripting error:

File "test-tuple.py", line 5, in def inc(x: Tuple[int, int]): NameError: name 'Tuple' is not defined

The remedy is to add the line from typing import Tuple to the beginning of the code.

Nominal Types

Nominal TorchScript types are Python classes. These types are called nominal because they are declared with a custom name and are compared using class names. Nominal classes are further classified into the following categories:

TSNominalType ::= TSBuiltinClasses | TSCustomClass | TSEnum

Among them, TSCustomClass and TSEnum must be compilable to TorchScript Intermediate Representation (IR). This is enforced by the type-checker.

Built-in Class

Built-in nominal types are Python classes whose semantics are built into the TorchScript system (e.g., tensor types). TorchScript defines the semantics of these built-in nominal types, and often supports only a subset of the methods or attributes of its Python class definition.

TSBuiltinClass ::= TSTensor | "torch.device" | "torch.Stream" | "torch.dtype" | "torch.nn.ModuleList" | "torch.nn.ModuleDict" | ... TSTensor ::= "torch.Tensor" | "common.SubTensor" | "common.SubWithTorchFunction" | "torch.nn.parameter.Parameter" | and subclasses of torch.Tensor

Special Note on torch.nn.ModuleList and torch.nn.ModuleDict

Although torch.nn.ModuleList and torch.nn.ModuleDict are defined as a list and dictionary in Python, they behave more like tuples in TorchScript:

Example

The following example highlights the use of a few built-in Torchscript classes (torch.*):

import torch

@torch.jit.script class A: def init(self): self.x = torch.rand(3)

def f(self, y: torch.device):
    return self.x.to(device=y)

def g(): a = A() return a.f(torch.device("cpu"))

script_g = torch.jit.script(g) print(script_g.graph)

Custom Class

Unlike built-in classes, semantics of custom classes are user-defined and the entire class definition must be compilable to TorchScript IR and subject to TorchScript type-checking rules.

TSClassDef ::= [ "@torch.jit.script" ] "class" ClassName [ "(object)" ] ":" MethodDefinition | [ "@torch.jit.ignore" ] | [ "@torch.jit.unused" ] MethodDefinition

Where:

Compared to Python

TorchScript custom classes are quite limited compared to their Python counterpart. Torchscript custom classes:

Example 1

Python classes can be used in TorchScript if they are annotated with @torch.jit.script, similar to how a TorchScript function would be declared:

@torch.jit.script class MyClass: def init(self, x: int): self.x = x

def inc(self, val: int):
    self.x += val

Example 2

A TorchScript custom class type must “declare” all its instance attributes by assignments in __init__(). If an instance attribute is not defined in __init__() but accessed in other methods of the class, the class cannot be compiled as a TorchScript class, as shown in the following example:

import torch

@torch.jit.script class foo: def init(self): self.y = 1

ERROR: self.x is not defined in init

def assign_x(self): self.x = torch.rand(2, 3)

The class will fail to compile and issue the following error:

RuntimeError: Tried to set nonexistent attribute: x. Did you forget to initialize it in init()?: def assign_x(self): self.x = torch.rand(2, 3) ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE

Example 3

In this example, a TorchScript custom class defines a class variable name, which is not allowed:

import torch

@torch.jit.script class MyClass(object): name = "MyClass" def init(self, x: int): self.x = x

def fn(a: MyClass): return a.name

It leads to the following compile-time error:

RuntimeError: 'torch.MyClass' object has no attribute or method 'name'. Did you forget to initialize an attribute in init()?: File "test-class2.py", line 10 def fn(a: MyClass): return a.name ~~~~~~ <--- HERE

Enum Type

Like custom classes, semantics of the enum type are user-defined and the entire class definition must be compilable to TorchScript IR and adhere to TorchScript type-checking rules.

TSEnumDef ::= "class" Identifier "(enum.Enum | TSEnumType)" ":" ( MemberIdentifier "=" Value )+ ( MethodDefinition )*

Where:

Compared to Python

Example 1

The following example defines the class Color as an Enum type:

import torch from enum import Enum

class Color(Enum): RED = 1 GREEN = 2

def enum_fn(x: Color, y: Color) -> bool: if x == Color.RED: return True return x == y

m = torch.jit.script(enum_fn)

print("Eager: ", enum_fn(Color.RED, Color.GREEN)) print("TorchScript: ", m(Color.RED, Color.GREEN))

Example 2

The following example shows the case of restricted enum subclassing, where BaseColor does not define any member, thus can be subclassed by Color:

import torch from enum import Enum

class BaseColor(Enum): def foo(self): pass

class Color(BaseColor): RED = 1 GREEN = 2

def enum_fn(x: Color, y: Color) -> bool: if x == Color.RED: return True return x == y

m = torch.jit.script(enum_fn)

print("TorchScript: ", m(Color.RED, Color.GREEN)) print("Eager: ", enum_fn(Color.RED, Color.GREEN))

TorchScript Module Class

TSModuleType is a special class type that is inferred from object instances that are created outside TorchScript. TSModuleType is named by the Python class of the object instance. The __init__() method of the Python class is not considered a TorchScript method, so it does not have to comply with TorchScript’s type-checking rules.

The type schema of a module instance class is constructed directly from an instance object (created outside the scope of TorchScript) rather than inferred from __init__() like custom classes. It is possible that two objects of the same instance class type follow two different type schemas.

In this sense, TSModuleType is not really a static type. Therefore, for type safety considerations, TSModuleType cannot be used in a TorchScript type annotation or be composed with TSType.

Module Instance Class

TorchScript module type represents the type schema of a user-defined PyTorch module instance. When scripting a PyTorch module, the module object is always created outside TorchScript (i.e., passed in as parameter to forward). The Python module class is treated as a module instance class, so the __init__() method of the Python module class is not subject to the type-checking rules of TorchScript.

TSModuleType ::= "class" Identifier "(torch.nn.Module)" ":" ClassBodyDefinition

Where:

Unlike custom classes, only the forward method and other methods decorated with @torch.jit.export of the module type need to be compilable. Most notably, __init__() is not considered a TorchScript method. Consequently, module type constructors cannot be invoked within the scope of TorchScript. Instead, TorchScript module objects are always constructed outside and passed into torch.jit.script(ModuleObj).

Example 1

This example illustrates a few features of module types:

import torch

class TestModule(torch.nn.Module): def init(self, v): super().init() self.x = v

def forward(self, inc: int):
    return self.x + inc

m = torch.jit.script(TestModule(1)) print(f"First instance: {m(3)}")

m = torch.jit.script(TestModule(torch.ones([5]))) print(f"Second instance: {m(3)}")

The example above produces the following output:

First instance: 4 Second instance: tensor([4., 4., 4., 4., 4.])

Example 2

The following example shows an incorrect usage of module type. Specifically, this example invokes the constructor of TestModule inside the scope of TorchScript:

import torch

class TestModule(torch.nn.Module): def init(self, v): super().init() self.x = v

def forward(self, x: int):
    return self.x + x

class MyModel: def init(self, v: int): self.val = v

@torch.jit.export
def doSomething(self, val: int) -> int:
    # error: should not invoke the constructor of module type
    myModel = TestModule(self.val)
    return myModel(val)

m = torch.jit.script(MyModel(2)) # Results in below RuntimeError

RuntimeError: Could not get name of python class object

Type Annotation

Since TorchScript is statically typed, programmers need to annotate types at strategic points of TorchScript code so that every local variable or instance data attribute has a static type, and every function and method has a statically typed signature.

When to Annotate Types

In general, type annotations are only needed in places where static types cannot be automatically inferred (e.g., parameters or sometimes return types to methods or functions). Types of local variables and data attributes are often automatically inferred from their assignment statements. Sometimes an inferred type may be too restrictive, e.g., x being inferred as NoneType through assignment x = None, whereas x is actually used as an Optional. In such cases, type annotations may be needed to overwrite auto inference, e.g., x: Optional[int] = None. Note that it is always safe to type annotate a local variable or data attribute even if its type can be automatically inferred. The annotated type must be congruent with TorchScript’s type-checking.

When a parameter, local variable, or data attribute is not type annotated and its type cannot be automatically inferred, TorchScript assumes it to be a default type of TensorType, List[TensorType], or Dict[str, TensorType].

Annotate Function Signature

Since a parameter may not be automatically inferred from the body of the function (including both functions and methods), they need to be type annotated. Otherwise, they assume the default type TensorType.

TorchScript supports two styles for method and function signature type annotation:

Python3Annotation ::= "def" Identifier [ "(" ParamAnnot* ")" ] [ReturnAnnot] ":" FuncOrMethodBody ParamAnnot ::= Identifier [ ":" TSType ] "," ReturnAnnot ::= "->" TSType

Note that when using Python3 style, the type self is automatically inferred and should not be annotated.

MyPyAnnotation ::= "# type:" "(" ParamAnnot* ")" [ ReturnAnnot ] ParamAnnot ::= TSType "," ReturnAnnot ::= "->" TSType

Example 1

In this example:

import torch

def f(a, b: int): return a+b

m = torch.jit.script(f) print("TorchScript:", m(torch.ones([6]), 100))

Example 2

The following example uses Mypy style annotation. Note that parameters or return values must be annotated even if some of them assume the default type.

import torch

def f(a, b): # type: (torch.Tensor, int) → torch.Tensor return a+b

m = torch.jit.script(f) print("TorchScript:", m(torch.ones([6]), 100))

Annotate Variables and Data Attributes

In general, types of data attributes (including class and instance data attributes) and local variables can be automatically inferred from assignment statements. Sometimes, however, if a variable or attribute is associated with values of different types (e.g., as None or TensorType), then they may need to be explicitly type annotated as a wider type such as Optional[int] or Any.

Local Variables

Local variables can be annotated according to Python3 typing module annotation rules, i.e.,

LocalVarAnnotation ::= Identifier [":" TSType] "=" Expr

In general, types of local variables can be automatically inferred. In some cases, however, you may need to annotate a multi-type for local variables that may be associated with different concrete types. Typical multi-types include Optional[T] and Any.

Example

import torch

def f(a, setVal: bool): value: Optional[torch.Tensor] = None if setVal: value = a return value

ones = torch.ones([6]) m = torch.jit.script(f) print("TorchScript:", m(ones, True), m(ones, False))

Instance Data Attributes

For ModuleType classes, instance data attributes can be annotated according to Python3 typing module annotation rules. Instance data attributes can be annotated (optionally) as final via Final.

"class" ClassIdentifier "(torch.nn.Module):" InstanceAttrIdentifier ":" ["Final("] TSType [")"] ...

Where:

Example

import torch

class MyModule(torch.nn.Module): offset_: int

def init(self, offset): self.offset_ = offset

...

Type Annotation APIs

torch.jit.annotate(T, expr)

This API annotates type T to an expression expr. This is often used when the default type of an expression is not the type intended by the programmer. For instance, an empty list (dictionary) has the default type of List[TensorType] (Dict[TensorType, TensorType]), but sometimes it may be used to initialize a list of some other types. Another common use case is for annotating the return type of tensor.tolist(). Note, however, that it cannot be used to annotate the type of a module attribute in __init__; torch.jit.Attribute should be used for this instead.

Example

In this example, [] is declared as a list of integers via torch.jit.annotate (instead of assuming [] to be the default type of List[TensorType]):

import torch from typing import List

def g(l: List[int], val: int): l.append(val) return l

def f(val: int): l = g(torch.jit.annotate(List[int], []), val) return l

m = torch.jit.script(f) print("Eager:", f(3)) print("TorchScript:", m(3))

See torch.jit.annotate() for more information.

Type Annotation Appendix

TorchScript Type System Definition

TSAllType ::= TSType | TSModuleType TSType ::= TSMetaType | TSPrimitiveType | TSStructuralType | TSNominalType

TSMetaType ::= "Any" TSPrimitiveType ::= "int" | "float" | "double" | "complex" | "bool" | "str" | "None"

TSStructuralType ::= TSTuple | TSNamedTuple | TSList | TSDict | TSOptional | TSUnion | TSFuture | TSRRef | TSAwait TSTuple ::= "Tuple" "[" (TSType ",")* TSType "]" TSNamedTuple ::= "namedtuple" "(" (TSType ",")* TSType ")" TSList ::= "List" "[" TSType "]" TSOptional ::= "Optional" "[" TSType "]" TSUnion ::= "Union" "[" (TSType ",")* TSType "]" TSFuture ::= "Future" "[" TSType "]" TSRRef ::= "RRef" "[" TSType "]" TSAwait ::= "Await" "[" TSType "]" TSDict ::= "Dict" "[" KeyType "," TSType "]" KeyType ::= "str" | "int" | "float" | "bool" | TensorType | "Any"

TSNominalType ::= TSBuiltinClasses | TSCustomClass | TSEnum TSBuiltinClass ::= TSTensor | "torch.device" | "torch.stream"| "torch.dtype" | "torch.nn.ModuleList" | "torch.nn.ModuleDict" | ... TSTensor ::= "torch.tensor" and subclasses

Unsupported Typing Constructs

TorchScript does not support all features and types of the Python3 typing module. Any functionality from the typing module that is not explicitly specified in this documentation is unsupported. The following table summarizes typing constructs that are either unsupported or supported with restrictions in TorchScript.

Item Description
typing.Any In development
typing.NoReturn Not supported
typing.Callable Not supported
typing.Literal Not supported
typing.ClassVar Not supported
typing.Final Supported for module attributes, class attribute, and annotations, but not for functions.
typing.AnyStr Not supported
typing.overload In development
Type aliases Not supported
Nominal typing In development
Structural typing Not supported
NewType Not supported
Generics Not supported

Expressions

The following section describes the grammar of expressions that are supported in TorchScript. It is modeled after the expressions chapter of the Python language reference.

Arithmetic Conversions

There are a number of implicit type conversions that are performed in TorchScript:

Explicit conversions can be invoked using the float, int, bool, and str built-in functions that accept primitive data types as arguments and can accept user-defined types if they implement__bool__, __str__, etc.

Atoms

Atoms are the most basic elements of expressions.

atom ::= identifier | literal | enclosure enclosure ::= parenth_form | list_display | dict_display

Identifiers

The rules that dictate what is a legal identifier in TorchScript are the same as their Python counterparts.

Literals

literal ::= stringliteral | integer | floatnumber

Evaluation of a literal yields an object of the appropriate type with the specific value (with approximations applied as necessary for floats). Literals are immutable, and multiple evaluations of identical literals may obtain the same object or distinct objects with the same value.stringliteral,integer, andfloatnumberare defined in the same way as their Python counterparts.

Parenthesized Forms

parenth_form ::= '(' [expression_list] ')'

A parenthesized expression list yields whatever the expression list yields. If the list contains at least one comma, it yields a Tuple; otherwise, it yields the single expression inside the expression list. An empty pair of parentheses yields an empty Tuple object (Tuple[]).

List and Dictionary Displays

list_comprehension ::= expression comp_for comp_for ::= 'for' target_list 'in' or_expr list_display ::= '[' [expression_list | list_comprehension] ']' dict_display ::= '{' [key_datum_list | dict_comprehension] '}' key_datum_list ::= key_datum (',' key_datum)* key_datum ::= expression ':' expression dict_comprehension ::= key_datum comp_for

Lists and dicts can be constructed by either listing the container contents explicitly or by providing instructions on how to compute them via a set of looping instructions (i.e. a comprehension). A comprehension is semantically equivalent to using a for loop and appending to an ongoing list. Comprehensions implicitly create their own scope to make sure that the items of the target list do not leak into the enclosing scope. In the case that container items are explicitly listed, the expressions in the expression list are evaluated left-to-right. If a key is repeated in a dict_display that has a key_datum_list, the resultant dictionary uses the value from the rightmost datum in the list that uses the repeated key.

Primaries

primary ::= atom | attributeref | subscription | slicing | call

Attribute References

attributeref ::= primary '.' identifier

The primary must evaluate to an object of a type that supports attribute references that have an attribute namedidentifier.

Subscriptions

subscription ::= primary '[' expression_list ']'

The primary must evaluate to an object that supports subscription.

Slicings

A slicing selects a range of items in a str, Tuple, List, or Tensor. Slicings may be used as expressions or targets in assignment or del statements.

slicing ::= primary '[' slice_list ']' slice_list ::= slice_item (',' slice_item)* [','] slice_item ::= expression | proper_slice proper_slice ::= [expression] ':' [expression] [':' [expression] ]

Slicings with more than one slice item in their slice lists can only be used with primaries that evaluate to an object of type Tensor.

Calls

call ::= primary '(' argument_list ')' argument_list ::= args [',' kwargs] | kwargs args ::= [arg (',' arg)] kwargs ::= [kwarg (',' kwarg)] kwarg ::= arg '=' expression arg ::= identifier

The primary must desugar or evaluate to a callable object. All argument expressions are evaluated before the call is attempted.

Power Operator

power ::= primary ['**' u_expr]

The power operator has the same semantics as the built-in pow function (not supported); it computes its left argument raised to the power of its right argument. It binds more tightly than unary operators on the left, but less tightly than unary operators on the right; i.e. -2 ** -3 == -(2 ** (-3)). The left and right operands can be int, float or Tensor. Scalars are broadcast in the case of scalar-tensor/tensor-scalar exponentiation operations, and tensor-tensor exponentiation is done elementwise without any broadcasting.

Unary and Arithmetic Bitwise Operations

u_expr ::= power | '-' power | '~' power

The unary - operator yields the negation of its argument. The unary ~ operator yields the bitwise inversion of its argument. - can be used with int, float, and Tensor of int and float.~ can only be used with int and Tensor of int.

Binary Arithmetic Operations

m_expr ::= u_expr | m_expr '*' u_expr | m_expr '@' m_expr | m_expr '//' u_expr | m_expr '/' u_expr | m_expr '%' u_expr a_expr ::= m_expr | a_expr '+' m_expr | a_expr '-' m_expr

The binary arithmetic operators can operate on Tensor, int, and float. For tensor-tensor ops, both arguments must have the same shape. For scalar-tensor or tensor-scalar ops, the scalar is usually broadcast to the size of the tensor. Division ops can only accept scalars as their right-hand side argument, and do not support broadcasting. The @ operator is for matrix multiplication and only operates on Tensor arguments. The multiplication operator (*) can be used with a list and integer in order to get a result that is the original list repeated a certain number of times.

Shifting Operations

shift_expr ::= a_expr | shift_expr ( '<<' | '>>' ) a_expr

These operators accept two int arguments, two Tensor arguments, or a Tensor argument and an int orfloat argument. In all cases, a right shift by n is defined as floor division by pow(2, n), and a left shift by n is defined as multiplication by pow(2, n). When both arguments are Tensors, they must have the same shape. When one is a scalar and the other is a Tensor, the scalar is logically broadcast to match the size of the Tensor.

Binary Bitwise Operations

and_expr ::= shift_expr | and_expr '&' shift_expr xor_expr ::= and_expr | xor_expr '^' and_expr or_expr ::= xor_expr | or_expr '|' xor_expr

The & operator computes the bitwise AND of its arguments, the ^ the bitwise XOR, and the | the bitwise OR. Both operands must be int or Tensor, or the left operand must be Tensor and the right operand must beint. When both operands are Tensor, they must have the same shape. When the right operand is int, and the left operand is Tensor, the right operand is logically broadcast to match the shape of the Tensor.

Comparisons

comparison ::= or_expr (comp_operator or_expr)* comp_operator ::= '<' | '>' | '==' | '>=' | '<=' | '!=' | 'is' ['not'] | ['not'] 'in'

A comparison yields a boolean value (True or False), or if one of the operands is a Tensor, a booleanTensor. Comparisons can be chained arbitrarily as long as they do not yield boolean Tensors that have more than one element. a op1 b op2 c ... is equivalent to a op1 b and b op2 c and ....

Value Comparisons

The operators <, >, ==, >=, <=, and != compare the values of two objects. The two objects generally need to be of the same type, unless there is an implicit type conversion available between the objects. User-defined types can be compared if rich comparison methods (e.g., __lt__) are defined on them. Built-in type comparison works like Python:

Membership Test Operations

The operators in and not in test for membership. x in s evaluates to True if x is a member of s and False otherwise.x not in s is equivalent to not x in s. This operator is supported for lists, dicts, and tuples, and can be used with user-defined types if they implement the __contains__ method.

Identity Comparisons

For all types except int, double, bool, and torch.device, operators is and is not test for the object’s identity;x is y is True if and only if x and y are the same object. For all other types, is is equivalent to comparing them using ==. x is not y yields the inverse of x is y.

Boolean Operations

or_test ::= and_test | or_test 'or' and_test and_test ::= not_test | and_test 'and' not_test not_test ::= 'bool' '(' or_expr ')' | comparison | 'not' not_test

User-defined objects can customize their conversion to bool by implementing a __bool__ method. The operator notyields True if its operand is false, False otherwise. The expression x and y first evaluates x; if it is False, its value (False) is returned; otherwise, y is evaluated and its value is returned (False or True). The expression x or yfirst evaluates x; if it is True, its value (True) is returned; otherwise, y is evaluated and its value is returned (False or True).

Conditional Expressions

conditional_expression ::= or_expr ['if' or_test 'else' conditional_expression] expression ::= conditional_expression

The expression x if c else y first evaluates the condition c rather than x. If c is True, x is evaluated and its value is returned; otherwise, y is evaluated and its value is returned. As with if-statements,x and y must evaluate to a value of the same type.

Expression Lists

expression_list ::= expression (',' expression)* [','] starred_item ::= '*' primary

A starred item can only appear on the left-hand side of an assignment statement, e.g., a, *b, c = ....

Simple Statements

The following section describes the syntax of simple statements that are supported in TorchScript. It is modeled after the simple statements chapter of the Python language reference.

Expression Statements

expression_stmt ::= starred_expression starred_expression ::= expression | (starred_item ",")* [starred_item] starred_item ::= assignment_expression | "*" or_expr

Assignment Statements

assignment_stmt ::= (target_list "=")+ (starred_expression) target_list ::= target ("," target)* [","] target ::= identifier | "(" [target_list] ")" | "[" [target_list] "]" | attributeref | subscription | slicing | "*" target

Augmented Assignment Statements

augmented_assignment_stmt ::= augtarget augop (expression_list) augtarget ::= identifier | attributeref | subscription augop ::= "+=" | "-=" | "*=" | "/=" | "//=" | "%=" | "**="| ">>=" | "<<=" | "&=" | "^=" | "|="

Annotated Assignment Statements

annotated_assignment_stmt ::= augtarget ":" expression ["=" (starred_expression)]

The raise Statement

raise_stmt ::= "raise" [expression ["from" expression]]

Raise statements in TorchScript do not support try\except\finally.

The assert Statement

assert_stmt ::= "assert" expression ["," expression]

Assert statements in TorchScript do not support try\except\finally.

The return Statement

return_stmt ::= "return" [expression_list]

Return statements in TorchScript do not support try\except\finally.

The del Statement

del_stmt ::= "del" target_list

The pass Statement

The print Statement

print_stmt ::= "print" "(" expression [, expression] [.format{expression_list}] ")"

The break Statement

The continue Statement:

continue_stmt ::= "continue"

Compound Statements

The following section describes the syntax of compound statements that are supported in TorchScript. The section also highlights how Torchscript differs from regular Python statements. It is modeled after the compound statements chapter of the Python language reference.

The if Statement

Torchscript supports both basic if/else and ternary if/else.

Basic if/else Statement

if_stmt ::= "if" assignment_expression ":" suite ("elif" assignment_expression ":" suite) ["else" ":" suite]

elif statements can repeat for an arbitrary number of times, but it needs to be before else statement.

Ternary if/else Statement

if_stmt ::= return [expression_list] "if" assignment_expression "else" [expression_list]

Example 1

A tensor with 1 dimension is promoted to bool:

import torch

@torch.jit.script def fn(x: torch.Tensor): if x: # The tensor gets promoted to bool return True return False print(fn(torch.rand(1)))

The example above produces the following output:

Example 2

A tensor with multi dimensions are not promoted to bool:

import torch

Multi dimensional Tensors error out.

@torch.jit.script def fn(): if torch.rand(2): print("Tensor is available")

if torch.rand(4,5,6):
    print("Tensor is available")

print(fn())

Running the above code yields the following RuntimeError.

RuntimeError: The following operation failed in the TorchScript interpreter. Traceback of TorchScript (most recent call last): @torch.jit.script def fn(): if torch.rand(2): ~~~~~~~~~~~~ <--- HERE print("Tensor is available") RuntimeError: Boolean value of Tensor with more than one value is ambiguous

If a conditional variable is annotated as final, either the true or false branch is evaluated depending on the evaluation of the conditional variable.

Example 3

In this example, only the True branch is evaluated, since a is annotated as final and set to True:

import torch

a : torch.jit.final[Bool] = True

if a: return torch.empty(2,3) else: return []

The while Statement

while_stmt ::= "while" assignment_expression ":" suite

while…else statements are not supported in Torchscript. It results in a RuntimeError.

The for-in Statement

for_stmt ::= "for" target_list "in" expression_list ":" suite ["else" ":" suite]

for...else statements are not supported in Torchscript. It results in a RuntimeError.

Example 1

For loops on tuples: these unroll the loop, generating a body for each member of the tuple. The body must type-check correctly for each member.

import torch from typing import Tuple

@torch.jit.script def fn(): tup = (3, torch.ones(4)) for x in tup: print(x)

fn()

The example above produces the following output:

3 1 1 1 1 [ CPUFloatType{4} ]

Example 2

For loops on lists: for loops over a nn.ModuleList will unroll the body of the loop at compile time, with each member of the module list.

class SubModule(torch.nn.Module): def init(self): super().init() self.weight = nn.Parameter(torch.randn(2))

def forward(self, input):
    return self.weight + input

class MyModule(torch.nn.Module): def init(self): super().init() self.mods = torch.nn.ModuleList([SubModule() for i in range(10)])

def forward(self, v):
    for module in self.mods:
        v = module(v)
    return v

model = torch.jit.script(MyModule())

The with Statement

The with statement is used to wrap the execution of a block with methods defined by a context manager.

with_stmt ::= "with" with_item ("," with_item) ":" suite with_item ::= expression ["as" target]

The tuple Statement

tuple_stmt ::= tuple([iterables])

Unpacking all outputs into a tuple is covered by:

abc = func() # Function that returns a tuple a,b = func()

The getattr Statement

getattr_stmt ::= getattr(object, name[, default])

The hasattr Statement

hasattr_stmt ::= hasattr(object, name)

The zip Statement

zip_stmt ::= zip(iterable1, iterable2)

Example 1

Both the iterables must be of the same container type:

a = [1, 2] # List b = [2, 3, 4] # List zip(a, b) # works

Example 2

This example fails because the iterables are of different container types:

a = (1, 2) # Tuple b = [2, 3, 4] # List zip(a, b) # Runtime error

Running the above code yields the following RuntimeError.

RuntimeError: Can not iterate over a module list or tuple with a value that does not have a statically determinable length.

Example 3

Two iterables of the same container Type but different data type is supported:

a = [1.3, 2.4] b = [2, 3, 4] zip(a, b) # Works

Iterable types in TorchScript include Tensors, lists, tuples, dictionaries, strings, torch.nn.ModuleList, and torch.nn.ModuleDict.

The enumerate Statement

enumerate_stmt ::= enumerate([iterable])

Python Values

Resolution Rules

When given a Python value, TorchScript attempts to resolve it in the following five different ways:

Python Built-in Functions Support

TorchScript Support for Python Built-in Functions

Built-in Function Support Level Notes
abs() Partial Only supports Tensor/Int/Float type inputs. | Doesn’t honor __abs__ override.
all() Full
any() Full
ascii() None
bin() Partial Only supports Int type input.
bool() Partial Only supports Tensor/Int/Float type inputs.
breakpoint() None
bytearray() None
bytes() None
callable() None
chr() Partial Only ASCII character set is supported.
classmethod() Full
compile() None
complex() None
delattr() None
dict() Full
dir() None
divmod() Full
enumerate() Full
eval() None
exec() None
filter() None
float() Partial Doesn’t honor __index__ override.
format() Partial Manual index specification not supported. | Format type modifier not supported.
frozenset() None
getattr() Partial Attribute name must be string literal.
globals() None
hasattr() Partial Attribute name must be string literal.
hash() Full Tensor’s hash is based on identity not numeric value.
hex() Partial Only supports Int type input.
id() Full Only supports Int type input.
input() None
int() Partial base argument not supported. | Doesn’t honor __index__ override.
isinstance() Full torch.jit.isintance provides better support when checking against container types like Dict[str, int].
issubclass() None
iter() None
len() Full
list() Full
ord() Partial Only ASCII character set is supported.
pow() Full
print() Partial separate, end and file arguments are not supported.
property() None
range() Full
repr() None
reversed() None
round() Partial ndigits argument is not supported.
set() None
setattr() None
slice() Full
sorted() Partial key argument is not supported.
staticmethod() Full
str() Partial encoding and errors arguments are not supported.
sum() Full
super() Partial It can only be used in nn.Module’s __init__ method.
type() None
vars() None
zip() Full
__import__() None

Python Built-in Values Support

TorchScript Support for Python Built-in Values

Built-in Value Support Level Notes
False Full
True Full
None Full
NotImplemented None
Ellipsis Full

torch.* APIs

Remote Procedure Calls

TorchScript supports a subset of RPC APIs that supports running a function on a specified remote worker instead of locally.

Specifically, following APIs are fully supported:

Asynchronous Execution

TorchScript enables you to create asynchronous computation tasks to make better use of computation resources. This is done via supporting a list of APIs that are only usable within TorchScript:

Type Annotations

TorchScript is statically-typed. It provides and supports a set of utilities to help annotate variables and attributes:

Type Refinement