✨[Feature] Remove requirement for require_full_compilation=False when using input_signature (original) (raw)
Feature Context
Models which are fully supported in TRT, except for their input type being a collection should be able to be fully-compiled in Torch-TRT. Considering that Torch-executed list packing and list unpacking code is already being inserted (by necessity) even when models are fully supported, there should not be a need to disable full compilation when providing complex input types. Additionally, operators including prim::ListUnpack should not be added to torch_executed_ops automatically upon using input_signature, as they are currently, since evaluators for them exist.
Desired Solution
The preferred solution is to remove the requirement for require_full_compilation=False when using input_signature and to remove the requirement that collection-based operators be executed in fallback:
| elif compile_spec["input_signature"] is not None: |
|---|
| log( |
| Level.Warning, |
| "Input signature parsing is an experimental feature, behavior and APIs may change", |
| ) |
| signature = _parse_input_signature(compile_spec["input_signature"]) |
| info.input_signature = _C.InputSignature(signature) # py_object |
| if not compile_spec["torch_fallback"]["enabled"]: |
| raise ValueError( |
| "Grouped inputs currently requires partial compilation to be enabled, this restriction will be relaxed in a future release" |
| ) |
| log( |
| Level.Debug, |
| "Grouped inputs currently requires additional settings to enable the feature", |
| ) |
| log( |
| Level.Debug, |
| """Adding the following ops to torch_executed_ops: |
| - aten::__getitem__ |
| - prim::ListConstruct |
| - prim::ListUnpack |
| - prim::TupleIndex |
| - prim::TupleConstruct |
| - prim::TupleUnpack |
| """, |
| ) |
| compile_spec["torch_fallback"]["forced_fallback_ops"].append( |
| "aten::__getitem__" |
| ) |
| compile_spec["torch_fallback"]["forced_fallback_ops"].append( |
| "prim::ListConstruct" |
| ) |
| compile_spec["torch_fallback"]["forced_fallback_ops"].append("prim::ListUnpack") |
| compile_spec["torch_fallback"]["forced_fallback_ops"].append("prim::TupleIndex") |
| compile_spec["torch_fallback"]["forced_fallback_ops"].append( |
| "prim::TupleConstruct" |
| ) |
| compile_spec["torch_fallback"]["forced_fallback_ops"].append( |
| "prim::TupleUnpack" |
| ) |
This would require modification of the C++ core code as well, to ensure that relaxing this requirement will not cause further issues with the existing compilation phases.
Additional Context
A proof-of-concept for this feature already exists in PR #1599, which could be used as a template to enable full-compilation functionality for collection inputs as well. This would complete the plan for Collection IO as discussed in #629 (comment).