Collection IO · pytorch/TensorRT · Discussion #629 (original) (raw)

Collections

Goal

Currently TRTorch programs that can be compiled must be trivially reducible to the form f([Tensor]) -> [Tensor]. Cases like f(Tensor) -> ((Tensor, Tensor)) are supported through this method. This means that any sort of Input/Output formatting is not currently handled by TRTorch. We would like to add support for cases like f(Tensor[]) -> (Tensor, Tensor, (Tensor, Tensor)) or f(Tensor, Tensor, (Tensor, Tensor)) -> (Tensor, (Tensor, Tensor)) which have non trivial subgrouping of tensors.

API Considerations

Considering that the formatting of the function signature is now more complex, we might want to think about ways to make it easy to convey the input specification.

Proposed API

For a module with a signature such as:

... def forward(x, y, (a, b, c)): ...

We could change the API to expect a tuple formatted in the same way someone might call the function. In conjunction with the example tensor feature (#616), this might provide a natural way to reuse or more easily provide input specs vs. doing some sort of mental computation about aligning specs with inputs.

Example

x = trtorch.Input() y = trtorch.Input() a = torch.randn() b = torch.randn() c = torch.randn()

trtorch.compile(mod, inputs=(x,y,(a,b,c)))

This is as opposed to

x = trtorch.Input() y = trtorch.Input() a = torch.randn() b = torch.randn() c = torch.randn()

trtorch.compile(mod, inputs=[x,y,a,b,c])

Where the inputs must be aligned properly and paired internally with the graph signature.

The advantage is we can create an internal structure which encodes the format of the inputs for the user directly from the tuple provided. It also gives us an input of fixed size. Alternative methods that examine the graph input signature may have these fixed sizes obfuscated by type information. For instance the graph signature that uses a list to group subsets of arguments instead of a tuple you might see a signature like:

graph(%x : Tensor, %y : Tensor, %abc : Tensor[]):
  ...
       %trt_ins : Tensor[] = prim::ListConstruct(%x, %y, %a, %b, %c)
       %trt_outs : Tensor[] = tensorrt::execute_engine(... %trt_ins)
 ...
  -> ((%i, (%j, %k))

This will not tell us how to align the inputs provided by the user as a flat list.

One limitation of this design may be the usage in C++, more exploration will be required to determine if this is ergonomic and consistent with PyTorch

Internal Implementation

Leverage TorchScript IR

1.Inputs

We could look to make trtorch::core::ir::Input compatible with IValues by registering it as a torch custom class. This would let us nest Inputs in PyTorch types. This means we can pass around one IValue which holds the full input spec. This can then be parsed in the graph construction phase directly.

1. Go from user spec to IValue

Its unclear the exact process to go from a presumably standard Python or C++ tuple to an IValue but this is something that PyTorch is able to do so it should just require looking at the source for PyTorch.

2. Assign IDs to Inputs and create list of Inputs to pass to TensorRT

The next step is to populate a data structure like the one below which assigns each input an ID so that we can create a flattened vector of inputs to pass to TensorRT.

We should add a field to the trtorch::core::ir::Input class which is called ID. This will be the unique identifier for the Input during compilation. The order in which we add these inputs will be determined by an in-order traversal of the tuple provided by the user. We only increment the id counter when we hit a new un-labeled Input (i.e. the leaves of the syntax tree). At the same time we can create a list of Inputs which will be passed to the conversion phase. This likely should be stored in a single struct

namespace trtorch { namespace core { namespace ir {

struct GraphInputs { torch::jit::IValue* input_signature; std::vector flattened_inputs; };

} // namespace ir } // namespace core } // namespace trtorch

This object should then be added to the CompileSpec (this could potentially replace the vector of Inputs we use right now).

2. Parse IValue and Construct Graph

Once we get to the graph construction phase we now need to amend it so that the first step is to create the input to graph and then take the inputs and flatten them to a list where each index of the list corresponds to the ID of each Input in TorchScript. This will involve using the IValue created in step 1 as the spec for access procedure for each Input.

2. Outputs

%trt_out : Tensor[] = tensorrt::execute_engine(...) 
...
-> (x, y, [a, b], c)

Returned from Conversion: IValue: torch::jit::IValue((0, 1 List[2, 3], 4))

1. Evaluating collection operations to get list of outputs

The evaluation system should automatically construct any sort of collections that will be used in the output during conversion. However currently MarkOutputs only handles ITensors and TensorContainers. It will need to extended to handle parsing the collection types. At this time we should construct a similar IValue to the Input IValue which encodes the indexes from the output of TensorRT to the final output tuple. This IValue should be returned from the conversion process with the serialized TensorRT engine. We already have an ID for each output to deal with the fact that TensorRT doesn't guarantee output order. These IDs can be reused in the IValue.

2. Parse IValue and Construct Graph

In the graph construction phase once the TensorRT engine is embedded now we need to add the nodes to pack the outputs into the right format. This should use a similar system to the input system except it is packing Tensors from a list into a format vs unpacking.

Data Structures

namespace trtorch { namespace core { namespace ir {

struct GraphInputs { torch::jit::IValue input_signature; std::vector flattened_inputs; };

typedef std::pair<GraphInputs, torch::jit::IValue> GraphIO;

} // namespace ir } // namespace core } // namespace trtorch

GraphIO is a pair where the first index is a struct which both the formatted input tuple containing core::ir::Input structs. And then a flattened version of the input tuple. The second index holds an IValue which is formatted tuple of Ints which defines how to go from the list output of TensorRT to the output tuple.

Implementation Phases

WAR

We should first check to see if partial compilation can handle some of this trivially to start so that users can get unblocked

MVP

We should implement support for one to two simple collection types. I think that tuples likely will be the simplest so we should start with that and get the system working end to end from user API to graph synthesis.

Additional Data Types

The next least complex type would be lists most likely. They should be implementable like tuples with very few changes if we use the API described above. After that we may want to look at dictionaries (this could be pushed to a later release even) which have the added complexity of keys.

Syntax Sugar

Finally we should consider if there is any way to make the API simpler than what we have proposed here. If there is any work we could do for the user.