馃悰 [Bug] Expected to find type str for value why_not_sparsity_fast_path.76 but get nothing. when trying to partially compile TransformerEncoder 路 Issue #1756 路 pytorch/TensorRT (original) (raw)
Bug Description
When trying to partially compile transformer encoder, you are met with the following error:
Expected to find type str for value why_not_sparsity_fast_path.76 but get nothing.
To Reproduce
Steps to reproduce the behavior:
import torch import torch.nn as nn
import torch_tensorrt
class TransformerModel(nn.Module): def init(self, input_dim, hidden_dim, num_layers, num_heads): super(TransformerModel, self).init()
# define embedding layer
self.embedding = nn.Embedding(input_dim, hidden_dim)
# define transformer encoder
self.transformer_encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_heads),
num_layers=num_layers
)
# define output layer
self.fc = nn.Linear(hidden_dim, input_dim)
def forward(self, x):
# apply embedding layer
x = self.embedding(x)
# apply transformer encoder
x = self.transformer_encoder(x)
# apply output layer
x = self.fc(x)
return x
model = TransformerModel(input_dim=100, hidden_dim=128, num_layers=2, num_heads=4)
input_data = torch.randint(low=0, high=100, size=(32,10)) # sequence length of 10, batch size of 32
input_data = input_data.to("cuda").to(torch.int) model.to("cuda") output = model(input_data) model.eval()
inputs = [ torch_tensorrt.Input( min_shape=[32,10], opt_shape=[32,10], max_shape=[32,10], dtype=torch.int, )]
enabled_precisions = {torch.float, torch.half} # Run with fp16
with torch_tensorrt.logging.graphs(): trt_ts_module = torch_tensorrt.compile( model, inputs=inputs, enabled_precisions=enabled_precisions, require_full_compilation=True )
result = trt_ts_module(input_data)
with open("../saved_models/trt_ts_module.ts", "wb") as f: torch.jit.save(trt_ts_module, f)
Expected behavior
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0): main
- PyTorch Version (e.g. 1.0): 2.0
- CPU Architecture: x64
- OS (e.g., Linux): Linux
- How you installed
PyTorch
(conda
,pip
,libtorch
, source): source - Build command you used (if compiling from source): pip install -e py
- Are you using local sources or building from archives: archives
- Python version: 3.9
- CUDA version: 11.8
- GPU models and configuration: 3080Ti
- Any other relevant information: