fix: Allow full model compilation with collection outputs by gs-olive · Pull Request #1599 · pytorch/TensorRT (original) (raw)

Thanks for the comments and review @narendasan - I have incorporated the feedback and updated two of the user warnings to compilation-halting errors.

One note I wanted to make is that despite the min_block_size=1 and allowing collection-type-nodes to run in Torch, this implementation still respects full compilation and will not execute intermediate pack/unpack operations in Torch. This is because prim::TupleUnpack and other such operators are not automatically added to torch_executed_ops - this is only done in the case where input_signature is used, which is not the intent of this PR (it will be a future PR). As a result, only the collection ops needed to pack the final model output are run in Torch, as per this function:

// Check if the inputs and outputs of the graph are Tensor. If not, then fallback connected nodes
void setInputsOutputsConnectedNodes(PartitioningCtx* ctx, torch::jit::Block* block) {
// fallback nodes that produce entire graph's nonTensor output
for (auto i : block->outputs()) {
if (!isTensor(i)) {
ctx->setNodeExecutorDecision(i->node(), NodeExecutorDecision::kNON_TENSOR);
}
}
// fallback nodes that consume entire graph's nonTensor input
for (auto i : block->inputs()) {
if (!isTensor(i)) {
for (auto use : i->uses()) {
ctx->setNodeExecutorDecision(use.user, NodeExecutorDecision::kNON_TENSOR);
}
}
}
}

Any intermediate packing/unpacking is handled by the evaluators and does not cause a graph segmentation, since those nodes are not directly graph outputs.