torch.testing — PyTorch 2.7 documentation (original) (raw)
torch.testing.assert_close(actual, expected, *, allow_subclasses=True, rtol=None, atol=None, equal_nan=False, check_device=True, check_dtype=True, check_layout=True, check_stride=False, msg=None)[source][source]¶
Asserts that actual
and expected
are close.
If actual
and expected
are strided, non-quantized, real-valued, and finite, they are considered close if
∣actual−expected∣≤atol+rtol⋅∣expected∣\lvert \text{actual} - \text{expected} \rvert \le \texttt{atol} + \texttt{rtol} \cdot \lvert \text{expected} \rvert
Non-finite values (-inf
and inf
) are only considered close if and only if they are equal. NaN
’s are only considered equal to each other if equal_nan
is True
.
In addition, they are only considered close if they have the same
- device (if
check_device
isTrue
), dtype
(ifcheck_dtype
isTrue
),layout
(ifcheck_layout
isTrue
), and- stride (if
check_stride
isTrue
).
If either actual
or expected
is a meta tensor, only the attribute checks will be performed.
If actual
and expected
are sparse (either having COO, CSR, CSC, BSR, or BSC layout), their strided members are checked individually. Indices, namely indices
for COO, crow_indices
and col_indices
for CSR and BSR, or ccol_indices
and row_indices
for CSC and BSC layouts, respectively, are always checked for equality whereas the values are checked for closeness according to the definition above.
If actual
and expected
are quantized, they are considered close if they have the sameqscheme() and the result of dequantize() is close according to the definition above.
actual
and expected
can be Tensor’s or any tensor-or-scalar-likes from whichtorch.Tensor’s can be constructed with torch.as_tensor(). Except for Python scalars the input types have to be directly related. In addition, actual
and expected
can be Sequence’s or Mapping’s in which case they are considered close if their structure matches and all their elements are considered close according to the above definition.
Note
Python scalars are an exception to the type relation requirement, because their type()
, i.e.int, float, and complex, is equivalent to the dtype
of a tensor-like. Thus, Python scalars of different types can be checked, but require check_dtype=False
.
Parameters
- actual (Any) – Actual input.
- expected (Any) – Expected input.
- allow_subclasses (bool) – If
True
(default) and except for Python scalars, inputs of directly related types are allowed. Otherwise type equality is required. - rtol (Optional _[_float]) – Relative tolerance. If specified
atol
must also be specified. If omitted, default values based on thedtype
are selected with the below table. - atol (Optional _[_float]) – Absolute tolerance. If specified
rtol
must also be specified. If omitted, default values based on thedtype
are selected with the below table. - equal_nan (Union [_bool,_ str]) – If
True
, twoNaN
values will be considered equal. - check_device (bool) – If
True
(default), asserts that corresponding tensors are on the samedevice. If this check is disabled, tensors on differentdevice’s are moved to the CPU before being compared. - check_dtype (bool) – If
True
(default), asserts that corresponding tensors have the samedtype
. If this check is disabled, tensors with differentdtype
’s are promoted to a commondtype
(according totorch.promote_types()) before being compared. - check_layout (bool) – If
True
(default), asserts that corresponding tensors have the samelayout
. If this check is disabled, tensors with differentlayout
’s are converted to strided tensors before being compared. - check_stride (bool) – If
True
and corresponding tensors are strided, asserts that they have the same stride. - msg (Optional [_ _Union_ _[_str,_ Callable _[_ _[_str] , str] ] ]) – Optional error message to use in case a failure occurs during the comparison. Can also passed as callable in which case it will be called with the generated message and should return the new message.
Raises
- ValueError – If no torch.Tensor can be constructed from an input.
- ValueError – If only
rtol
oratol
is specified. - AssertionError – If corresponding inputs are not Python scalars and are not directly related.
- AssertionError – If
allow_subclasses
isFalse
, but corresponding inputs are not Python scalars and have different types. - AssertionError – If the inputs are Sequence’s, but their length does not match.
- AssertionError – If the inputs are Mapping’s, but their set of keys do not match.
- AssertionError – If corresponding tensors do not have the same shape.
- AssertionError – If
check_layout
isTrue
, but corresponding tensors do not have the samelayout
. - AssertionError – If only one of corresponding tensors is quantized.
- AssertionError – If corresponding tensors are quantized, but have different qscheme()’s.
- AssertionError – If
check_device
isTrue
, but corresponding tensors are not on the samedevice. - AssertionError – If
check_dtype
isTrue
, but corresponding tensors do not have the samedtype
. - AssertionError – If
check_stride
isTrue
, but corresponding strided tensors do not have the same stride. - AssertionError – If the values of corresponding tensors are not close according to the definition above.
The following table displays the default rtol
and atol
for different dtype
’s. In case of mismatchingdtype
’s, the maximum of both tolerances is used.
dtype | rtol | atol |
---|---|---|
float16 | 1e-3 | 1e-5 |
bfloat16 | 1.6e-2 | 1e-5 |
float32 | 1.3e-6 | 1e-5 |
float64 | 1e-7 | 1e-7 |
complex32 | 1e-3 | 1e-5 |
complex64 | 1.3e-6 | 1e-5 |
complex128 | 1e-7 | 1e-7 |
quint8 | 1.3e-6 | 1e-5 |
quint2x4 | 1.3e-6 | 1e-5 |
quint4x2 | 1.3e-6 | 1e-5 |
qint8 | 1.3e-6 | 1e-5 |
qint32 | 1.3e-6 | 1e-5 |
other | 0.0 | 0.0 |
Note
assert_close() is highly configurable with strict default settings. Users are encouraged to partial() it to fit their use case. For example, if an equality check is needed, one might define an assert_equal
that uses zero tolerances for every dtype
by default:
import functools assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0) assert_equal(1e-9, 1e-10) Traceback (most recent call last): ... AssertionError: Scalars are not equal!
Expected 1e-10 but got 1e-09. Absolute difference: 9.000000000000001e-10 Relative difference: 9.0
Examples
tensor to tensor comparison
expected = torch.tensor([1e0, 1e-1, 1e-2]) actual = torch.acos(torch.cos(expected)) torch.testing.assert_close(actual, expected)
scalar to scalar comparison
import math expected = math.sqrt(2.0) actual = 2.0 / math.sqrt(2.0) torch.testing.assert_close(actual, expected)
numpy array to numpy array comparison
import numpy as np expected = np.array([1e0, 1e-1, 1e-2]) actual = np.arccos(np.cos(expected)) torch.testing.assert_close(actual, expected)
sequence to sequence comparison
import numpy as np
The types of the sequences do not have to match. They only have to have the same
length and their elements have to match.
expected = [torch.tensor([1.0]), 2.0, np.array(3.0)] actual = tuple(expected) torch.testing.assert_close(actual, expected)
mapping to mapping comparison
from collections import OrderedDict import numpy as np foo = torch.tensor(1.0) bar = 2.0 baz = np.array(3.0)
The types and a possible ordering of mappings do not have to match. They only
have to have the same set of keys and their elements have to match.
expected = OrderedDict([("foo", foo), ("bar", bar), ("baz", baz)]) actual = {"baz": baz, "bar": bar, "foo": foo} torch.testing.assert_close(actual, expected)
expected = torch.tensor([1.0, 2.0, 3.0]) actual = expected.clone()
By default, directly related instances can be compared
torch.testing.assert_close(torch.nn.Parameter(actual), expected)
This check can be made more strict with allow_subclasses=False
torch.testing.assert_close( ... torch.nn.Parameter(actual), expected, allow_subclasses=False ... ) Traceback (most recent call last): ... TypeError: No comparison pair was able to handle inputs of type <class 'torch.nn.parameter.Parameter'> and <class 'torch.Tensor'>.
If the inputs are not directly related, they are never considered close
torch.testing.assert_close(actual.numpy(), expected) Traceback (most recent call last): ... TypeError: No comparison pair was able to handle inputs of type <class 'numpy.ndarray'> and <class 'torch.Tensor'>.
Exceptions to these rules are Python scalars. They can be checked regardless of
their type if check_dtype=False.
torch.testing.assert_close(1.0, 1, check_dtype=False)
NaN != NaN by default.
expected = torch.tensor(float("Nan")) actual = expected.clone() torch.testing.assert_close(actual, expected) Traceback (most recent call last): ... AssertionError: Scalars are not close!
Expected nan but got nan. Absolute difference: nan (up to 1e-05 allowed) Relative difference: nan (up to 1.3e-06 allowed)
torch.testing.assert_close(actual, expected, equal_nan=True)
expected = torch.tensor([1.0, 2.0, 3.0]) actual = torch.tensor([1.0, 4.0, 5.0])
The default error message can be overwritten.
torch.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!") Traceback (most recent call last): ... AssertionError: Argh, the tensors are not close!
If msg is a callable, it can be used to augment the generated message with
extra information
torch.testing.assert_close( ... actual, expected, msg=lambda msg: f"Header\n\n{msg}\n\nFooter" ... ) Traceback (most recent call last): ... AssertionError: Header
Tensor-likes are not close!
Mismatched elements: 2 / 3 (66.7%) Greatest absolute difference: 2.0 at index (1,) (up to 1e-05 allowed) Greatest relative difference: 1.0 at index (1,) (up to 1.3e-06 allowed)
Footer
torch.testing.make_tensor(*shape, dtype, device, low=None, high=None, requires_grad=False, noncontiguous=False, exclude_zero=False, memory_format=None)[source][source]¶
Creates a tensor with the given shape
, device
, and dtype
, and filled with values uniformly drawn from [low, high)
.
If low
or high
are specified and are outside the range of the dtype
’s representable finite values then they are clamped to the lowest or highest representable finite value, respectively. If None
, then the following table describes the default values for low
and high
, which depend on dtype
.
dtype | low | high |
---|---|---|
boolean type | 0 | 2 |
unsigned integral type | 0 | 10 |
signed integral types | -9 | 10 |
floating types | -9 | 9 |
complex types | -9 | 9 |
Parameters
- shape (Tuple [_int,_ ... ]) – Single integer or a sequence of integers defining the shape of the output tensor.
- dtype (torch.dtype) – The data type of the returned tensor.
- device (Union [_str,_ torch.device]) – The device of the returned tensor.
- low (Optional [ Number ]) – Sets the lower limit (inclusive) of the given range. If a number is provided it is clamped to the least representable finite value of the given dtype. When
None
(default), this value is determined based on thedtype
(see the table above). Default:None
. - high (Optional [ Number ]) –
Sets the upper limit (exclusive) of the given range. If a number is provided it is clamped to the greatest representable finite value of the given dtype. WhenNone
(default) this value is determined based on thedtype
(see the table above). Default:None
.
Deprecated since version 2.1: Passinglow==high
to make_tensor() for floating or complex types is deprecated since 2.1 and will be removed in 2.3. Use torch.full() instead. - requires_grad (Optional _[_bool]) – If autograd should record operations on the returned tensor. Default:
False
. - noncontiguous (Optional _[_bool]) – If True, the returned tensor will be noncontiguous. This argument is ignored if the constructed tensor has fewer than two elements. Mutually exclusive with
memory_format
. - exclude_zero (Optional _[_bool]) – If
True
then zeros are replaced with the dtype’s small positive value depending on thedtype
. For bool and integer types zero is replaced with one. For floating point types it is replaced with the dtype’s smallest positive normal number (the “tiny” value of thedtype
’sfinfo()
object), and for complex types it is replaced with a complex number whose real and imaginary parts are both the smallest positive normal number representable by the complex type. DefaultFalse
. - memory_format (Optional _[_torch.memory_format]) – The memory format of the returned tensor. Mutually exclusive with
noncontiguous
.
Raises
- ValueError – If
requires_grad=True
is passed for integral dtype - ValueError – If
low >= high
. - ValueError – If either
low
orhigh
isnan
. - ValueError – If both
noncontiguous
andmemory_format
are passed. - TypeError – If
dtype
isn’t supported by this function.
Return type
Examples
from torch.testing import make_tensor
Creates a float tensor with values in [-1, 1)
make_tensor((3,), device='cpu', dtype=torch.float32, low=-1, high=1) tensor([ 0.1205, 0.2282, -0.6380])
Creates a bool tensor on CUDA
make_tensor((2, 2), device='cuda', dtype=torch.bool) tensor([[False, False], [False, True]], device='cuda:0')
torch.testing.assert_allclose(actual, expected, rtol=None, atol=None, equal_nan=True, msg='')[source][source]¶
Warning
torch.testing.assert_allclose() is deprecated since 1.12
and will be removed in a future release. Please use torch.testing.assert_close() instead. You can find detailed upgrade instructionshere.