🐛 [Bug] Support for modules with multiple outputs seems broken in v1.2.0 (original) (raw)

Bug Description

It appears that modules with multiple outputs no longer compile when using dynamic input shapes in v1.2.0.

The following example works in v1.1.1 but fails in v1.2.0

import torch
import torch.nn as nn
import torch_tensorrt as trt

from torch import Tensor
from typing import List, Tuple

trt.logging.set_reportable_log_level(trt.logging.Level.Debug)

class Net(nn.Module):
    def __init__(self):
        nn.Module.__init__(self)

        self.h = nn.Conv2d(1, 4, 3, padding=1)
        self.g = nn.Conv2d(1, 4, 3, padding=1)

    def forward(self, x) -> Tuple[Tensor, Tensor]:
        return self.h(x), self.g(x)

model = Net().eval()
model = torch.jit.trace(model, torch.randn(1, 1, 128, 128))
model = trt.compile(
    model.cuda(),
    inputs=[
        trt.Input(min_shape=(1, 1, 128, 128),
                  opt_shape=(4, 1, 256, 256),
                  max_shape=(8, 1, 512, 512))
    ],
    min_block_size=1
    require_full_compilation=True
)

Fails with error:

RuntimeError: [Error thrown at core/conversion/conversion.cpp:230] Tuple type. Only a single tensor or a TensorList type is supported.

In v1.1.1, the graph returns two output tensors - while in v1.2.0 it creates an intermediate node to (%13) and returns a single TupleConstruct output. Unfortunately MarkOutputs in core/conversion/converter.cpp now only gets a single tuple output and throws an error.

Graphs are given below:

v1.1.1

  %11 : Tensor = aten::_convolution(%x, %self.h.weight, %self.h.bias, %3, %3, %3, %5, %2, %4, %5, %5, %6, %6), scope: __module.h # /home/cjenkins/dev/mat/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:453:0
  %12 : Tensor = aten::_convolution(%x, %self.g.weight, %self.g.bias, %3, %3, %3, %5, %2, %4, %5, %5, %6, %6), scope: __module.g # /home/cjenkins/dev/mat/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:453:0
  return (%11, %12)

v1.2.0

  %11 : Tensor = aten::_convolution(%x, %self.h.weight, %self.h.bias, %3, %3, %3, %5, %2, %4, %5, %5, %6, %6), scope: __module.h # /home/cjenkins/dev/mat/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:453:0
  %12 : Tensor = aten::_convolution(%x, %self.g.weight, %self.g.bias, %3, %3, %3, %5, %2, %4, %5, %5, %6, %6), scope: __module.g # /home/cjenkins/dev/mat/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:453:0
  %13 : (Float(1, 4, 128, 128, strides=[65536, 16384, 128, 1], requires_grad=0, device=cpu), Float(1, 4, 128, 128, strides=[65536, 16384, 128, 1], requires_grad=0, device=cpu)) = prim::TupleConstruct(%11, %12)
  return (%13)

Expected behavior

A return type of Tuple[Tensor, Tensor] should be treated as two separate outputs - not one.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

Additional context