fix/feat: Add lowering pass to resolve most aten::Int.Tensor
uses by gs-olive · Pull Request #1937 · pytorch/TensorRT (original) (raw)
Description
- Adds improved support for full-conversion for a variety of models
- Implement lowering pass which detects canonical
aten::Int.Tensor
cases and recursively replaces input Value pointers until all 0D tensors have been resolved to their scalar components - Lowering pass is specialized to replacing strictly integer-typed Value pointers and can only trace through aten::mul and aten::floor_divide operators, which are two of the most common cases of use
- Lowering pass traverses the graph until one of three base cases are encountered (or an invalid Value type is detected). These cases are
prim::NumToTensor
,prim::Constant
(0D tensor), or simple integers. It then replaces the child nodes with the integer equivalents of the produced Tensors - Added extensive testing of new capabilities for accuracy, robustness, and functionality
Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
Fixes #1880
Fixes #1836
Fixes #513
Type of change
- Bug fix (non-breaking change which fixes an issue)
- New feature (non-breaking change which adds functionality)
Checklist:
- [ x ] My code follows the style guidelines of this project (You can use the linters)
- [ x ] I have performed a self-review of my own code
- [ x ] I have commented my code, particularly in hard-to-understand areas and hacks
- [ x ] I have made corresponding changes to the documentation
- [ x ] I have added tests to verify my fix or my feature
- [ x ] New and existing unit tests pass locally with my changes
- [ x ] I have added the relevant labels to my PR in so that relevant reviewers are notified
gs-olive changed the title
fix/feat: Add lowering pass to resolve aten::Int.Tensor fix/feat: Add lowering pass to resolve most aten::Int.Tensor
invocations
gs-olive changed the title
fix/feat: Add lowering pass to resolve most fix/feat: Add lowering pass to resolve most aten::Int.Tensor
invocationsaten::Int.Tensor
uses
- Implement lowering pass which detects canonical
aten::Int.Tensor
cases and recursively replaces input Value pointers until all 0D tensors have been resolved to their scalar components - Lowering pass is specialized to replacing strictly integer-typed Value pointers and can only trace through aten::mul and aten::floor_divide operators, which are two of the most common cases of use
- Lowering pass traverses the graph until one of three base cases are
encountered (or an invalid Value type is detected). These cases are
prim::NumToTensor
,prim::Constant
(0D tensor), or simple integers. It then replaces the child nodes with the integer equivalents of the produced Tensors - Added extensive testing of new capabilities for accuracy, robustness, and functionality
torch::jit::aten::floor_divide, |
---|
}; |
c10::optionaltorch::jit::Value\* Validate0DTensor(torch::jit::Value* value) { |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactored to use c10::optional
wrapper instead of nullptr
+ replaced pointer checks with .has_value()
- Edit in favor of
c10::optional
type usage
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
gs-olive deleted the replace_aten_int_schema branch
narendasan pushed a commit that referenced this pull request
narendasan pushed a commit that referenced this pull request
narendasan pushed a commit that referenced this pull request