fix: Repair invalid schema arising from lowering pass by gs-olive · Pull Request #1786 · pytorch/TensorRT (original) (raw)
Description
- When
remove_unnecessary_casts
replaces both tensors in anaten::div
call with the corresponding scalars, and the rounding mode (third argument) is specified, the schema becomes invalid - To avoid this, we intercept the call and split the result based on what the third argument in the input was, either delegating the result to
aten::Int(aten::div(...))
, for "trunc" oraten::floordiv(...)
for "floor" - Update lowering pass to incorporate the above change, with appropriate documentation
- Add testing to verify catching and correctly handling these edge cases
Fixes #1780
Type of change
Please delete options that are not relevant and/or add your own.
- Bug fix (non-breaking change which fixes an issue)
Checklist:
- [ x ] My code follows the style guidelines of this project (You can use the linters)
- [ x ] I have performed a self-review of my own code
- [ x ] I have commented my code, particularly in hard-to-understand areas and hacks
- [ x ] I have made corresponding changes to the documentation
- [ x ] I have added tests to verify my fix or my feature
- [ x ] New and existing unit tests pass locally with my changes
- [ x ] I have added the relevant labels to my PR in so that relevant reviewers are notified
- When
remove_unnecessary_casts
replaces both tensors in anaten::div
call with the corresponding scalars, and the rounding mode (third argument) is specified, the schema becomes invalid - To avoid this, we intercept the call and split the result based on
what the third argument in the input was, either delegating the result
to
aten::Int(aten::div(...))
, for "trunc" oraten::floordiv(...)
for "floor" - Update lowering pass to incorporate the above change, with appropriate documentation
- Add testing to verify catching and correctly handling these edge cases
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
bowang007 pushed a commit that referenced this pull request
bowang007 pushed a commit that referenced this pull request