馃悰 [Bug] Compilation Error on GPT-2 路 Issue #1455 路 pytorch/TensorRT (original) (raw)
Bug Description
When converting the GPT-2 network (https://huggingface.co/gpt2) from TorchScript to Torch-TRT, the following error is encountered:
compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec)) RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:167] Unsupported input data type unsigned char
To Reproduce
Steps to reproduce the behavior:
- Run torch_tensorrt.compile with GPT-2 model as input, using fp32 precision.
- Choose fixed input size of [1, 128] and enable truncate_long_and_double with 12 GB workspace.
- Pass in model keyword args to disable attention and hidden state outputs
Expected behavior
Model should successfully compile to Torch-TRT. Specifically, internal (non-user-provided) type-casting issues should not cause errors.
Environment
- Torch-TensorRT Version: 1.3.0a0+e3b99294
- PyTorch Version: 1.13.0.dev20220921+cu116
- CPU Architecture: Intel Xeon CPU
- OS: Ubuntu 20.04
- How you installed PyTorch: pip
- Build command you used:
python setup.py develop
- Are you using local sources or building from archives: local
- Python version: 3.8.13
- CUDA version: 11.6
Additional context
The problematic data in GPT-2 seems to be this bias term, instantiated in the attention module, which has type uint8
. In both the TorchScript IR and the model code (example 1, example 2), it seems that this bias term is generally cast to a bool. The error is thrown in this code segment:
c10::optionalnvinfer1::DataType dtype = util::optTypeMetaToTRTDataType(cur_ivalue.toTensor().dtype()); |
---|
if (dtype == c10::nullopt) { |
TORCHTRT_THROW_ERROR("Unsupported input data type " << cur_ivalue.toTensor().dtype()); |
The conversion of a uint8
type to a TRT Data Type fails, however simply patching this conversion also does not fix the issue, as an out-of-bounds error later follows.
Temporary Solution
A temporary fix to this problem is to add the following to the compilation arguments in torch_tensorrt.compile:
torch_tensorrt.compile( ..., torch_executed_ops=["aten::where"], ...)
This solution works as it happens to exclude the code which uses and processes the uint8
tensor, however it is only a temporary fix and does not resolve the underlying issue.
Steps to a Solution
- Fix mismatched dimension issue in
aten::where
- Make
at::kByte
a valid input type