fix: Add lowering pass to remove output repacking in convert_method_to_trt_engine
calls by gs-olive · Pull Request #1945 · pytorch/TensorRT (original) (raw)
Description
- Adds improved support for full-conversion and TRT engine export for a variety of models
- Automatically remove output repacking for
convert_method_to_trt_engine
calls, to improve parity between models which can be converted directly to TRT engines, and models which can be fully compiled - Add new internal
CompileSpec
argument for lowering which indicates whether the lowering passes originate from aconvert_method_to_trt_engine
call or a regularcompile
call, which affects whether the lowering pass is applied - Regular TorchScript graphs cannot have this pass applied, as it can otherwise break the output graph. Newer versions of Torch disallow graph outputs with 0 or 2+ arguments which are not packed in a struct
- Current lowering pass detects outputs which are flat Lists or Tuples of Tensors and returns the outputs as-is (direct from the TRT Engine), so the entire model can be converted to a single TRT engine
Type of change
- Bug fix (non-breaking change which fixes an issue)
- New feature (non-breaking change which adds functionality)
Checklist:
- [ x ] My code follows the style guidelines of this project (You can use the linters)
- [ x ] I have performed a self-review of my own code
- [ x ] I have commented my code, particularly in hard-to-understand areas and hacks
- [ x ] I have made corresponding changes to the documentation
- [ x ] I have added tests to verify my fix or my feature
- [ x ] New and existing unit tests pass locally with my changes
- [ x ] I have added the relevant labels to my PR in so that relevant reviewers are notified
- Automatically remove output repacking for
convert_method_to_trt_engine
calls, to improve parity between models which can be converted directly to TRT engines, and models which can be fully compiled - Add new internal
CompileSpec
argument for lowering which indicates whether the lowering passes originate from aconvert_method_to_trt_engine
call or a regularcompile
call, which affects whether the lowering pass is applied - Regular TorchScript graphs cannot have this pass applied, as it can otherwise break the output graph. Newer versions of Torch disallow graph outputs with 0 or 2+ arguments which are not packed in a struct
- Current lowering pass detects outputs which are flat Lists or Tuples of Tensors and returns the outputs as-is (direct from the TRT Engine), so the entire model can be converted to a single TRT engine
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
gs-olive deleted the remove_output_collection_casts branch
narendasan pushed a commit that referenced this pull request
…to_trt_engine` calls (#1945)
narendasan pushed a commit that referenced this pull request
…to_trt_engine` calls (#1945)
narendasan pushed a commit that referenced this pull request
…to_trt_engine` calls (#1945)