fix: Allow full model compilation with collection outputs by gs-olive · Pull Request #1599 · pytorch/TensorRT (original) (raw)
Thanks for the comments and review @narendasan - I have incorporated the feedback and updated two of the user warnings to compilation-halting errors.
One note I wanted to make is that despite the min_block_size=1
and allowing collection-type-nodes to run in Torch, this implementation still respects full compilation and will not execute intermediate pack/unpack operations in Torch. This is because prim::TupleUnpack
and other such operators are not automatically added to torch_executed_ops
- this is only done in the case where input_signature
is used, which is not the intent of this PR (it will be a future PR). As a result, only the collection ops needed to pack the final model output are run in Torch, as per this function:
// Check if the inputs and outputs of the graph are Tensor. If not, then fallback connected nodes |
---|
void setInputsOutputsConnectedNodes(PartitioningCtx* ctx, torch::jit::Block* block) { |
// fallback nodes that produce entire graph's nonTensor output |
for (auto i : block->outputs()) { |
if (!isTensor(i)) { |
ctx->setNodeExecutorDecision(i->node(), NodeExecutorDecision::kNON_TENSOR); |
} |
} |
// fallback nodes that consume entire graph's nonTensor input |
for (auto i : block->inputs()) { |
if (!isTensor(i)) { |
for (auto use : i->uses()) { |
ctx->setNodeExecutorDecision(use.user, NodeExecutorDecision::kNON_TENSOR); |
} |
} |
} |
} |
Any intermediate packing/unpacking is handled by the evaluators and does not cause a graph segmentation, since those nodes are not directly graph outputs.