🐛 [Bug] failed correctness check when using F.interpolate(align_corners=False) (original) (raw)

Bug Description

INFO:torch_tensorrt.fx.passes.pass_utils:== Log pass <function fuse_permute_matmul at 0x000001F8996789D0> before/after graph to C:\Users\Holy\AppData\Local\Temp\tmpsek2ezsi, before/after are the same = True
INFO:torch_tensorrt.fx.passes.pass_utils:== Log pass <function fuse_permute_linear at 0x000001F899678790> before/after graph to C:\Users\Holy\AppData\Local\Temp\tmpp6a4seyh, before/after are the same = True

Supported node types in the model:
acc_ops.interpolate: ((), {'input': torch.float32})

Unsupported node types in the model:

Got 1 acc subgraphs and 0 non-acc subgraphs
INFO:torch_tensorrt.fx.passes.lower_pass_manager_builder:Now lowering submodule _run_on_acc_0
INFO:torch_tensorrt.fx.lower:split_name=_run_on_acc_0, input_specs=[InputTensorSpec(shape=torch.Size([1, 3, 256, 256]), dtype=torch.float32, device=device(type='cuda', index=0), shape_ranges=[], has_batch_dim=True)]
INFO:torch_tensorrt.fx.lower:Timing cache is used!
INFO:torch_tensorrt.fx.fx2trt:TRT INetwork construction elapsed time: 0:00:00
[12/17/2022-14:17:40] [TRT] [W] TensorRT was linked against cuDNN 8.6.0 but loaded cuDNN 8.5.0
INFO:torch_tensorrt.fx.fx2trt:Build TRT engine elapsed time: 0:00:00.809599
INFO:torch_tensorrt.fx.passes.lower_pass_manager_builder:Lowering submodule _run_on_acc_0 elapsed time 0:00:04.167018
Traceback (most recent call last):
  File "C:\Users\Holy\Downloads\test.py", line 22, in <module>
    trt_mod = compile(
  File "C:\Python310\lib\site-packages\torch_tensorrt\fx\lower.py", line 88, in compile
    return lowerer(module, input)
  File "C:\Python310\lib\site-packages\torch_tensorrt\fx\lower.py", line 323, in __call__
    return do_lower(module, inputs)
  File "C:\Python310\lib\site-packages\torch_tensorrt\fx\passes\pass_utils.py", line 155, in pass_with_validation
    raise e
  File "C:\Python310\lib\site-packages\torch_tensorrt\fx\passes\pass_utils.py", line 141, in pass_with_validation
    torch.testing.assert_close(x, y, **kwargs2)
  File "C:\Python310\lib\site-packages\torch\testing\_comparison.py", line 1342, in assert_close
    assert_equal(
  File "C:\Python310\lib\site-packages\torch\testing\_comparison.py", line 1093, in assert_equal
    raise error_metas[0].to_error(msg)
AssertionError: Pass <function Lowerer.__call__.<locals>.do_lower at 0x000001F8996D3E20> failed correctness check due at output 0:
Tensor-likes are not close!

Mismatched elements: 353039 / 3145728 (11.2%)
Greatest absolute difference: 0.4991211108863354 at index (0, 2, 24, 4) (up to 0.1 allowed)
Greatest relative difference: 56.67618494203449 at index (0, 1, 4, 4) (up to 0.1 allowed)

To Reproduce

import torch from torch import nn from torch.nn import functional as F from torch_tensorrt.fx import compile from torch_tensorrt.fx.utils import LowerPrecision

class MyModule(nn.Module): def init(self): super(MyModule, self).init()

def forward(self, x):
    return F.interpolate(x, scale_factor=4, mode='bilinear', align_corners=False)

if name == 'main': device = torch.device('cuda') x = torch.rand(1, 3, 256, 256, dtype=torch.float32, device=device)

with torch.inference_mode():
    mod = MyModule().eval().to(device)
    trt_mod = compile(
        mod,
        [x],
        min_acc_module_size=1,
        explicit_batch_dimension=True,
        lower_precision=LowerPrecision.FP32,
        dynamic_batch=False,
    )

Environment