馃悰 [Bug] Expected to find type str for value why_not_sparsity_fast_path.76 but get nothing. when trying to partially compile TransformerEncoder 路 Issue #1756 路 pytorch/TensorRT (original) (raw)

Bug Description

When trying to partially compile transformer encoder, you are met with the following error:

Expected to find type str for value why_not_sparsity_fast_path.76 but get nothing.

To Reproduce

Steps to reproduce the behavior:

import torch import torch.nn as nn

import torch_tensorrt

class TransformerModel(nn.Module): def init(self, input_dim, hidden_dim, num_layers, num_heads): super(TransformerModel, self).init()

    # define embedding layer
    self.embedding = nn.Embedding(input_dim, hidden_dim)

    # define transformer encoder
    self.transformer_encoder = nn.TransformerEncoder(
        nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_heads),
        num_layers=num_layers
    )

    # define output layer
    self.fc = nn.Linear(hidden_dim, input_dim)

def forward(self, x):
    # apply embedding layer
    x = self.embedding(x)

    # apply transformer encoder
    x = self.transformer_encoder(x)

    # apply output layer
    x = self.fc(x)

    return x

model = TransformerModel(input_dim=100, hidden_dim=128, num_layers=2, num_heads=4)

input_data = torch.randint(low=0, high=100, size=(32,10)) # sequence length of 10, batch size of 32

input_data = input_data.to("cuda").to(torch.int) model.to("cuda") output = model(input_data) model.eval()

inputs = [ torch_tensorrt.Input( min_shape=[32,10], opt_shape=[32,10], max_shape=[32,10], dtype=torch.int, )]

enabled_precisions = {torch.float, torch.half} # Run with fp16

with torch_tensorrt.logging.graphs(): trt_ts_module = torch_tensorrt.compile( model, inputs=inputs, enabled_precisions=enabled_precisions, require_full_compilation=True )

result = trt_ts_module(input_data)

with open("../saved_models/trt_ts_module.ts", "wb") as f: torch.jit.save(trt_ts_module, f)

Expected behavior

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

Additional context