NVIDIA TensorRT Operators Documentation 10.14.1 (original) (raw)

Toggle table of contents sidebar

Boilerplate for All Operator Examples

This boilerplate code provides a framework to run all the operator examples. To make them runnable, copy and paste the specific example code between the designated ‘example begin’ and ‘example end’ comments.

import numpy as np import math # example_plugin_v2.py import ctypes

import tensorrt as trt from cuda.bindings import driver as cuda, runtime as cudart from common import cuda_call, CudaStreamContext, DeviceMem, PinnedHostMem, memcpy_host_to_device_async, memcpy_device_to_host_async

class OutputAllocator(trt.IOutputAllocator): def init(self, curr_size): trt.IOutputAllocator.init(self) self.curr_size = curr_size self.allocated_mem = None if curr_size > 0: self.allocated_mem = DeviceMem(curr_size) self.tensor_shape = None

def reallocate_output(self, tensor_name, memory, size, alignment):
    assert size > 0
    if size > self.curr_size:
        self.allocated_mem = DeviceMem(size)
        self.curr_size = size
    return int(self.allocated_mem.device_ptr)

def notify_shape(self, tensor_name, shape):
    self.tensor_shape = shape

class Runner: def init(self, logger=trt.Logger(min_severity=trt.ILogger.Severity.INFO)): self.builder = trt.Builder(logger) self.network = self.builder.create_network(flags=0) self.config = self.builder.create_builder_config() self.runtime = trt.Runtime(logger) self.inputs = {} self.outputs = {} self.expected = {} self.results = {} self.logger = logger self.atol = 0.1

def example(get_runner: Runner): network = get_runner.network inputs = get_runner.inputs outputs = get_runner.outputs expected = get_runner.expected

# -------------------- Example Begin --------------------
# Paste the code examples here
# e.g. for Activation
in1 = network.add_input("input1", dtype=trt.float32, shape=(2, 3))
layer = network.add_activation(in1, type=trt.ActivationType.RELU)
network.mark_output(layer.get_output(0))

inputs[in1.name] = np.array([[-3.0, -2.0, -1.0], [0.0, 1.0, 2.0]])

outputs[layer.get_output(0).name] = layer.get_output(0).shape

expected[layer.get_output(0).name] = np.array([[0.0, 0.0, 0.0], [0.0, 1.0, 2.0]])
# --------------------- Example End ---------------------

return get_runner

def run_example(): example_runner = Runner() inputs = example_runner.inputs outputs = example_runner.outputs expected = example_runner.expected builder = example_runner.builder config = example_runner.config runtime = example_runner.runtime results = example_runner.results

def log_info(info):
    example_runner.logger.log(trt.ILogger.Severity.INFO, f"[Example] {info}")

def log_error(info):
    example_runner.logger.log(trt.ILogger.Severity.ERROR, f"[Example] {info}")

example_runner = example(example_runner)

log_info("Building serialized network")
serialized_engine = builder.build_serialized_network(example_runner.network, config)
assert serialized_engine is not None

log_info("Creating engine")
engine = runtime.deserialize_cuda_engine(serialized_engine)
context = engine.create_execution_context()

# Allocate host and device buffers
in_mem = []
out_mem = dict()
output_allocators = dict()
tensor_names = [engine.get_tensor_name(i) for i in range(engine.num_io_tensors)]
for tensor in tensor_names:
    dtype = np.dtype(trt.nptype(engine.get_tensor_dtype(tensor)))
    if engine.get_tensor_mode(tensor) == trt.TensorIOMode.INPUT:
        if engine.is_shape_inference_io(tensor):
            context.set_input_shape(tensor, inputs[tensor].shape)
            # Get input memory address from the numpy object.
            input_address = inputs[tensor].ctypes.data
            context.set_tensor_address(tensor, input_address)
        else:
            # Handle input tensors
            context.set_input_shape(tensor, inputs[tensor].shape)
            input_buffer = np.ascontiguousarray(inputs[tensor], dtype=dtype)
            input_memory = DeviceMem(input_buffer.nbytes)
            context.set_tensor_address(tensor, int(input_memory.device_ptr))
            in_mem.append((input_memory, input_buffer))
    else: # Handle output tensors
        # Check if output tensor contains unknown shape
        if trt.volume(context.get_tensor_shape(tensor)) < 0:
            # Set an output allocator for the output tensor with unknown shape.
            # Initialize output allocator with 0 memory size, so reallocate always allocate.
            output_allocator = OutputAllocator(0)
            context.set_output_allocator(tensor, output_allocator)
            output_allocators[tensor] = output_allocator
            # No need to initialize output buffer and output memory here.
            out_mem[tensor] = None
        else:
            size = trt.volume(context.get_tensor_shape(tensor))
            output_buffer = PinnedHostMem(size, dtype)
            output_memory = DeviceMem(output_buffer.nbytes)
            context.set_tensor_address(tensor, int(output_memory.device_ptr))
            out_mem[tensor] = (output_buffer.array, output_memory)

# Run inference
with CudaStreamContext() as stream:
    # Transfer input data to the GPU.
    for input in in_mem:
        memcpy_host_to_device_async(input[0].device_ptr, input[1], stream.stream)
    log_info("Running example")
    context.execute_async_v3(stream_handle=stream.stream)
    # Transfer prediction output from the GPU.
    for output in out_mem:
        output_mem = out_mem[output]
        if output_mem is None:
            # Must have been allocated using OutputAllocator.reallocate.
            assert output in output_allocators
            assert output_allocators[output].allocated_mem
            shape = output_allocators[output].tensor_shape
            assert shape is not None
            size = trt.volume(shape)
            dtype = np.dtype(trt.nptype(engine.get_tensor_dtype(output)))
            output_buffer = PinnedHostMem(size, dtype)
            output_memory = context.get_tensor_address(output)
            output_mem = (output_buffer.array, output_memory)
            # Store tensor to output buffer and output memory mappings.
            out_mem[output] = output_mem
        # Handle both DeviceMem objects and raw pointers
        device_ptr = output_mem[1].device_ptr if hasattr(output_mem[1], 'device_ptr') else output_mem[1]
        memcpy_device_to_host_async(output_mem[0], device_ptr, stream.stream)
    log_info("Synchronizing with cuda stream")
    stream.synchronize()
    log_info("Sync done")
for output in out_mem:
    output_mem = out_mem[output][0]
    shape = outputs[output]
    if trt.volume(context.get_tensor_shape(tensor)) < 0:
        # Get real output tensor size
        shape = output_allocators[output].tensor_shape
        assert shape is not None
        size = trt.volume(shape)
        output_mem = output_mem[:size]
    output_mem = output_mem.reshape(shape)
    results[output] = output_mem

log_info(f"Network inputs: {inputs}")
log_info(f"Inference results: {results}")
log_info(f"Expected results: {expected}")

# Check result
is_equal = {}
all_are_equal = True
for output in expected:
    is_equal[output] = np.allclose(results[output], expected[output], atol=example_runner.atol)
    all_are_equal &= is_equal[output]
log_info(f"All results are expected: {all_are_equal}")
if all_are_equal is False:
    for output in is_equal:
        if is_equal[output] is False:
            log_error(f"{output} mismatch:")
            log_error(f"expected - content:{expected[output]}")
            log_error(f"actual - content:{repr(results[output])}")

log_info("Example complete")

if name == "main": run_example()

TensorRT operators have stricter volume limits for input and output tensors on Turing GPUs, i.e. the product of the tensor dimensions cannot exceed \(2^{31}\). Volume limits for each operator on other GPUs can be found in the corresponding operator documentation page.