Functionals — TensorRT-LLM (original) (raw)

class tensorrt_llm.functional.AllReduceFusionOp(

value,

names=,

*values,

module=None,

qualname=None,

type=None,

start=1,

boundary=None,

)[source]#

Bases: IntEnum

LAST_PROCESS_FOR_UB = 2#

MOE_ALLREDUCE_RESIDUAL_RMS_NORM = 8#

NONE = 0#

RESIDUAL_RMS_NORM = 1#

RESIDUAL_RMS_NORM_OUT_QUANT_FP8 = 6#

RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4 = 7#

RESIDUAL_RMS_NORM_QUANT_FP8 = 4#

RESIDUAL_RMS_NORM_QUANT_NVFP4 = 5#

RESIDUAL_RMS_PREPOST_NORM = 3#

class tensorrt_llm.functional.AllReduceParams(

strategy: AllReduceStrategy = AllReduceStrategy.AUTO,

fusion_op: AllReduceFusionOp = AllReduceFusionOp.NONE,

bias: Tensor | None = None,

residual: Tensor | None = None,

norm_weight: Tensor | None = None,

scale: Tensor | None = None,

norm_pre_residual_weight: Tensor | None = None,

eps: float = 1e-06,

enable_allreduce: bool = True,

)[source]#

Bases: object

has_affine()[source]#

has_bias()[source]#

has_scale()[source]#

update_strategy()[source]#

class tensorrt_llm.functional.AllReduceStrategy(

value,

names=,

*values,

module=None,

qualname=None,

type=None,

start=1,

boundary=None,

)[source]#

Bases: IntEnum

AUTO = 3#

MIN_LATENCY = 1#

NCCL = 0#

ONESHOT = 4#

TWOSHOT = 5#

UB = 2#

class tensorrt_llm.functional.AttentionMaskType(

value,

names=,

*values,

module=None,

qualname=None,

type=None,

start=1,

boundary=None,

)[source]#

Bases: IntEnum

bidirectional = 3#

bidirectionalglm = 4#

blocksparse = 5#

causal = 1#

custom_mask = 6#

padding = 0#

sliding_window_causal = 2#

class tensorrt_llm.functional.Conditional(condition: Tensor)[source]#

Bases: object

Add an operation to conditionally execute two code paths/subgraphs.

Usage:

  1. conditional = Conditional(condition)
  2. input_1_ = conditional.add_input(input_1) …input_n_ = conditional.add_input(input_n)
  3. Construct the graph to get true_output_value and false_output_value using input_1_, …, input_n_
  4. output = conditional.add_output(true_output_value, false_output_value)

add_input(

input: Tensor,

) → Tensor[source]#

add_output(

true_value: Tensor,

false_value: Tensor,

) → Tensor[source]#

class tensorrt_llm.functional.DimRange(

shape: List[int | List[int] | Tuple[int, int, int]],

names: List[str],

)[source]#

Bases: object

One DimRange object stores the ranges of all the dimensions of one tensor in one optimization profile. For example, tensor has 2 dimensions. Then the data members are:

self.min = [dim 0 min, dim 1 min] self.opt = [dim 0 opt, dim 1 opt] self.max = [dim 0 max, dim 1 max]

For static dimension, it has min==opt==max, thus the shape param in the ctor can be an integer

class tensorrt_llm.functional.LayerNormPositionType(

value,

names=,

*values,

module=None,

qualname=None,

type=None,

start=1,

boundary=None,

)[source]#

Bases: IntEnum

post_layernorm = 1#

pre_layernorm = 0#

class tensorrt_llm.functional.LayerNormType(

value,

names=,

*values,

module=None,

qualname=None,

type=None,

start=1,

boundary=None,

)[source]#

Bases: IntEnum

GroupNorm = 2#

LayerNorm = 0#

RmsNorm = 1#

class tensorrt_llm.functional.MLPType(

value,

names=,

*values,

module=None,

qualname=None,

type=None,

start=1,

boundary=None,

)[source]#

Bases: IntEnum

FusedGatedMLP = 2#

GatedMLP = 1#

MLP = 0#

class tensorrt_llm.functional.PositionEmbeddingType(

value,

names=,

*values,

module=None,

qualname=None,

type=None,

start=1,

boundary=None,

)[source]#

Bases: IntEnum

alibi = 4#

alibi_with_scale = 5#

chatglm = 7#

static choices() → List[str][source]#

deferred = 10#

static from_string(s)[source]#

is_alibi() → bool[source]#

is_deferred() → bool[source]#

is_mrope() → bool[source]#

is_rope() → bool[source]#

learned_absolute = 0#

long_rope = 3#

mrope = 9#

relative = 6#

rope_gpt_neox = 2#

rope_gptj = 1#

yarn = 8#

class tensorrt_llm.functional.RopeEmbeddingUtils[source]#

Bases: object

static apply_llama3_scaling(

inv_freqs: ndarray,

rope_scaling_config: dict,

)[source]#

static apply_rotary_pos_emb(

tensor: Tensor,

position_embedding: List[Tensor] = None,

pos_emb_type: PositionEmbeddingType = PositionEmbeddingType.rope_gptj,

) → Tensor[source]#

static apply_rotary_pos_emb_chatglm(

qkv,

position_embedding,

num_attention_heads,

attention_head_size,

max_position_embeddings,

rotary_embedding_scale,

remove_input_padding,

) → Tensor[source]#

static apply_rotary_pos_emb_cogvlm(

qkv,

position_embedding,

num_attention_heads,

attention_head_size,

max_position_embeddings,

rotary_embedding_scale,

remove_input_padding,

) → Tensor[source]#

static create_fake_weight(

dim: int,

dtype=<class 'numpy.float16'>,

)[source]#

static create_sinusoidal_positions(

num_pos: int,

dim: int,

theta: float = 10000.0,

dtype=<class 'numpy.float32'>,

)[source]#

static create_sinusoidal_positions_for_attention_plugin(

num_pos: int,

dim: int,

theta: float = 10000.0,

scale: float = 1.0,

scale_type: ~tensorrt_llm.functional.RotaryScalingType = RotaryScalingType.none,

rope_scaling_config: dict = None,

dtype=<class 'numpy.float32'>,

)[source]#

static create_sinusoidal_positions_for_cogvlm_attention_plugin(

num_pos: int,

dim: int,

theta: float = 10000.0,

scale: float = 1.0,

scale_type: ~tensorrt_llm.functional.RotaryScalingType = RotaryScalingType.none,

vision_start: int = 1,

vision_length: int = 1225,

dtype=<class 'numpy.float32'>,

)[source]#

create_sinusoidal_positions_long_rope(

num_orig_pos: int,

dim: int,

theta: float = 10000.0,

scaling_short_factors: ~tensorrt_llm.functional.Tensor = 1.0,

scaling_long_factors: ~tensorrt_llm.functional.Tensor = 1.0,

short_mscale=None,

long_mscale=None,

dtype=<class 'numpy.float32'>,

)[source]#

static create_sinusoidal_positions_yarn(

num_pos: int,

dim: int,

base: int = 10000,

scaling_factor: float = 1.0,

original_max_position_embeddings: int = 4096,

beta_fast: int = 32,

beta_slow: int = 1,

mscale: float = 1.0,

mscale_all_dim: float = 1.0,

dtype=<class 'numpy.float32'>,

)[source]#

static rotate_every_two(

tensor: Tensor,

) → Tensor[source]#

static rotate_half(

tensor: Tensor,

) → Tensor[source]#

class tensorrt_llm.functional.RotaryScalingType(

value,

names=,

*values,

module=None,

qualname=None,

type=None,

start=1,

boundary=None,

)[source]#

Bases: IntEnum

dynamic = 2#

static from_string(s)[source]#

linear = 1#

llama3 = 4#

longrope = 3#

mrope = 6#

none = 0#

yarn = 5#

class tensorrt_llm.functional.SideStreamIDType(

value,

names=,

*values,

module=None,

qualname=None,

type=None,

start=1,

boundary=None,

)[source]#

Bases: IntEnum

disable = 0#

moe = 1#

class tensorrt_llm.functional.SliceInputType(

value,

names=,

*values,

module=None,

qualname=None,

type=None,

start=1,

boundary=None,

)[source]#

Bases: IntEnum

axes = 5#

data = 0#

fill_value = 4#

size = 2#

start = 1#

stride = 3#

class tensorrt_llm.functional.Tensor(

name=None,

dtype=None,

shape=None,

dim_range=None,

is_network_input=True,

location=<TensorLocation.DEVICE: 0>,

network=None,

trt_tensor=None,

)[source]#

Bases: object

The class to represent dense tensors.

A dense tensor is named, has a shape and contains typed elements. Each dimension of a tensor can either be static or dynamic. Static dimensions are known at engine compilation by TensorRT. Dynamic dimensions can take values determined at runtime. The tensor can be located on the host (CPU) or the device (GPU).

abs()[source]#

See functional.abs.

cast(dtype)[source]#

See functional.cast.

property dtype#

The type of the elements in the tensor.

flatten(start_dim=0, end_dim=-1)[source]#

See functional.flatten.

get_parent()[source]#

Get the layer that produces this tensor.

get_users()[source]#

Get the layers that use this tensor as an input.

is_dynamic(dim=None)[source]#

If the argument ‘dim’ is None, that function returns a boolean that indicates if the tensor contains a dynamic dimension (True) or not (False). In that case, the first dimension is excluded (as it usually corresponds to the batch size). If the argument is an integer, that functions returns a boolean that indicates if the dimension ‘dim’ is dynamic (True) or not (False).

is_trt_wrapper()[source]#

Check if there is a trt.ITensor member inside, which is required for graph rewriter. In order to differentiate usages, it may be necessary to have an inheritance hierarchy.

property location#

The physical location of the tensor (on the host or the device).

log()[source]#

See functional.log.

mark_output(

name: str | None = None,

dtype: str | DataType | None = None,

)[source]#

Mark a tensor as a network output.

When a tensor is marked as an output, its content can be obtained after the execution of the TensorRT engine. The user is responsible for allocating buffers to store the output tensors when preparing the execution of the TensorRT engine.

max(dim, keepdim=False)[source]#

See functional.max.

mean(dim, keepdim=False)[source]#

See functional.mean.

property name#

The name of the tensor.

ndim()[source]#

Returns the rank (i.e. the number of dimensions) of the tensor.

property network#

permute(dims)[source]#

See functional.permute.

rank()[source]#

Returns the rank (i.e. the number of dimensions) of the tensor.

repeat(sizes)[source]#

See functional.repeat

replace_all_uses_with(new_tensor)[source]#

Replace all uses of this tensor as an input to consumer layers

select(dim, index)[source]#

See functional.select.

property shape#

The shape of the tensor.

size(dim=None)[source]#

Returns the shape of the tensor if the dim parameter is None. Otherwise, returns a size of the dimension indicated by dim. The behavior is undefined if dim is negative or exceeds the rank of the tensor.

split(split_size_or_sections, dim=0)[source]#

See functional.split.

sqrt()[source]#

See functional.sqrt.

squeeze(dim, zero_is_placeholder)[source]#

See functional.squeeze.

transpose(dim0, dim1)[source]#

See functional.transpose.

unbind(dim=0)[source]#

See functional.unbind.

unsqueeze(dim)[source]#

See functional.squeeze.

view(shape, zero_is_placeholder=True)[source]#

See functional.view.

tensorrt_llm.functional.abs(

input: ~tensorrt_llm.functional.Tensor,

*,

op: ~tensorrt.tensorrt.UnaryOperation = <UnaryOperation.ABS: 4>,

) → Tensor#

Add an elementwise operation on a single input.

The following closures are defined in functional.*:

round for op=trt.UnaryOperation.ROUND sqrt for op=trt.UnaryOperation.SQRT exp for op=trt.UnaryOperation.EXP sin for op=trt.UnaryOperation.SIN cos for op=trt.UnaryOperation.COS abs for op=trt.UnaryOperation.ABS log for op=trt.UnaryOperation.LOG

It is implemented using the IUnaryLayer from TensorRT.

Parameters:

Returns:

The tensor produced by this elementwise operation.

tensorrt_llm.functional.activation(

input: Tensor,

act_type: ActivationType,

) → Tensor[source]#

Add an activation function.

Parameters:

The following closures are defined in functional.*:

relu for op=trt.ActivationType.RELU tanh for op=trt.ActivationType.TANH sigmoid for op=trt.ActivationType.SIGMOID

Returns:

The tensor produced by the activation layer.

tensorrt_llm.functional.add(

left: ~tensorrt_llm.functional.Tensor | int | float,

right: ~tensorrt_llm.functional.Tensor | int | float,

*,

op: ~tensorrt.tensorrt.ElementWiseOperation = <ElementWiseOperation.SUM: 0>,

) → Tensor#

Add an elementwise operation with two inputs.

For each input, that function first creates a constant tensor if the input is an integer or a float. Then, if needed, it expands the smaller tensor to make sure its rank is the same as the larger one. Then, it performs the elementwise operation ‘op’.

The following closures are defined in functional.*:

add for op=trt.ElementWiseOperation.SUM sub for op=trt.ElementWiseOperation.SUB mul for op=trt.ElementWiseOperation.PROD div for op=trt.ElementWiseOperation.DIV floordiv for op=trt.ElementWiseOperation.FLOOR_DIV gt for op=trt.ElementWiseOperation.GREATER lt for op=trt.ElementWiseOperation.LESS op_and for op=trt.ElementWiseOperation.AND op_or for op=trt.ElementWiseOperation.OR eq for op=trt.ElementWiseOperation.EQUAL minimum for op=trt.ElementWiseOperation.MIN maximum for op=trt.ElementWiseOperation.MAX pow for op=trt.ElementWiseOperation.POW

It is implemented using the IElementWiseLayer from TensorRT.

Parameters:

Returns:

The tensor produced by this elementwise operation.

tensorrt_llm.functional.allgather(

tensor: Tensor,

group: List[int],

gather_dim: int = 0,

) → Tensor[source]#

Add an operation that performs a collective all-gather.

Let’s define ‘group_size’ as the length of the ‘group’ list. That functions creates a layer to gather ‘group_size’ tensors distributed amongst the ‘group_size’ participating ranks (one GPU per rank).

The list ‘group’ contains the identifiers of the ranks participating into the collective operation.

Note that ‘group’ here can be either TP group or PP group, because allgather communication is not limited to a specific split pattern. Therefore ‘group_size’ does not need to equal MPI ‘world_size’.

The tensors in the different ranks must be 1D tensors (or views) and the output tensor will have that same shape.

Given the ‘section_size = input.shape[0] / group_size’, each rank contributes a section of its input tensor that correspond to ‘rank*section_size:(rank+1)*section_size’.

That operation is implemented using a plugin that wraps the NCCL all-gather collective operation. Seehttps://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allgatherfor details.

Parameters:

Returns:

The tensor produced by that layer.

tensorrt_llm.functional.allreduce(tensor: ~tensorrt_llm.functional.Tensor, group: ~typing.List[int], all_reduce_params: ~tensorrt_llm.functional.AllReduceParams | None = <tensorrt_llm.functional.AllReduceParams object>) → Tensor[source]#

Add an operation that performs a collective all-reduce.

Let’s define ‘world_size’ as the length of the ‘group’ list. That functions creates a layer to compute the sum of ‘world_size’ tensors distributed amongst the ‘world_size’ participating ranks (one GPU per rank).

The list ‘group’ contains the identifiers of the ranks participating into the collective operation.

The tensors in the different ranks must be 1D tensors (or views) and the output tensor will have that same shape. The output tensor will be replicated on the ‘world_size’ ranks.

That operation is implemented using a plugin that wraps the NCCL all-reduce collective operation. Seehttps://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allreducefor details.

Parameters:

Returns:

The tensor produced by that layer.

tensorrt_llm.functional.arange(

start: Tensor | int,

end: Tensor | int,

dtype: str,

) → Tensor[source]#

Add an operation to fill a 1D tensor.

The tensor is filled with the values between start and end with a step of 1 between the different elements. In pseudo-code, it corresponds to a tensor populated with the values:

output = Tensor([dtype(ii) for ii in range(start, end, 1)])

For example, a call to arange(3, 6, ‘int32’) will add an operation to the TensorRT graph that will produce [3, 4, 5] when executed. The call to arange(2, 5, ‘float32’) will add a layer to generate [2.0, 3.0, 4.0].

This operation is implemented using a tensorrt.IFillLayer in trt.FillOperation.LINSPACE mode.

Parameters:

Returns:

The tensor produced by the fill layer. It is a 1D tensor containingend-start elements of type dtype.

tensorrt_llm.functional.argmax(

input: Tensor,

dim: int,

keepdim: bool = False,

) → Tensor[source]#

Add an argmax operation.

As explained in the ONNX documentation,

that function creates a layer computing the indices of the max elements of the input tensor’s element along the provided dim. The resulting tensor has the same rank as the input if keepdims is True. If keepdims is False, then the resulting tensor has the reduced dimension pruned.

Parameters:

Returns:

The tensor produced by this argmax operation.

tensorrt_llm.functional.assertion(

condition: Tensor,

message: str = '',

) → None[source]#

tensorrt_llm.functional.avg_pool2d(

input: Tensor,

kernel_size: Tuple[int],

stride: Tuple[int] | None = None,

padding: Tuple[int] | None = (0, 0),

ceil_mode: bool = False,

count_include_pad: bool = True,

) → Tensor[source]#

tensorrt_llm.functional.bert_attention(

tensor: Tensor,

input_lengths: Tensor,

num_heads: int,

head_size: int,

q_scaling: float,

relative_attention: bool = False,

relative_attention_bias: Tensor = None,

max_distance: int = 0,

max_input_length: Tensor = None,

sage_attn: bool = False,

sage_attn_q_block_size: int = 0,

sage_attn_k_block_size: int = 0,

sage_attn_v_block_size: int = 0,

) → Tuple[Tensor][source]#

Add an operation that performs the multi-head attention in BERT.

The multi-head attention (MHA) is the sequence of a batched matmul, a softmax and a batched matmul as described inhttps://arxiv.org/abs/1706.03762. That function adds an operation that performs those computations using a single GPU kernel.

The input tensor contains the Q, K and V elements. It is a 2D tensor and its shape is ‘[sum_of_tokens, 3*hidden_dim]’ where the ‘sum_of_tokens’ is the sum of the sequence lengths in the batch.

In MHA, the output of the Q*K^T product is scaled by a constant value that is computed as:

1.f / (q_scaling * sqrt(head_size)).

That ‘q_scaling’ constant is the last argument of that function.

That layer is implemented using a plugin (see bertAttentionPlugin).

Parameters:

Returns:

The tensor produced by that layer.

tensorrt_llm.functional.broadcast_helper(

left: Tensor | int | float,

right: Tensor | int | float,

) → Tuple[Tensor, Tensor][source]#

Helper function to perform a broadcast.

For each input, that function first creates a constant tensor if the input is an integer or a float. Then, if needed, it expands the smaller tensor to make sure its rank is the same as the larger one.

Parameters:

Returns:

A pair of tensors of same rank.

tensorrt_llm.functional.cast(

input: Tensor,

dtype: str | DataType,

)[source]#

Add a cast operation.

For an input tensor of type INT8, this function sets the dynamic range of the input to [-127, 127] for automatic dequantization. For a cast into INT8, that function sets the dynamic range of the output to [-127, 127] for automatic quantization.

Parameters:

Returns:

The tensor produced by the inserted layer.

tensorrt_llm.functional.categorical_sample(

probs: Tensor,

rand_data: Tensor = None,

) → Tensor[source]#

This is a sampling operation and an equivalent of torch.distributions.Categorical.sample() i.e. given a probability distribution tensor, it samples an index of that tensor. See: https://pytorch.org/docs/stable/distributions.html#torch.distributions.categorical.Categorical.sampleNOTE: This assumes that the given probabilities are not normalized.

Parameters:

Returns:

A tensor containing a single index of the probs tensor representing the sample.

tensorrt_llm.functional.chunk(

tensor: Tensor,

chunks: int,

dim: int = 0,

) → Tensor[source]#

Add an operation that splits a tensor into sub-tensors.

This operation creates a list of tensors that are obtained from the input tensor by chunking it along the dimension ‘dim’. It produces ‘chunks’ sub-tensors.

That operation is only defined for static tensors (no dynamic dimension) and the size of the tensor in the dimension ‘dim’ must be a multiple of ‘chunks’: ‘input.shape[dim] % chunks == 0’.

It maps to ‘split’ with ‘split_size = input.shape[dim] / chunks’.

Parameters:

Returns:

The list of tensors produced by the different operations.

tensorrt_llm.functional.clip(

input: Tensor,

alpha: float,

beta: float,

) → Tensor[source]#

Add a CLIP operation that sets the range to [alpha, beta].

Parameters:

Returns:

The tensor produced by the activation layer.

tensorrt_llm.functional.concat(

inputs: Sequence[Tensor | int],

dim: int = 0,

) → Tensor[source]#

Add an operation to concatenate tensors.

The function creates an operation that concatenates the tensors from the sequence ‘inputs’. The concatenation is done along the dimension ‘dim’.

All the tensors in ‘inputs’ must have the same shape expect for the dimension ‘dim’.

for ii in range(inputs[0].rank()):

assert (ii == dim) or all(inp.shape[ii] == inputs[0].shape[ii] for inp in inputs)

The shape of the output tensor is defined as:

for ii in range(inputs[0].rank()):

# Same size as all the inputs in dimension ii != dim. output.shape[ii] = inputs[0].shape[ii]

# Sum of the sizes in the different inputs in dimension ‘dim’. if ii == dim:

for jj in range(1, len(inputs)):

output.shape[ii] += inputs[jj].shape[ii]

For example, given a sequence of two 2D tensors [[0, 1], [2, 3]] and [[4, 5], [6, 7]] both of shape [2, 2],

concat(inputs, 0)

will produce [[0, 1], [2, 3], [4, 5], [6, 7]] of shape [4, 2] and

concat(inputs, 1)

will produce [[0, 1, 4, 5], [2, 3, 6, 7]] of shape [2, 4].

Parameters:

Returns:

A tensor that contains the concatenation of the tensors.

tensorrt_llm.functional.constant(

ndarray: ndarray,

as_dtype: DataType | None = None,

as_shape=None,

) → Tensor[source]#

Add a constant layer.

TensorRT graphs encapsulate constant values in the form of constant layers (tensorrt.IConstantLayer). That function creates such a layer from a Numpy array of values. After compilation of the network by TensorRT, those weights are stored in the serialized TensorRT engine.

Parameters:

ndarray – numpy.ndarray The array of values (weights) encapsulated by this constant layer.

Returns:

The tensor produced by the inserted layer.

tensorrt_llm.functional.constant_to_tensor_(

input: Tensor | int | float | bool,

dtype: DataType | str = None,

to_array=True,

) → Tensor[source]#

tensorrt_llm.functional.constants_to_tensors_(

*inputs: Tensor | int | float,

) → Tuple[Tensor, ...][source]#

Helper function to create tensors from multiple inputs.

For each inputs, that function first creates a constant tensor if the input is an integer or a float. Then, if any input is int64, it upcasts other integer inputs to int64.

Parameters:

inputs – Tuple[Union[Tensor, int, float], …] The inputs to create tensors from.

Returns:

A tuple of tensors.

tensorrt_llm.functional.conv1d(

input: Tensor,

weight: Tensor,

bias: Tensor | None = None,

stride: int = 1,

padding: int = 0,

dilation: int = 1,

groups: int = 1,

) → Tensor[source]#

tensorrt_llm.functional.conv2d(

input: Tensor,

weight: Tensor,

bias: Tensor | None = None,

stride: Tuple[int, int] = (1, 1),

padding: Tuple[int, int] = (0, 0),

dilation: Tuple[int, int] = (1, 1),

groups: int = 1,

pre_padding: Tuple[int, int] | None = None,

post_padding: Tuple[int, int] | None = None,

) → Tensor[source]#

tensorrt_llm.functional.conv3d(

input: Tensor,

weight: Tensor,

bias: Tensor | None = None,

stride: int | Tuple[int, int] = (1, 1, 1),

padding: int | Tuple[int, int] = (0, 0, 0),

dilation: int | Tuple[int, int] = (1, 1, 1),

groups: int = 1,

) → Tensor[source]#

tensorrt_llm.functional.conv_transpose2d(

input: Tensor,

weight: Tensor,

bias: Tensor | None = None,

stride: Tuple[int, int] = (1, 1),

padding: Tuple[int, int] = (0, 0),

output_padding: Tuple[int, int] = (0, 0),

dilation: Tuple[int, int] = (1, 1),

groups: int = 1,

) → Tensor[source]#

tensorrt_llm.functional.cos(

input: ~tensorrt_llm.functional.Tensor,

*,

op: ~tensorrt.tensorrt.UnaryOperation = <UnaryOperation.COS: 7>,

) → Tensor#

Add an elementwise operation on a single input.

The following closures are defined in functional.*:

round for op=trt.UnaryOperation.ROUND sqrt for op=trt.UnaryOperation.SQRT exp for op=trt.UnaryOperation.EXP sin for op=trt.UnaryOperation.SIN cos for op=trt.UnaryOperation.COS abs for op=trt.UnaryOperation.ABS log for op=trt.UnaryOperation.LOG

It is implemented using the IUnaryLayer from TensorRT.

Parameters:

Returns:

The tensor produced by this elementwise operation.

tensorrt_llm.functional.cp_split_plugin(

input_ids: Tensor,

host_request_types: Tensor,

host_context_lengths: Tensor,

cp_size: int = 1,

cp_rank: int = 0,

) → Tensor[source]#

Add an operation to perform splitting for context parallelism.

This operation split the input_ids into cp_size chunks, and return the cp_rank-th chunk. When the seqlen % cp_size != 0, the chunk sizes of each rank would be [seqlen // cp_size, seqlen // cp_size, …, seqlen - (seqlen // cp_size) * cp_size]

It inserts a IPluginV3Layer.

Parameters:

Returns:

The output split tensor. The length of the output split tensor. The index for rebuilding the sequence

tensorrt_llm.functional.create_allreduce_plugin(

network: INetworkDefinition,

tensor: ITensor,

workspace: ITensor | None,

group: array,

dtype: DataType,

all_reduce_params: AllReduceParams,

)[source]#

tensorrt_llm.functional.cuda_stream_sync(

input_list: List[Tensor],

side_stream_id: SideStreamIDType,

) → Tensor[source]#

Wait for the side stream on the main stream. output = input_list[0]

Parameters:

tensorrt_llm.functional.cumsum(

input: Tensor,

dim: int,

prefer_plugin: bool = True,

) → Tensor[source]#

Add an operation to calculate inclusive cumulative sum of elements of a tensor in a given dimension.

Given an input tensor, that function creates an operation that calculates inclusive cumulative sum of elements in the dimension ‘dim’ to create a new tensor. The output tensor has the same shape as the input tensor.

The input tensor must have rank >= 1. The ‘dim’ must be valid, and negative value is supported.

For example, on input=[[4, 2, 5], [2, 1, 2], [4, 7, 1]], which has a shape [3, 3],

cumsum(input, 0)

will produce [[4, 2, 5], [6, 3, 7], [10, 10, 8]].

cumsum(input, 1)

will produce [[4, 6, 11], [2, 3, 5], [4, 11, 12]].

That operation is implemented by TensorRT ILoopLayer.

Parameters:

Returns:

The tensor containing the inclusive cumulative sum of input.

tensorrt_llm.functional.div(

left: ~tensorrt_llm.functional.Tensor | int | float,

right: ~tensorrt_llm.functional.Tensor | int | float,

*,

op: ~tensorrt.tensorrt.ElementWiseOperation = <ElementWiseOperation.DIV: 5>,

) → Tensor#

Add an elementwise operation with two inputs.

For each input, that function first creates a constant tensor if the input is an integer or a float. Then, if needed, it expands the smaller tensor to make sure its rank is the same as the larger one. Then, it performs the elementwise operation ‘op’.

The following closures are defined in functional.*:

add for op=trt.ElementWiseOperation.SUM sub for op=trt.ElementWiseOperation.SUB mul for op=trt.ElementWiseOperation.PROD div for op=trt.ElementWiseOperation.DIV floordiv for op=trt.ElementWiseOperation.FLOOR_DIV gt for op=trt.ElementWiseOperation.GREATER lt for op=trt.ElementWiseOperation.LESS op_and for op=trt.ElementWiseOperation.AND op_or for op=trt.ElementWiseOperation.OR eq for op=trt.ElementWiseOperation.EQUAL minimum for op=trt.ElementWiseOperation.MIN maximum for op=trt.ElementWiseOperation.MAX pow for op=trt.ElementWiseOperation.POW

It is implemented using the IElementWiseLayer from TensorRT.

Parameters:

Returns:

The tensor produced by this elementwise operation.

tensorrt_llm.functional.dora_plugin(

activations: Tensor,

out_hidden_sizes: list[int],

lora_weights_pointers: list[Tensor],

host_request_types: Tensor,

host_context_lengths: Tensor | None = None,

) → Tensor[source]#

The DoRA plugin applies column-wise scaling to the output of a LoRA layer.

Parameters:

Returns:

The tensor produced by that layer.

tensorrt_llm.functional.einsum(

einsum_eq: str,

inputs: Sequence[Tensor],

) → Tensor[source]#

Add an Einsum operation.

That operation maps to tensorrt.IEinsumLayer. As explained in the TensorRT documentation, this layer implements a summation over the elements of the inputs along dimensions specified by the equation parameter, based on the Einstein summation convention. The layer can have one or more inputs of rank >= 0. All the inputs must be of same data type. This layer supports all TensorRT data types except bool. There is one output tensor of the same type as the input tensors. The shape of output tensor is determined by the equation.

The equation specifies ASCII lower-case letters for each dimension in the inputs in the same order as the dimensions, separated by comma for each input. The dimensions labeled with the same subscript must match or be able to be broadcasted. Repeated subscript labels in one input take the diagonal. Repeating a label across multiple inputs means that those axes will be multiplied. Omitting a label from the output means values along those axes will be summed. In implicit mode, the indices which appear once in the expression will be part of the output in increasing alphabetical order. In explicit mode, the output can be controlled by specifying output subscript labels by adding an arrow (‘->’) followed by subscripts for the output. For example, “ij,jk->ik” is equivalent to “ij,jk”. Ellipsis (‘…’) can be used in place of subscripts to broadcast the dimensions. See the TensorRT Developer Guide for more details on equation syntax.

Many common operations can be expressed using the Einsum equation. For .. rubric:: Example

Matrix Transpose: ij->ji Sum: ij-> Matrix-Matrix Multiplication: ik,kj->ij Dot Product: i,i-> Matrix-Vector Multiplication: ik,k->i Batch Matrix Multiplication: ijk,ikl->ijl Batch Diagonal: …ii->…i

Note that TensorRT does not support ellipsis or diagonal operations so, neither, does TensorRT-LLM.

Parameters:

Returns:

The tensor produced by the Einsum operation.

tensorrt_llm.functional.elementwise_binary(

left: Tensor | int | float,

right: Tensor | int | float,

op: ElementWiseOperation,

) → Tensor[source]#

Add an elementwise operation with two inputs.

For each input, that function first creates a constant tensor if the input is an integer or a float. Then, if needed, it expands the smaller tensor to make sure its rank is the same as the larger one. Then, it performs the elementwise operation ‘op’.

The following closures are defined in functional.*:

add for op=trt.ElementWiseOperation.SUM sub for op=trt.ElementWiseOperation.SUB mul for op=trt.ElementWiseOperation.PROD div for op=trt.ElementWiseOperation.DIV floordiv for op=trt.ElementWiseOperation.FLOOR_DIV gt for op=trt.ElementWiseOperation.GREATER lt for op=trt.ElementWiseOperation.LESS op_and for op=trt.ElementWiseOperation.AND op_or for op=trt.ElementWiseOperation.OR eq for op=trt.ElementWiseOperation.EQUAL minimum for op=trt.ElementWiseOperation.MIN maximum for op=trt.ElementWiseOperation.MAX pow for op=trt.ElementWiseOperation.POW

It is implemented using the IElementWiseLayer from TensorRT.

Parameters:

Returns:

The tensor produced by this elementwise operation.

tensorrt_llm.functional.embedding(

input: Tensor,

weight: Tensor,

tp_size=1,

tp_group=None,

sharding_dim=0,

tp_rank=None,

per_token_scale=None,

padding=None,

) → Tensor[source]#

Add an operation to perform embedding lookup.

That operation performs the embedding lookup. The ‘input’ tensor contains the identifiers of the rows of ‘weight’ to gather.

1. Distribute the embedding lookup table over multiple GPU When ‘tp_size’ is greater than 1 and the ‘tp_group’ is defined, this embedding lookup is distributed among multiple GPUs.

When ‘sharding_dim==0’, each GPU stores a subset of the rows of the embedding table rows(that number of rows per GPU is given by weights.shape[0] and the offset to the 1st row stored on the GPU is given by rank * weights.shape[0]). Each parallel rank will query all the indices and set 0s for the weights that are not stored on the associated GPU. To compute the final result, a parallel all-reduce operation is added to the TensorRT graph. That lookup can be performed using either the plugin or the operators TensorRT support.

When’sharding_dim==1’, each GPU stores a subset of the embedding table’s columns. Each rank can obtain a portion of the embedding results. Then the embedding is collected using the all-gather operation. Related transposition operations are also used to obtain the final results.

2. Store embedding lookup table as a whole When ‘tp_size’ is not greater than 1, the embedding lookup table will not be divided. In this case, when the default_net().plugin_config.lookup_plugin is set, the operation is implemented using a plugin (without the all-reduce operation). Otherwise, this operation is implemented using the standard IGatherLayer in TensorRT.

Parameters:

Returns:

The tensor produced by the embedding lookup layer.

tensorrt_llm.functional.eq(

left: ~tensorrt_llm.functional.Tensor | int | float,

right: ~tensorrt_llm.functional.Tensor | int | float,

*,

op: ~tensorrt.tensorrt.ElementWiseOperation = <ElementWiseOperation.EQUAL: 11>,

) → Tensor#

Add an elementwise operation with two inputs.

For each input, that function first creates a constant tensor if the input is an integer or a float. Then, if needed, it expands the smaller tensor to make sure its rank is the same as the larger one. Then, it performs the elementwise operation ‘op’.

The following closures are defined in functional.*:

add for op=trt.ElementWiseOperation.SUM sub for op=trt.ElementWiseOperation.SUB mul for op=trt.ElementWiseOperation.PROD div for op=trt.ElementWiseOperation.DIV floordiv for op=trt.ElementWiseOperation.FLOOR_DIV gt for op=trt.ElementWiseOperation.GREATER lt for op=trt.ElementWiseOperation.LESS op_and for op=trt.ElementWiseOperation.AND op_or for op=trt.ElementWiseOperation.OR eq for op=trt.ElementWiseOperation.EQUAL minimum for op=trt.ElementWiseOperation.MIN maximum for op=trt.ElementWiseOperation.MAX pow for op=trt.ElementWiseOperation.POW

It is implemented using the IElementWiseLayer from TensorRT.

Parameters:

Returns:

The tensor produced by this elementwise operation.

tensorrt_llm.functional.exp(

input: ~tensorrt_llm.functional.Tensor,

*,

op: ~tensorrt.tensorrt.UnaryOperation = <UnaryOperation.EXP: 0>,

) → Tensor#

Add an elementwise operation on a single input.

The following closures are defined in functional.*:

round for op=trt.UnaryOperation.ROUND sqrt for op=trt.UnaryOperation.SQRT exp for op=trt.UnaryOperation.EXP sin for op=trt.UnaryOperation.SIN cos for op=trt.UnaryOperation.COS abs for op=trt.UnaryOperation.ABS log for op=trt.UnaryOperation.LOG

It is implemented using the IUnaryLayer from TensorRT.

Parameters:

Returns:

The tensor produced by this elementwise operation.

tensorrt_llm.functional.expand(

input: Tensor,

expand_shape: Tensor,

) → Tensor[source]#

Add an operation to expand a tensor.

The operation expands the input tensor in the singleton dimensions to the size indicated by the corresponding dimension in the expand_shape tensor. In other words, given an input tensor with dimensions of size 1, those dimensions will be expanded to the size in expand_shape.

For example, a tensor of shape [4, 3, 1, 3] will be expanded to a tensor of shape [4, 3, 2, 3] by the layer created using expand(input, [4, 3, 2, 3]).

The expansion may either replicate the values or be mapped to a view with a stride of 0 in the expanded dimensions. For example, for a tensor [[3, 2]] of shape [1, 2],

expand([[3, 2]], [2, 2])

can be used to expand the input to [[3, 2], [3, 2]].

This operation is implemented using a tensorrt.ISliceLayer. The current implementation does not verify that non singleton dimensions are not shrunk. In other words, for an input of shape [4, 1, 2],

expand(input, [3, 2, 2])

will produce a tensor of shape [3, 2, 2]. That behavior is subject to change in the future.

Parameters:

Returns:

The tensor produced by the expand layer.

tensorrt_llm.functional.expand_dims(

input: Tensor,

dim: int | Sequence[int],

shape_cast_dtype=None,

) → Tensor[source]#

Add an operation to expand the tensor shape with singleton dimensions.

That function adds a tensorrt.IShuffleLayer to the network. Given an ‘input’ of rank N and a sequence of M dimensions, the output tensor produced by this operation (when executed by TensorRT) will have a rank of N+M. Singleton dimensions will be inserted at the different positions in ‘dim’.

The pseudo-code for that operation is:

new_shape, ii = [], 0 for jj in range(input.rank() + len(dim)):

new_shape.append(1 if jj in dims else input.shape[ii++])

For example, for a tensor of shape [3, 4, 1, 5]

expand_dims(input, [0, 2])

will produce a tensor of shape [1, 3, 1, 4, 1, 5].

Parameters:

Returns:

The tensor produced by the shuffle layer.

tensorrt_llm.functional.expand_dims_like(

left: Tensor | int | float,

right: Tensor,

) → Tensor[source]#

Add an operation to expand the first tensor to the same rank as the second tensor.

That function takes a first tensor. It also accepts an integer or a float, in which case it creates a constant tensor from it. In both cases, the rank of that first tensor is compared to the rank of the second tensor. If they are of the same rank, the first tensor is returned. Otherwise, the first tensor is expanded on the left to match the rank of the second tensor.

Note that the shapes do not have to match, only the rank is considered in that function.

For example, for a pair of tensors of shapes [3, 4] and [4, 3, 2], the first tensor will be expanded to a tensor of rank 3 and shape [1, 3, 4].

Parameters:

Returns:

The tensor produced by the shuffle layer.

tensorrt_llm.functional.expand_mask(

mask: Tensor,

tgt_len: Tensor | None = None,

) → Tensor[source]#

Expand an attention mask.

That function adds the sequence of operations to expand from a tensor of shape ‘[batch_size, src_seq_len]’ to a tensor of shape ‘[batch_size, 1, tgt_seq_len, src_seq_len]’. It can be used to create the mask applied to the Q*K^T product before the softmax operation in the multi-head attention block.

Parameters:

Returns:

The tensor created by that sequence of operations.

tensorrt_llm.functional.flatten(

input: Tensor,

start_dim: int = 0,

end_dim: int = -1,

)[source]#

Flattens input by reshaping it into a one-dimensional tensor.

If start_dim or end_dim are passed, only dimensions starting with start_dim and ending with end_dim are flattened. The order of elements in input is unchanged.

Parameters:

Returns:

The tensor produced by the flatten layer.

tensorrt_llm.functional.flip(

input: Tensor,

dims: Sequence[int],

) → Tensor[source]#

Reverses the order of an n-D tensor along given axis in dims.

That flip operation maps to a TensorRT ISliceLayer. For the dimensions listed in dims it copies the elements from the last one to the first one (from (N-1) down to 0 with a step of -1). For the dimensions not in ‘dims’, it copies the elements from the first one to the last one (from 0 to N-1 with a step of 1).

Parameters:

Returns:

The tensor produced by the inserted layer.

tensorrt_llm.functional.floordiv(

left: ~tensorrt_llm.functional.Tensor | int | float,

right: ~tensorrt_llm.functional.Tensor | int | float,

*,

op: ~tensorrt.tensorrt.ElementWiseOperation = <ElementWiseOperation.FLOOR_DIV: 7>,

) → Tensor#

Add an elementwise operation with two inputs.

For each input, that function first creates a constant tensor if the input is an integer or a float. Then, if needed, it expands the smaller tensor to make sure its rank is the same as the larger one. Then, it performs the elementwise operation ‘op’.

The following closures are defined in functional.*:

add for op=trt.ElementWiseOperation.SUM sub for op=trt.ElementWiseOperation.SUB mul for op=trt.ElementWiseOperation.PROD div for op=trt.ElementWiseOperation.DIV floordiv for op=trt.ElementWiseOperation.FLOOR_DIV gt for op=trt.ElementWiseOperation.GREATER lt for op=trt.ElementWiseOperation.LESS op_and for op=trt.ElementWiseOperation.AND op_or for op=trt.ElementWiseOperation.OR eq for op=trt.ElementWiseOperation.EQUAL minimum for op=trt.ElementWiseOperation.MIN maximum for op=trt.ElementWiseOperation.MAX pow for op=trt.ElementWiseOperation.POW

It is implemented using the IElementWiseLayer from TensorRT.

Parameters:

Returns:

The tensor produced by this elementwise operation.

tensorrt_llm.functional.gather(

input: Tensor,

dim: int,

indices: Tensor | int,

) → Tensor[source]#

Add an operation to gather elements from a tensor.

That function implements the GatherElements operator from the ONNX specification as described in

The input and indices arguments must have the same rank >= 1. The operation will produce a tensor with the same shape as the indices tensor. The axis is the dimension to gather on.

As shown in the ONNX description, for a 3D tensor, the output is:

out[i][j][k] = input[indices[i][j][k]][j][k] if axis = 0, out[i][j][k] = input[i][indices[i][j][k]][k] if axis = 1, out[i][j][k] = input[i][j][indices[i][j][k]] if axis = 2.

For example,

gather([[4, 2], [5, 3]], 0, [[1, 0], [0, 1]])

will produce [[5, 2], [4, 3]].

gather([[1, 2, 3], [4, 5, 6], 1, [[1], [0]])

will produce [[2], [4]]. See the ONNX documentation for more examples.

That operation maps to the TensorRT IGatherLayer.

Parameters:

Returns:

The tensor containing the gathered elements. It has the same shape as the indices tensor.

tensorrt_llm.functional.gather_last_token_logits(

hidden_states: Tensor,

last_token_ids: Tensor,

remove_input_padding: bool,

) → Tensor[source]#

Extract the logits that correspond to the last token from the hidden states.

That function adds the operations to extract the logits of the last tokens in a batch of sequences.

Depending on whether ‘remove_input_padding’ is ‘True’ or ‘False’, that function assumes inputs of different shapes.

When ‘remove_input_padding’ is ‘True’, the ‘hidden_states’ tensor is assumed to be packed. It has a shape ‘[num_tokens, hidden_dim]’ where ‘num_tokens’ is the sum of the lengths of the sequences in the batch and ‘hidden_dim’ is the hidden dimension. The ‘last_tokens_ids’ is a 1D tensor that encodes the inclusive prefix-sums of the lengths of the sequences in the batch.

When ‘remove_input_padding’ is ‘False’, the ‘hidden_states’ tensor is assumed to be padded. It has a shape ‘[batch_size, max_seqlen, hidden_dim]’ where ‘max_seqlen’ is the length of the longest sequence in the batch and ‘hidden_dim’ is the hidden dimension. The ‘last_token_ids’ is a 1D tensor that encodes the length of each sequence in the batch.

In both cases, that function produces a tensor of shape ‘[batch_size, hidden_size]’ where the row at index ‘i’ corresponds to the logits of the last token from the ‘i’-th sequence.

Parameters:

Returns:

The tensor created by that sequence of operations.

tensorrt_llm.functional.gather_nd(

input: Tensor,

indices: Tensor,

batch_dims: int = 1,

) → Tensor[source]#

Adds a layer that performs a gather with some element-wise dimensions. See: https://onnx.ai/onnx/operators/onnx__GatherND.htmlThe gather is performed on dim=batch_dims.

Parameters:

Returns:

A tensor created by the gather layer with GatherMode.ND.

tensorrt_llm.functional.gegelu(

x: Tensor,

limit: float | None = None,

) → Tensor[source]#

tensorrt_llm.functional.geglu(

x: Tensor,

) → Tensor[source]#

Add a Gated-GELU operation.

That function takes a tensor, splits it into two halves along the last dimension, applies GELU to the second half and multiply the results. The behavior is undefined if the last dimension is not even.

Parameters:

input – Tensor The input tensor on which the activation function is applied.

Returns:

The tensor produced by the activation layer.

tensorrt_llm.functional.gelu(

x: Tensor,

) → Tensor[source]#

Add a GELU operation.

Parameters:

input – Tensor The input tensor on which the activation function is applied.

Returns:

The tensor produced by the activation layer.

tensorrt_llm.functional.gemm_allreduce(

a: Tensor,

b: Tensor,

group: List[int],

transa: bool = False,

transb: bool = False,

alpha: ndarray | Tensor | None = None,

output_dtype: DataType | None = None,

fp8_inputs_override: bool = False,

a_sf: Tensor | None = None,

b_sf: Tensor | None = None,

)[source]#

Add an operation that performs fused GEMM+AllReduce.

Parameters:

Returns:

Returns GEMM output tensor which has been reduced across ranks.

tensorrt_llm.functional.gemm_swiglu(

input: Tensor,

weight: Tensor,

bias: Tensor | None = None,

scale_d0: float = 1.0,

scale_d1: float = 1.0,

scale_output: float = 1.0,

) → Tensor[source]#

Add a matrix multiplication, followed by SwiGLU (x * SiLU(gate)) operation.

The second SwiGLU operation takes the preceding tensor, splits it into two halves along the last dimension, applies SiLU to the second half and multiply the results. The behaviour is undefined if the last dimension is not even.

Parameters: input : Tensor

The first tensor (often called A).

weightTensor

The second tensor (often called B).

biasOptional[Tensor]

The per-channel bias. The plugin with fp8 dtype does not support bias yet.

scale_d0float

The scale for dequantizing x, used for fp8

scale_d1float

The scale for dequantizing gate, used for fp8

scale_outputfloat

The scale for quantizing output, used for fp8

Returns:

The tensor produced by the inserted layer.

tensorrt_llm.functional.generate_alibi_biases(

slopes: Tensor,

key_length: Tensor,

) → Tensor[source]#

Compute the ALiBi biases as described in https://arxiv.org/abs/2211.05100.

The ALiBi biases are added to the result of the Q*K^T product in the multi-head attention block.

Parameters:

Returns:

A constant tensor that contains the ALiBi biases.

tensorrt_llm.functional.generate_alibi_slopes(

num_heads: int,

tp_size: int = 1,

tp_rank: int = 0,

alibi_scale: float = 1.0,

alibi_bias_max: int = 8,

) → ndarray[source]#

Compute the ALiBi slopes as described in https://arxiv.org/abs/2211.05100.

Parameters:

Returns:

A constant tensor that contains the ALiBi slopes.

tensorrt_llm.functional.generate_logn_scaling(

seq_length: int = 8192,

max_position_embeddings: int = 32768,

) → ndarray[source]#

Compute the Log-N scaling vector for Qwen inference extrapolation

Parameters:

Returns:

A constant np.ndarray that contains logn scaling vector

tensorrt_llm.functional.gpt_attention(*, qkv: ~tensorrt_llm.functional.Tensor, past_key_value: ~tensorrt_llm.functional.Tensor, attention_mask: ~tensorrt_llm.functional.Tensor | None = None, attention_packed_mask: ~tensorrt_llm.functional.Tensor | None = None, sequence_length: ~tensorrt_llm.functional.Tensor, host_past_key_value_lengths: ~tensorrt_llm.functional.Tensor | None, host_max_attention_window_sizes: ~tensorrt_llm.functional.Tensor, host_sink_token_length: ~tensorrt_llm.functional.Tensor, context_lengths: ~tensorrt_llm.functional.Tensor | None, cache_indirection: ~tensorrt_llm.functional.Tensor | None, host_request_types: ~tensorrt_llm.functional.Tensor, layer_idx: int, num_heads: int, num_kv_heads: int, hidden_size_per_head: int, q_scaling: float, attn_logit_softcapping_scale: float = 0.0, rotary_embedding_dim: int = 0, rotary_embedding_base: float = 10000.0, rotary_embedding_scale_type: ~tensorrt_llm.functional.RotaryScalingType = RotaryScalingType.none, rotary_embedding_short_m_scale: float = 1.0, rotary_embedding_long_m_scale: float = 1.0, rotary_embedding_scale: float = 1.0, rotary_embedding_max_positions: int = 1024, rotary_embedding_original_max_positions: int = 1024, position_embedding_type: ~tensorrt_llm.functional.PositionEmbeddingType = PositionEmbeddingType.learned_absolute, rotary_inv_freq: ~tensorrt_llm.functional.Tensor | None = None, rotary_cos_sin: ~tensorrt_llm.functional.Tensor | None = None, kv_orig_quant_scale: ~tensorrt_llm.functional.Tensor | None = None, kv_quant_orig_scale: ~tensorrt_llm.functional.Tensor | None = None, attention_output_orig_quant_scale: ~tensorrt_llm.functional.Tensor | None = None, attention_output_sf_scale: ~tensorrt_llm.functional.Tensor | None = None, kv_cache_quant_mode: ~tensorrt_llm._utils.QuantModeWrapper | ~tensorrt_llm.quantization.mode.QuantMode = <QuantMode: 0>, max_context_length: int | None = None, mask_type: ~tensorrt_llm.functional.AttentionMaskType = AttentionMaskType.causal, block_sparse_block_size: int = 64, block_sparse_homo_head_pattern: bool = False, block_sparse_num_local_blocks: int = 16, block_sparse_vertical_stride: int = 8, alibi_slopes: ~tensorrt_llm.functional.Tensor | None = None, tp_size: int = 1, tp_rank: int = 0, vision_start: int = -1, vision_length: int = -1, kv_cache_block_offsets: ~tensorrt_llm.functional.Tensor | None = None, host_kv_cache_block_offsets: ~tensorrt_llm.functional.Tensor = None, host_kv_cache_pool_pointers: ~tensorrt_llm.functional.Tensor = None, host_kv_cache_pool_mapping: ~tensorrt_llm.functional.Tensor = None, do_cross_attention: bool = False, cross_kv: ~tensorrt_llm.functional.Tensor | None = None, cross_kv_length: ~tensorrt_llm.functional.Tensor | None = None, encoder_input_lengths: ~tensorrt_llm.functional.Tensor | None = None, relative_attention_bias: ~tensorrt_llm.functional.Tensor | None = None, logn_scaling: ~tensorrt_llm.functional.Tensor | None = None, max_distance: int = 0, host_context_lengths: ~tensorrt_llm.functional.Tensor | None = None, qkv_bias: ~tensorrt_llm.functional.Tensor | None = None, use_cache: bool = True, spec_decoding_is_generation_length_variable: bool = False, spec_decoding_max_generation_length: int = 0, spec_decoding_generation_lengths: ~tensorrt_llm.functional.Tensor = None, spec_decoding_position_offsets: ~tensorrt_llm.functional.Tensor = None, spec_decoding_packed_mask: ~tensorrt_llm.functional.Tensor = None, spec_decoding_use: ~tensorrt_llm.functional.Tensor = None, long_rope_rotary_inv_freq: ~tensorrt_llm.functional.Tensor | None = None, long_rope_rotary_cos_sin: ~tensorrt_llm.functional.Tensor | None = None, mrope_rotary_cos_sin: ~tensorrt_llm.functional.Tensor = None, mrope_position_deltas: ~tensorrt_llm.functional.Tensor = None, host_runtime_perf_knobs: ~tensorrt_llm.functional.Tensor | None = None, host_context_progress: ~tensorrt_llm.functional.Tensor = None, is_mla_enabled_flag: bool = False, q_lora_rank: int = 0, kv_lora_rank: int = 0, qk_nope_head_dim: int = 0, qk_rope_head_dim: int = 0, v_head_dim: int = 0, q_b_proj: ~tensorrt_llm.functional.Tensor | None = None, kv_b_proj: ~tensorrt_llm.functional.Tensor | None = None, k_b_proj_trans: ~tensorrt_llm.functional.Tensor | None = None, skip_attn=None, cp_group: ~typing.List[int] = [0], cp_size: int = 1, cp_rank: int = 0, num_kv_heads_origin: int = -1) → Tuple[Tensor, Tensor | None][source]#

Add an operation that performs the multi-head attention in GPT-like models.

The signature of the function will change in the future release - we are in the process of simplifying the API. The current version is still work-in-progress! The following API is provided with hints regarding the arguments that are likely to be removed or merged with others in the future release.

See docs/source/advanced/gpt-attention.md for the documentation of that function.

Parameters:

Returns:

The tensor produced by that layer.

tensorrt_llm.functional.group_norm(

input: Tensor,

num_groups: int,

weight: Tensor | None = None,

bias: Tensor | None = None,

eps: float = 1e-05,

)[source]#

tensorrt_llm.functional.gt(

left: ~tensorrt_llm.functional.Tensor | int | float,

right: ~tensorrt_llm.functional.Tensor | int | float,

*,

op: ~tensorrt.tensorrt.ElementWiseOperation = <ElementWiseOperation.GREATER: 12>,

) → Tensor#

Add an elementwise operation with two inputs.

For each input, that function first creates a constant tensor if the input is an integer or a float. Then, if needed, it expands the smaller tensor to make sure its rank is the same as the larger one. Then, it performs the elementwise operation ‘op’.

The following closures are defined in functional.*:

add for op=trt.ElementWiseOperation.SUM sub for op=trt.ElementWiseOperation.SUB mul for op=trt.ElementWiseOperation.PROD div for op=trt.ElementWiseOperation.DIV floordiv for op=trt.ElementWiseOperation.FLOOR_DIV gt for op=trt.ElementWiseOperation.GREATER lt for op=trt.ElementWiseOperation.LESS op_and for op=trt.ElementWiseOperation.AND op_or for op=trt.ElementWiseOperation.OR eq for op=trt.ElementWiseOperation.EQUAL minimum for op=trt.ElementWiseOperation.MIN maximum for op=trt.ElementWiseOperation.MAX pow for op=trt.ElementWiseOperation.POW

It is implemented using the IElementWiseLayer from TensorRT.

Parameters:

Returns:

The tensor produced by this elementwise operation.

tensorrt_llm.functional.identity(

input: Tensor,

) → Tensor[source]#

Add an identity operation.

TODO: Document why it can be done using a plugin!!!

Parameters:

input – Tensor The input tensor.

Returns:

The tensor produced by this identity operation.

tensorrt_llm.functional.index_select(

input: Tensor,

dim: int,

index: Tensor,

) → Tensor[source]#

Add an operation to select slices of elements from a tensor.

Given an input tensor, that function creates an operation that selects the slices of elements in the dimension ‘dim’ at the indices listed in ‘index’ to create a new tensor. The output tensor has the same rank as the input tensor.

The ‘index’ is a tensor of rank 1.

For example, on input=[[4, 2, 5], [2, 1, 2], [4, 7, 1]], which has a shape [3, 3],

index_select(input, 0, [0, 1])

will create a tensor of shape [2, 3] that contains the [[4, 2, 5], [2, 1, 2]].

Regarding the shape of the output tensor, the dimension ‘dim’ has the same size as the ‘index’ tensor. It means that for a input tensor of shape [4, 2, 6, 3],

index_select(input, 2, [1, 4])

will select the 2nd and 5th slices (index == 1 or 4) from the 3rd dimension (dim == 2) and return a tensor of shape [4, 2, 2, 3] (i.e. the 3rd dimension is shrunk to 2).

Note that this operation can also be used to expand a tensor in the ‘dim’ dimension, for example, on input [[0, 1], [2, 3]],

index_select(input, 1, [0, 0, 0])

will produce a tensor of shape [2, 3] containing [[0, 0, 0], [2, 2, 2]].

That operation maps to the TensorRT IGatherLayer.

Parameters:

Returns:

The tensor containing the selected slices.

tensorrt_llm.functional.int_clip(

input: Tensor,

lower: int,

upper: int,

) → Tensor[source]#

tensorrt_llm.functional.interpolate(

input: Tensor,

size: int | List[int] = None,

scale_factor: float | List[float] = None,

mode: str = 'nearest',

align_corners: bool = False,

recompute_scale_factor: bool = False,

antialias: bool = False,

) → Tensor[source]#

tensorrt_llm.functional.is_gated_activation(activation)[source]#

Is a given activation function gated?

Parameters:

activation – str The name of the activation function.

Returns:

True if the function is gated, False otherwise.

tensorrt_llm.functional.layer_norm(

input: Tensor,

normalized_shape: int | Tuple[int],

weight: Tensor | None = None,

bias: Tensor | None = None,

eps: float = 1e-05,

use_diff_of_squares: bool = True,

) → Tensor[source]#

Add a layer-norm operation on a tensor.

That operation applies the layer-normalization to its input tensor. In its simplest form, for large language models, the ‘normalized_shape’ should be set to the hidden dimension of the activation tensor. Otherwise, it is the shape of the normalized fraction of the tensor (starting from the right-most dimension).

The ‘weight’ tensor corresponds to ‘gamma’ in the layer-norm formula and ‘bias’ is ‘beta’. The ‘eps’ value is added to the variance before computing the squared-root.

This implementation (when using the plugin) supports an additional flag to enable/disable the use of a difference of squares (‘Var = Mean(X^2) - Mean(X)^2’).

Parameters:

Returns:

The output tensor of that operation.

tensorrt_llm.functional.log(

input: ~tensorrt_llm.functional.Tensor,

*,

op: ~tensorrt.tensorrt.UnaryOperation = <UnaryOperation.LOG: 1>,

) → Tensor#

Add an elementwise operation on a single input.

The following closures are defined in functional.*:

round for op=trt.UnaryOperation.ROUND sqrt for op=trt.UnaryOperation.SQRT exp for op=trt.UnaryOperation.EXP sin for op=trt.UnaryOperation.SIN cos for op=trt.UnaryOperation.COS abs for op=trt.UnaryOperation.ABS log for op=trt.UnaryOperation.LOG

It is implemented using the IUnaryLayer from TensorRT.

Parameters:

Returns:

The tensor produced by this elementwise operation.

tensorrt_llm.functional.log_softmax(

input: Tensor,

dim: int,

) → Tensor[source]#

This function is equivalent of torch.nn.functional.log_softmax() i.e. it performs log(softmax(input)) in a safer and faster way.

Parameters:

Returns:

A tensor of same shape as input with log_softmax computed on the specified dim.

tensorrt_llm.functional.lora_plugin(

input: Tensor = None,

in_hidden_size: int = 0,

out_hidden_sizes: List[int] = [0],

host_request_types: Tensor = None,

transa: bool = False,

transb: bool = False,

host_context_lengths: Tensor = None,

max_low_rank: int = 0,

lora_ranks: List[Tensor] = None,

lora_weights_pointers: List[Tensor] = None,

weight_index: int = 0,

)[source]#

Parameters:

Returns:

The tensor produced by that layer.

tensorrt_llm.functional.low_latency_gemm(

input: Tensor,

mat2: Tensor,

alpha: ndarray | None = None,

strict_dtype: DataType | None = None,

) → Tensor[source]#

tensorrt_llm.functional.low_latency_gemm_swiglu(

input: Tensor,

weight: Tensor,

scale_d0: float = 1.0,

scale_d1: float = 1.0,

scale_output: float = 1.0,

) → Tensor[source]#

Add a matrix multiplication, followed by SwiGLU (x * SiLU(gate)) operation.

The second SwiGLU operation takes the preceding tensor, splits it into two halves along the last dimension, applies SiLU to the second half and multiply the results. The behaviour is undefined if the last dimension is not even.

Parameters: input : Tensor

The first tensor (often called A).

weightTensor

The second tensor (often called B).

scale_d0float

The scale for dequantizing x, used for fp8

scale_d1float

The scale for dequantizing gate, used for fp8

scale_outputfloat

The scale for quantizing output, used for fp8

Returns:

The tensor produced by the inserted layer.

tensorrt_llm.functional.lt(

left: ~tensorrt_llm.functional.Tensor | int | float,

right: ~tensorrt_llm.functional.Tensor | int | float,

*,

op: ~tensorrt.tensorrt.ElementWiseOperation = <ElementWiseOperation.LESS: 13>,

) → Tensor#

Add an elementwise operation with two inputs.

For each input, that function first creates a constant tensor if the input is an integer or a float. Then, if needed, it expands the smaller tensor to make sure its rank is the same as the larger one. Then, it performs the elementwise operation ‘op’.

The following closures are defined in functional.*:

add for op=trt.ElementWiseOperation.SUM sub for op=trt.ElementWiseOperation.SUB mul for op=trt.ElementWiseOperation.PROD div for op=trt.ElementWiseOperation.DIV floordiv for op=trt.ElementWiseOperation.FLOOR_DIV gt for op=trt.ElementWiseOperation.GREATER lt for op=trt.ElementWiseOperation.LESS op_and for op=trt.ElementWiseOperation.AND op_or for op=trt.ElementWiseOperation.OR eq for op=trt.ElementWiseOperation.EQUAL minimum for op=trt.ElementWiseOperation.MIN maximum for op=trt.ElementWiseOperation.MAX pow for op=trt.ElementWiseOperation.POW

It is implemented using the IElementWiseLayer from TensorRT.

Parameters:

Returns:

The tensor produced by this elementwise operation.

tensorrt_llm.functional.mamba_conv1d(

input: Tensor,

conv_state_or_ptr: Tensor,

conv_weight: Tensor,

conv_bias: Tensor,

host_request_types: Tensor,

last_token_ids: Tensor,

dim: int,

dconv: int,

dtype: str,

pre_stride: int = 0,

post_stride: int = 0,

host_context_lengths: Tensor | None = None,

slot_mapping: Tensor | None = None,

apply_silu: bool = True,

)[source]#

Parameters:

tensorrt_llm.functional.masked_scatter(

input: Tensor,

mask: Tensor,

source: Tensor,

) → Tensor[source]#

Add the masked_scatter base on PyTorch definition.

See https://pytorch.org/docs/stable/generated/torch.Tensor.masked_scatter_.html#torch-tensor-masked-scatter for a description of that function.

Parameters:

Returns:

The tensor containing the source tensor selected by mask.

tensorrt_llm.functional.masked_select(

input: Tensor,

mask: Tensor,

) → Tensor[source]#

Add an operation to select elements from a tensor according to a boolean mask tensor.

Given an input tensor, that function creates an operation that selects elements at the indices indicated by the boolean mask tensor to create a new tensor. The output tensor is a 1-D tensor.

The input tensor must have rank >= 1. The shapes of the input tensor and the mask tensor don’t need to match, but they must be able to be broadcasted.

For example, on input=[[4, 2, 5], [2, 1, 2], [4, 7, 1]], which has a shape [3, 3],

masked_select(input, [[True, False, True], [False, True, False], [True, False, True]])

will create a tensor of shape [5] that contains the [4, 5, 1, 4, 1].

masked_select(input, [[True], [False], [True]])

will create a tensor of shape [6] that contains the [4, 2, 5, 4, 7, 1].

masked_select(input, [[False, False, True]])

will create a tensor of shape [3] that contains the [5, 2, 1].

masked_select(input, [False])

will create a tensor of shape [0] which is empty.

That operation is implemented by NonZero, Shuffle and GatherV2 layers in TensorRT.

Parameters:

Returns:

The 1-D tensor containing the selected elements.

tensorrt_llm.functional.matmul(

input: Tensor,

mat2: Tensor,

transa: bool = False,

transb: bool = False,

use_fp32_acc: bool = True,

) → Tensor[source]#

Add a matrix multiplication.

That operation maps to a tensorrt.IMatrixMultiplyLayer layer. As explained in the TensorRT documentation, it computes the inner product between the two inputs after applying an optional transposition on the inputs.

Parameters:

Returns:

The tensor produced by the inserted layer.

tensorrt_llm.functional.max(

input: Tensor,

dim: int,

keepdim: bool = False,

) → Tensor[source]#

Add an operation to compute the max along a dimension.

Computes the max along the dimension ‘dim’ of the input tensor.

It is implemented using the IReduceLayer from TensorRT.

Parameters:

Returns:

The tensor produced by this reduction operation.

tensorrt_llm.functional.maximum(

left: ~tensorrt_llm.functional.Tensor | int | float,

right: ~tensorrt_llm.functional.Tensor | int | float,

*,

op: ~tensorrt.tensorrt.ElementWiseOperation = <ElementWiseOperation.MAX: 2>,

) → Tensor#

Add an elementwise operation with two inputs.

For each input, that function first creates a constant tensor if the input is an integer or a float. Then, if needed, it expands the smaller tensor to make sure its rank is the same as the larger one. Then, it performs the elementwise operation ‘op’.

The following closures are defined in functional.*:

add for op=trt.ElementWiseOperation.SUM sub for op=trt.ElementWiseOperation.SUB mul for op=trt.ElementWiseOperation.PROD div for op=trt.ElementWiseOperation.DIV floordiv for op=trt.ElementWiseOperation.FLOOR_DIV gt for op=trt.ElementWiseOperation.GREATER lt for op=trt.ElementWiseOperation.LESS op_and for op=trt.ElementWiseOperation.AND op_or for op=trt.ElementWiseOperation.OR eq for op=trt.ElementWiseOperation.EQUAL minimum for op=trt.ElementWiseOperation.MIN maximum for op=trt.ElementWiseOperation.MAX pow for op=trt.ElementWiseOperation.POW

It is implemented using the IElementWiseLayer from TensorRT.

Parameters:

Returns:

The tensor produced by this elementwise operation.

tensorrt_llm.functional.mean(

input: Tensor,

dim: int | Tuple[int],

keepdim: bool = False,

) → Tensor[source]#

Add an operation to compute the mean along a dimension.

Computes the mean along the dimension ‘dim’ of the input tensor.

It is implemented using the IReduceLayer from TensorRT.

Parameters:

Returns:

The tensor produced by this reduction operation.

tensorrt_llm.functional.meshgrid2d(

x: Tensor,

y: Tensor,

) → Tuple[Tensor][source]#

Creates grids (2D) of coordinates specified by the 1D inputs (only supports indexing=’xy’).

Parameters:

Returns:

The tuple of two tensors produced.

TODO: Add full support for torch.meshgrid.

See https://pytorch.org/docs/stable/generated/torch.meshgrid.html#torch-meshgrid

tensorrt_llm.functional.min(input: ~tensorrt_llm.functional.Tensor, *, op: ~tensorrt.tensorrt.ReduceOperation = <ReduceOperation.MIN: 3>, dim: int | ~typing.Tuple[int], keepdim: bool = False) → Tensor#

Add an reduction operation to do along a dimension.

It is implemented using the IReduceLayer from TensorRT.

Parameters:

Returns:

The tensor produced by this reduction operation.

tensorrt_llm.functional.minimum(

left: ~tensorrt_llm.functional.Tensor | int | float,

right: ~tensorrt_llm.functional.Tensor | int | float,

*,

op: ~tensorrt.tensorrt.ElementWiseOperation = <ElementWiseOperation.MIN: 3>,

) → Tensor#

Add an elementwise operation with two inputs.

For each input, that function first creates a constant tensor if the input is an integer or a float. Then, if needed, it expands the smaller tensor to make sure its rank is the same as the larger one. Then, it performs the elementwise operation ‘op’.

The following closures are defined in functional.*:

add for op=trt.ElementWiseOperation.SUM sub for op=trt.ElementWiseOperation.SUB mul for op=trt.ElementWiseOperation.PROD div for op=trt.ElementWiseOperation.DIV floordiv for op=trt.ElementWiseOperation.FLOOR_DIV gt for op=trt.ElementWiseOperation.GREATER lt for op=trt.ElementWiseOperation.LESS op_and for op=trt.ElementWiseOperation.AND op_or for op=trt.ElementWiseOperation.OR eq for op=trt.ElementWiseOperation.EQUAL minimum for op=trt.ElementWiseOperation.MIN maximum for op=trt.ElementWiseOperation.MAX pow for op=trt.ElementWiseOperation.POW

It is implemented using the IElementWiseLayer from TensorRT.

Parameters:

Returns:

The tensor produced by this elementwise operation.

tensorrt_llm.functional.modulo(

x: Tensor,

y: Tensor | int,

) → Tensor[source]#

This function adds an element-wise modulo (x % y) operation for a given tensor. Since there is no TensorRT layer that can directly perform this, this function implements it using some of the basic operations.

Returns:

A tensor that represents (x % y) modulo operation.

tensorrt_llm.functional.mul(

left: ~tensorrt_llm.functional.Tensor | int | float,

right: ~tensorrt_llm.functional.Tensor | int | float,

*,

op: ~tensorrt.tensorrt.ElementWiseOperation = <ElementWiseOperation.PROD: 1>,

) → Tensor#

Add an elementwise operation with two inputs.

For each input, that function first creates a constant tensor if the input is an integer or a float. Then, if needed, it expands the smaller tensor to make sure its rank is the same as the larger one. Then, it performs the elementwise operation ‘op’.

The following closures are defined in functional.*:

add for op=trt.ElementWiseOperation.SUM sub for op=trt.ElementWiseOperation.SUB mul for op=trt.ElementWiseOperation.PROD div for op=trt.ElementWiseOperation.DIV floordiv for op=trt.ElementWiseOperation.FLOOR_DIV gt for op=trt.ElementWiseOperation.GREATER lt for op=trt.ElementWiseOperation.LESS op_and for op=trt.ElementWiseOperation.AND op_or for op=trt.ElementWiseOperation.OR eq for op=trt.ElementWiseOperation.EQUAL minimum for op=trt.ElementWiseOperation.MIN maximum for op=trt.ElementWiseOperation.MAX pow for op=trt.ElementWiseOperation.POW

It is implemented using the IElementWiseLayer from TensorRT.

Parameters:

Returns:

The tensor produced by this elementwise operation.

tensorrt_llm.functional.non_gated_version(activation)[source]#

Given an activation function, get the non-gated version.

If the activation function is non-gated, it returns the same activation function name.

For example, that function returns ‘silu’ for ‘swiglu’ and ‘relu’ for ‘relu’.

Parameters:

activation – str The name of the activation function.

Returns:

The name of the non-gated activation function.

tensorrt_llm.functional.nonzero(

input: Tensor,

) → Tensor[source]#

Adds a layer that finds the indices of non-zero values of the input tensor.

Parameters:

input – Tensor The input tensor for which we need to find the indices of non-zero values.

Returns:

A tensor of shape [D, C] where D is the number of dimensions of input and C is the number of non-zero values in it. Each column of this 2D tensor represents the index tuple for each non-zero value.

tensorrt_llm.functional.not_op(

input: ~tensorrt_llm.functional.Tensor,

*,

op: ~tensorrt.tensorrt.UnaryOperation = <UnaryOperation.NOT: 20>,

) → Tensor#

Add an elementwise operation on a single input.

The following closures are defined in functional.*:

round for op=trt.UnaryOperation.ROUND sqrt for op=trt.UnaryOperation.SQRT exp for op=trt.UnaryOperation.EXP sin for op=trt.UnaryOperation.SIN cos for op=trt.UnaryOperation.COS abs for op=trt.UnaryOperation.ABS log for op=trt.UnaryOperation.LOG

It is implemented using the IUnaryLayer from TensorRT.

Parameters:

Returns:

The tensor produced by this elementwise operation.

tensorrt_llm.functional.op_and(

left: ~tensorrt_llm.functional.Tensor | int | float,

right: ~tensorrt_llm.functional.Tensor | int | float,

*,

op: ~tensorrt.tensorrt.ElementWiseOperation = <ElementWiseOperation.AND: 8>,

) → Tensor#

Add an elementwise operation with two inputs.

For each input, that function first creates a constant tensor if the input is an integer or a float. Then, if needed, it expands the smaller tensor to make sure its rank is the same as the larger one. Then, it performs the elementwise operation ‘op’.

The following closures are defined in functional.*:

add for op=trt.ElementWiseOperation.SUM sub for op=trt.ElementWiseOperation.SUB mul for op=trt.ElementWiseOperation.PROD div for op=trt.ElementWiseOperation.DIV floordiv for op=trt.ElementWiseOperation.FLOOR_DIV gt for op=trt.ElementWiseOperation.GREATER lt for op=trt.ElementWiseOperation.LESS op_and for op=trt.ElementWiseOperation.AND op_or for op=trt.ElementWiseOperation.OR eq for op=trt.ElementWiseOperation.EQUAL minimum for op=trt.ElementWiseOperation.MIN maximum for op=trt.ElementWiseOperation.MAX pow for op=trt.ElementWiseOperation.POW

It is implemented using the IElementWiseLayer from TensorRT.

Parameters:

Returns:

The tensor produced by this elementwise operation.

tensorrt_llm.functional.op_or(

left: ~tensorrt_llm.functional.Tensor | int | float,

right: ~tensorrt_llm.functional.Tensor | int | float,

*,

op: ~tensorrt.tensorrt.ElementWiseOperation = <ElementWiseOperation.OR: 9>,

) → Tensor#

Add an elementwise operation with two inputs.

For each input, that function first creates a constant tensor if the input is an integer or a float. Then, if needed, it expands the smaller tensor to make sure its rank is the same as the larger one. Then, it performs the elementwise operation ‘op’.

The following closures are defined in functional.*:

add for op=trt.ElementWiseOperation.SUM sub for op=trt.ElementWiseOperation.SUB mul for op=trt.ElementWiseOperation.PROD div for op=trt.ElementWiseOperation.DIV floordiv for op=trt.ElementWiseOperation.FLOOR_DIV gt for op=trt.ElementWiseOperation.GREATER lt for op=trt.ElementWiseOperation.LESS op_and for op=trt.ElementWiseOperation.AND op_or for op=trt.ElementWiseOperation.OR eq for op=trt.ElementWiseOperation.EQUAL minimum for op=trt.ElementWiseOperation.MIN maximum for op=trt.ElementWiseOperation.MAX pow for op=trt.ElementWiseOperation.POW

It is implemented using the IElementWiseLayer from TensorRT.

Parameters:

Returns:

The tensor produced by this elementwise operation.

tensorrt_llm.functional.op_xor(

left: ~tensorrt_llm.functional.Tensor | int | float,

right: ~tensorrt_llm.functional.Tensor | int | float,

*,

op: ~tensorrt.tensorrt.ElementWiseOperation = <ElementWiseOperation.XOR: 10>,

) → Tensor#

Add an elementwise operation with two inputs.

For each input, that function first creates a constant tensor if the input is an integer or a float. Then, if needed, it expands the smaller tensor to make sure its rank is the same as the larger one. Then, it performs the elementwise operation ‘op’.

The following closures are defined in functional.*:

add for op=trt.ElementWiseOperation.SUM sub for op=trt.ElementWiseOperation.SUB mul for op=trt.ElementWiseOperation.PROD div for op=trt.ElementWiseOperation.DIV floordiv for op=trt.ElementWiseOperation.FLOOR_DIV gt for op=trt.ElementWiseOperation.GREATER lt for op=trt.ElementWiseOperation.LESS op_and for op=trt.ElementWiseOperation.AND op_or for op=trt.ElementWiseOperation.OR eq for op=trt.ElementWiseOperation.EQUAL minimum for op=trt.ElementWiseOperation.MIN maximum for op=trt.ElementWiseOperation.MAX pow for op=trt.ElementWiseOperation.POW

It is implemented using the IElementWiseLayer from TensorRT.

Parameters:

Returns:

The tensor produced by this elementwise operation.

tensorrt_llm.functional.outer(

input: Tensor,

vec2: Tensor,

) → Tensor[source]#

Add an operation to compute the outer product between two tensors.

That operation creates an Einsum node.

Parameters:

Returns:

The output tensor produced by this layer.

tensorrt_llm.functional.pad(

input: Tensor,

pad: Sequence[int] | Tensor,

mode: str = 'constant',

value: float | None = None,

) → Tensor[source]#

Add a pad layer.

The padding layer adds zero-padding at the start and end of the input tensor. And the padding size by which to pad some dimensions of input are described starting from the last dimension and moving forward.

[len(pad) / 2] dimensions of input will be padded. For example, to pad only the last dimension of the input tensor, then pad has the form [padding_left, padding_right]; to pad the last 2 dimensions of the input tensor, then use [padding_left, padding_right, padding_top, padding_bottom]; to pad the last 3 dimensions, use [padding_left, padding_right, padding_top, padding_bottom, padding_front, padding_back].

Parameters:

Returns:

The tensor produced by the inserted layer.

tensorrt_llm.functional.permute(

input: Tensor,

dims: Sequence[int],

) → Tensor[source]#

Add an operation to permute the dimensions of a tensor.

The dimensions of the input tensor are permuted according to the sequence of dimensions in ‘dims’. That operation maps to tensorrt.IShuffleLayer where the second transposition is described by the indices in ‘dims’.

Given a tensor of rank N, the result of the permutation is a tensor of rank N in which the i-th input dimension maps to the dims[i]-th dimension.

For example, permute(input, [1, 0]) will transpose a 2D tensor by permuting the rows and columns.

Parameters:

Returns:

The tensor produced by the permutation layer.

tensorrt_llm.functional.pow(

left: ~tensorrt_llm.functional.Tensor | int | float,

right: ~tensorrt_llm.functional.Tensor | int | float,

*,

op: ~tensorrt.tensorrt.ElementWiseOperation = <ElementWiseOperation.POW: 6>,

) → Tensor#

Add an elementwise operation with two inputs.

For each input, that function first creates a constant tensor if the input is an integer or a float. Then, if needed, it expands the smaller tensor to make sure its rank is the same as the larger one. Then, it performs the elementwise operation ‘op’.

The following closures are defined in functional.*:

add for op=trt.ElementWiseOperation.SUM sub for op=trt.ElementWiseOperation.SUB mul for op=trt.ElementWiseOperation.PROD div for op=trt.ElementWiseOperation.DIV floordiv for op=trt.ElementWiseOperation.FLOOR_DIV gt for op=trt.ElementWiseOperation.GREATER lt for op=trt.ElementWiseOperation.LESS op_and for op=trt.ElementWiseOperation.AND op_or for op=trt.ElementWiseOperation.OR eq for op=trt.ElementWiseOperation.EQUAL minimum for op=trt.ElementWiseOperation.MIN maximum for op=trt.ElementWiseOperation.MAX pow for op=trt.ElementWiseOperation.POW

It is implemented using the IElementWiseLayer from TensorRT.

Parameters:

Returns:

The tensor produced by this elementwise operation.

tensorrt_llm.functional.prod(input: ~tensorrt_llm.functional.Tensor, *, op: ~tensorrt.tensorrt.ReduceOperation = <ReduceOperation.PROD: 1>, dim: int | ~typing.Tuple[int], keepdim: bool = False) → Tensor#

Add an reduction operation to do along a dimension.

It is implemented using the IReduceLayer from TensorRT.

Parameters:

Returns:

The tensor produced by this reduction operation.

tensorrt_llm.functional.quick_gelu(

x: Tensor,

) → Tensor[source]#

tensorrt_llm.functional.rand(

shape: Tensor,

low: float = 0,

high: float = 1,

dtype: str | DataType = 'float32',

) → Tensor[source]#

This operation adds a fill layer that generates a random (uniform) tensor with the specified shape and data type.

Parameters:

Returns:

The generated random tensor produced by the fill layer.

tensorrt_llm.functional.rearrange(

inputs: Tensor | Sequence[Tensor],

expression: str,

**kwargs,

) → Tensor[source]#

Add a rearrange operation on a tensor.

This operation is a reader-friendly smart element reordering for multidimensional tensors, including functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze, stack, concatenate and other operations. Please see: https://einops.rocks/api/rearrange/

For example, if the shape of input tensor is [32, 30, 40, 3], and run:

rearrange(x, ‘b (h h1) (w w1) c -> b h w 1 (c h1 w1) 1’, h1=2, w1=2)

it would produce a tensor with shape as [32, 15, 20, 1, 12, 1].

Parameters:

Returns:

The output tensor of this operation.

tensorrt_llm.functional.recv(

tensor: Tensor,

src: int,

) → Tensor[source]#

Add an operation that performs a recv to a rank from another.

The recv operation receives a tensor from on a rank from another. If a rank ‘i’ receives a tensor from a rank ‘j’, the rank ‘j’ must have a corresponding ‘send’ operation to rank ‘j’. See ‘send’.

That operation is implemented using a plugin that wraps the NCCL recv point-to-point operation. Seehttps://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/p2p.html#ncclrecvfor details.

Parameters:

Returns:

The tensor produced by that layer.

tensorrt_llm.functional.reduce(

input: Tensor,

op: ReduceOperation,

dim: int | Tuple[int],

keepdim: bool = False,

) → Tensor[source]#

Add an reduction operation to do along a dimension.

It is implemented using the IReduceLayer from TensorRT.

Parameters:

Returns:

The tensor produced by this reduction operation.

tensorrt_llm.functional.reduce_scatter(

tensor: Tensor,

group: List[int],

) → Tensor[source]#

tensorrt_llm.functional.relu(

input: ~tensorrt_llm.functional.Tensor,

*,

act_type: ~tensorrt.tensorrt.ActivationType = <ActivationType.RELU: 0>,

) → Tensor#

Add an activation function.

Parameters:

The following closures are defined in functional.*:

relu for op=trt.ActivationType.RELU tanh for op=trt.ActivationType.TANH sigmoid for op=trt.ActivationType.SIGMOID

Returns:

The tensor produced by the activation layer.

tensorrt_llm.functional.repeat(

input: Tensor,

sizes: Sequence[int],

) → Tensor[source]#

Repeats the tensor along the specified dimensions.

Parameters:

Returns:

A tensor except for repeated input tensors along specified dim.

tensorrt_llm.functional.repeat_interleave(

tensor: Tensor,

repeats: int,

dim: int,

) → Tensor[source]#

Repeats elements of a tensor along an axis.

Parameters:

Returns:

A tensor with the same shape as input except for repeated elements along specified dim.

TODO: Allow repeats to be a list of integers and dim to be unspecified.

tensorrt_llm.functional.rg_lru(

input: Tensor,

A: Tensor,

state_or_ptr: Tensor,

host_request_types: Tensor,

last_token_ids: Tensor,

dim: int,

dtype: str,

block_size: int = 0,

y: Tensor | None = None,

y_bias: Tensor | None = None,

gate: Tensor | None = None,

gate_bias: Tensor | None = None,

gate_x: Tensor | None = None,

gate_x_bias: Tensor | None = None,

gate_a: Tensor | None = None,

gate_a_bias: Tensor | None = None,

slot_mapping: Tensor | None = None,

)[source]#

Parameters:

tensorrt_llm.functional.rms_norm(

input: Tensor,

normalized_shape: int | Tuple[int],

num_groups: int = 1,

weight: Tensor | None = None,

eps: float = 1e-06,

) → Tensor[source]#

Add a RMS norm operation on a tensor.

That operation applies the rms-normalization to its input tensor. In its simplest form, for large language models, the ‘normalized_shape’ should be set to the hidden dimension of the activation tensor. Otherwise, it is the shape of the normalized fraction of the tensor (starting from the right-most dimension).

The ‘weight’ tensor corresponds to ‘gamma’ in the rms-norm formula. The ‘eps’ value is added to the variance before computing the squared-root.

Parameters:

Returns:

The output tensor of that operation.

tensorrt_llm.functional.round(

input: ~tensorrt_llm.functional.Tensor,

*,

op: ~tensorrt.tensorrt.UnaryOperation = <UnaryOperation.ROUND: 22>,

) → Tensor#

Add an elementwise operation on a single input.

The following closures are defined in functional.*:

round for op=trt.UnaryOperation.ROUND sqrt for op=trt.UnaryOperation.SQRT exp for op=trt.UnaryOperation.EXP sin for op=trt.UnaryOperation.SIN cos for op=trt.UnaryOperation.COS abs for op=trt.UnaryOperation.ABS log for op=trt.UnaryOperation.LOG

It is implemented using the IUnaryLayer from TensorRT.

Parameters:

Returns:

The tensor produced by this elementwise operation.

tensorrt_llm.functional.scatter(

input: Tensor,

dim: int,

indices: Tensor,

updates: Tensor,

) → Tensor[source]#

This operation adds a layer that creates an output tensor by element-wise copying values from the input tensor and then updating values by the given

indices and updates tensors. For a 2D input tensor, it first copies the input to output, then updates the output tensor like the following for each entry in updates:

output[indices[i][j]][j] = updates[i][j] if dim=0 output[i][indices[i][j]] = updates[i][j] if dim=1

If the input tensor is [[1, 2, 3], [4, 5, 6]], the indices tensor is [[1, 2], [0, 1]], the updates tensor is [[-1, -2], [-3, -4]], and dim=1 the output tensor will be [[1, -1, -2], [-3, -4, 6]]. Parameters:

input: Tensor

The input data that needs to be updated.

dim: int

The axis on which the scatter is to be performed.

indices: Tensor

An integer tensor of the same rank as input that indicates the positions to be updated.

updates: Tensor

A data tensor of same shape as the indices tensor that contains the update values.

Returns:

A tensor created by the element-wise scatter layer.

tensorrt_llm.functional.scatter_nd(

input: Tensor,

mask: Tensor,

source: Tensor,

) → Tensor[source]#

Scatter_nd is a tensor operation that writes or updates values in a tensor based on indices.

Parameters:

Returns:

New tensor with the same shape as the input tensor data, where the values from the source tensor are scattered or written into the output tensor at the locations specified by the mask tensor.

tensorrt_llm.functional.select(

input: Tensor,

dim: int,

index: Tensor | int,

) → Tensor[source]#

Add an operation to select a slice of elements from a tensor.

Given an input tensor, that function creates an operation that selects the index-th slice of elements in the dimension ‘dim’ to create a new tensor. The output tensor has a shape in which the input dimension ‘dim’ is removed.

The ‘index’ can either be an integer or a 1D tensor containing a single element.

For example, on input=[[4, 2, 5], [2, 1, 2], [4, 7, 1]], which has a shape [3, 3],

select(input, 0, 1)

will create a tensor of shape [3] that contains the [2, 1, 2].

Regarding the shape of the output tensor, the dimension ‘dim’ is removed. It means that for a tensor of shape [4, 2, 6, 3],

select(input, 2, 4)

will select the 5th slice (index == 4) from the 3rd dimension (dim == 2) and return a tensor of shape [4, 2, 3] (i.e. the 3rd dimension is removed).

That operation maps to the TensorRT IGatherLayer.

Parameters:

Returns:

The tensor containing the selected slice.

tensorrt_llm.functional.selective_scan(

input: Tensor,

state_or_ptr: Tensor,

delta: Tensor,

delta_bias: Tensor,

A: Tensor,

BC: Tensor,

D: Tensor,

host_request_types: Tensor,

last_token_ids: Tensor,

dim: int,

dstate: int,

dt_rank: int,

delta_softplus: bool,

dtype: str,

z: Tensor | None = None,

host_context_lengths: Tensor | None = None,

slot_mapping: Tensor | None = None,

nheads: int = 1,

ngroups: int = 1,

chunk_size: int = 256,

mamba_version: str = 'Mamba1',

)[source]#

Parameters:

tensorrt_llm.functional.send(

tensor: Tensor,

tgt: int,

) → Tensor[source]#

Add an operation that performs a send from a rank to another.

The send operation sends a tensor from one rank to another. If a rank ‘i’ sends a tensor to a rank ‘j’, the rank ‘j’ must have a corresponding ‘recv’ operation from rank ‘i’. See ‘recv’.

That operation is implemented using a plugin that wraps the NCCL send point-to-point operation. Seehttps://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/p2p.html#ncclsendfor details.

Parameters:

Returns:

The tensor produced by that layer.

tensorrt_llm.functional.shape(

input: Tensor,

dim: int | None = None,

cast_to_dtype: str | DataType | None = None,

clip_before_cast: Sequence[int] = None,

) → Tensor[source]#

Add an operation to create a shape tensor.

The shape tensor can either be the shape of the input tensor when the parameter dim is None or a scalar (tensor of rank 0) that corresponds to the size of dim-th dimension.

Parameters:

Returns:

A tensor that contains the shape of the input tensor (if ‘dim’ is None) or the size in the dimension ‘dim’ of the input tensor. If ‘dim’ is ‘None’, that tensor has the same rank as the input tensor, otherwise its rank is 0.

tensorrt_llm.functional.sigmoid(

input: ~tensorrt_llm.functional.Tensor,

*,

act_type: ~tensorrt.tensorrt.ActivationType = <ActivationType.SIGMOID: 1>,

) → Tensor#

Add an activation function.

Parameters:

The following closures are defined in functional.*:

relu for op=trt.ActivationType.RELU tanh for op=trt.ActivationType.TANH sigmoid for op=trt.ActivationType.SIGMOID

Returns:

The tensor produced by the activation layer.

tensorrt_llm.functional.silu(

input: Tensor,

) → Tensor[source]#

Add a SiLU (x * sigmoid(x)) operation.

Parameters:

input – Tensor The input tensor on which the activation function is applied.

Returns:

The tensor produced by the activation layer.

tensorrt_llm.functional.sin(

input: ~tensorrt_llm.functional.Tensor,

*,

op: ~tensorrt.tensorrt.UnaryOperation = <UnaryOperation.SIN: 6>,

) → Tensor#

Add an elementwise operation on a single input.

The following closures are defined in functional.*:

round for op=trt.UnaryOperation.ROUND sqrt for op=trt.UnaryOperation.SQRT exp for op=trt.UnaryOperation.EXP sin for op=trt.UnaryOperation.SIN cos for op=trt.UnaryOperation.COS abs for op=trt.UnaryOperation.ABS log for op=trt.UnaryOperation.LOG

It is implemented using the IUnaryLayer from TensorRT.

Parameters:

Returns:

The tensor produced by this elementwise operation.

tensorrt_llm.functional.slice(

input: Tensor,

starts: Tensor | Sequence[int],

sizes: Tensor | Sequence[int],

strides: Tensor | Sequence[int] = None,

mode: SampleMode = None,

fill_value: float | Tensor = None,

) → Tensor[source]#

Add an operation to extract a slice from a tensor.

As described in the TensorRT documentation of the ISliceLayer, the slice layer has two variants: Static and dynamic.

For static slicing, this function takes the starts and sizes values in the different dimensions to slice at layer creation time via a sequence of integers. For dynamic slicing, it accepts starts and sizes as tensorrt.ITensor`s.

The slice layer selects for each dimension a start location from within the input tensor, and copies elements to the output tensor using a stride of 1 across the input tensor. Start and size tensors must be 1-D int32 shape tensors if not specified as a sequence of integers.

As an example, on input = [[0, 2, 4], [1, 3, 5]], the call to

slice(input, start=[1, 0], size=[1, 2])

will produce the tensor [[1, 3]] as output. The slice operator when executed by TensorRT will copy one row (because size[0] == 1) starting from the 2nd row (because start[0] == 1) and two columns (size[1] == 2) starting from the 1st column (because start[1] == 0).

In pseudo-code the behavior of that operation can be described as follows for a 2D tensor (and easily be extended to more dimensions):

output = Tensor(shape=sizes) for ii in range(sizes[0]):

for jj in range(sizes[1]):

output[ii][jj] = input[starts[0]+ii][starts[1]+jj]

Note that it is common in deep-learning frameworks to use ranges [start:end] for similar operations. It can be emulated by setting the sizes argument such that in each dimension [start:start+size] == [start:end] i.e. size = end-start.

TensorRT supports different slice modes but that function restricts that choice to mode == tensorrt.SampleMode.STRICT_BOUNDS.

Parameters:

Returns:

The tensor produced by the slice layer.

tensorrt_llm.functional.softmax(

input: Tensor,

dim: int | None = None,

) → Tensor[source]#

Add an operation to compute softmax on a tensor.

That operation computes the softmax on the input tensor in the dimension ‘dim’ if specified. Otherwise, it is applied on the last dimension.

It inserts a ISoftmaxLayer to the TensorRT graph.

Parameters:

Returns:

The output tensor of the softmax layer.

tensorrt_llm.functional.softplus(

input: Tensor,

beta: float,

threshold: float,

) → Tensor[source]#

Add the softplus activation base on PyTorch definition.

See https://pytorch.org/docs/stable/generated/torch.nn.functional.softplus.html#torch-nn-functional-softplus for a description of that function.

Parameters:

Returns:

The output tensor created by that layer.

tensorrt_llm.functional.split(

tensor: Tensor,

split_size_or_sections: int | Sequence[int],

dim: int = 0,

) → Sequence[Tensor][source]#

Add an operation that splits a tensor into sub-tensors.

This operation creates a list of tensors that are obtained from the input tensor by slicing it along the dimension ‘dim’. If ‘split_size_or_sections’ is an integer, the tensor is split into ‘input.shape[dim] / split_size_or_sections’ slices. If ‘split_size_or_sections’ is a list of sizes, the tensor is split into ‘len(split_size_or_sections)’ slices and the size of the ith slice is given by ‘split_size_or_sections[i]’.

There are several constraints with the current implementation:

That operation is implemented using a ‘slice’ operation for each output slice.

Parameters:

Returns:

The list of tensors produced by the different operations.

tensorrt_llm.functional.sqrt(

input: ~tensorrt_llm.functional.Tensor,

*,

op: ~tensorrt.tensorrt.UnaryOperation = <UnaryOperation.SQRT: 2>,

) → Tensor#

Add an elementwise operation on a single input.

The following closures are defined in functional.*:

round for op=trt.UnaryOperation.ROUND sqrt for op=trt.UnaryOperation.SQRT exp for op=trt.UnaryOperation.EXP sin for op=trt.UnaryOperation.SIN cos for op=trt.UnaryOperation.COS abs for op=trt.UnaryOperation.ABS log for op=trt.UnaryOperation.LOG

It is implemented using the IUnaryLayer from TensorRT.

Parameters:

Returns:

The tensor produced by this elementwise operation.

tensorrt_llm.functional.squared_relu(

x: Tensor,

) → Tensor[source]#

Add a Squared ReLU operation.

This function applies ReLU and squares the output.

Parameters:

input – Tensor The input tensor on which the activation function is applied.

Returns:

The tensor produced by the activation layer.

tensorrt_llm.functional.squeeze(

input: Tensor,

dim: int | Sequence[int] | None = None,

zero_is_placeholder: bool = False,

)[source]#

Add an operation to remove singleton dimensions of a tensor.

This functions creates an operation that removes singleton dimension (dimension of size 1) at positions ‘dim’ in the input tensor. It works with negative values for the ‘dim’.

For example, for a tensor ‘input’ of shape [1, 4, 1, 4]:

squeeze(input, 0) will produce an output of shape [4, 1, 4], squeeze(input, 2) will produce an output of shape [1, 4, 4], squeeze(input, [0, 2]) will produce an output of shape [4, 4], squeeze(input, [-2]) will produce an output of shape [1, 4, 4],

Parameters:

Returns:

The tensor produced by the layer.

tensorrt_llm.functional.stack(

inputs: Sequence[Tensor],

dim: int = 0,

) → Tensor[source]#

Add an operation to contact input tensors along a new dimension.

The function creates an operation that creates a new dim for all the input tensors and then concatenates them along that new dim.

.

All the tensors in ‘inputs’ must have the same shape.

for ii in range(inputs[0].rank()):

assert all(inp.shape[ii] == inputs[0].shape[ii] for inp in inputs)

The shape of the output tensor is defined as:

output.rank() = inputs[0].rank() + 1

output.shape[dim] = len(inputs)

for ii in range(inputs[0].rank()):

if ii < dim:

output.shape[ii] = inputs[0].shape[ii]

else:

output.shape[ii+1] = inputs[0].shape[ii]

For example, given a sequence of two 2D tensors [[0, 1], [2, 3]] and [[4, 5], [6, 7]] both of shape [2, 2],

stack(inputs, 0)

will produce [[[0, 1], [2, 3]], [[4, 5], [6, 7]]] of shape [2, 2, 2] and

stack(inputs, 1)

will produce [[[0, 1], [4, 5]], [[2, 3], [6, 7]]] of shape [2, 2, 2].

Parameters:

inputsSequence[Tensor]

The sequence of tensors to stack.

dimint

The dimension in which the stack is performed.

Returns:

A tensor that contains the input tensors stacked along a new dimension.

tensorrt_llm.functional.sub(

left: ~tensorrt_llm.functional.Tensor | int | float,

right: ~tensorrt_llm.functional.Tensor | int | float,

*,

op: ~tensorrt.tensorrt.ElementWiseOperation = <ElementWiseOperation.SUB: 4>,

) → Tensor#

Add an elementwise operation with two inputs.

For each input, that function first creates a constant tensor if the input is an integer or a float. Then, if needed, it expands the smaller tensor to make sure its rank is the same as the larger one. Then, it performs the elementwise operation ‘op’.

The following closures are defined in functional.*:

add for op=trt.ElementWiseOperation.SUM sub for op=trt.ElementWiseOperation.SUB mul for op=trt.ElementWiseOperation.PROD div for op=trt.ElementWiseOperation.DIV floordiv for op=trt.ElementWiseOperation.FLOOR_DIV gt for op=trt.ElementWiseOperation.GREATER lt for op=trt.ElementWiseOperation.LESS op_and for op=trt.ElementWiseOperation.AND op_or for op=trt.ElementWiseOperation.OR eq for op=trt.ElementWiseOperation.EQUAL minimum for op=trt.ElementWiseOperation.MIN maximum for op=trt.ElementWiseOperation.MAX pow for op=trt.ElementWiseOperation.POW

It is implemented using the IElementWiseLayer from TensorRT.

Parameters:

Returns:

The tensor produced by this elementwise operation.

tensorrt_llm.functional.sum(

input: Tensor,

dim: int,

keepdim: bool = False,

) → Tensor[source]#

Add an operation to compute the sum along a dimension.

Computes the sum along the dimension ‘dim’ of the input tensor.

It is implemented using the IReduceLayer from TensorRT.

Parameters:

Returns:

The tensor produced by this reduction operation.

tensorrt_llm.functional.swiglu(

input: Tensor,

) → Tensor[source]#

Add a SwiGLU (x * SiLU(gate)) operation.

That function takes a tensor, splits it into two halves along the last dimension, applies SiLU to the second half and multiply the results. The behavior is undefined if the last dimension is not even.

Parameters:

input – Tensor The input tensor on which the activation function is applied.

Returns:

The tensor produced by the activation layer.

tensorrt_llm.functional.tanh(

input: ~tensorrt_llm.functional.Tensor,

*,

act_type: ~tensorrt.tensorrt.ActivationType = <ActivationType.TANH: 2>,

) → Tensor#

Add an activation function.

Parameters:

The following closures are defined in functional.*:

relu for op=trt.ActivationType.RELU tanh for op=trt.ActivationType.TANH sigmoid for op=trt.ActivationType.SIGMOID

Returns:

The tensor produced by the activation layer.

tensorrt_llm.functional.topk(

input: Tensor,

k: Tensor | int,

dim: int,

largest: bool = True,

prefer_plugin: bool = True,

) → Tuple[Tensor, Tensor][source]#

Add an topk operation.

As explained in the ONNX documentation,

NOTE: One distinction from the ONNX topk op, the output is always sorted with TensorRT layer.

Retrieve the top-K largest elements along a specified axis. Given an input tensor of shape [a_1, a_2, …, a_n, r] and integer argument k, return two outputs: Value tensor of shape [a_1, a_2, …, a_{axis-1}, k, a_{axis+1}, … a_n] which contains the values of the top k elements along the specified axis Index tensor of shape [a_1, a_2, …, a_{axis-1}, k, a_{axis+1}, … a_n] which contains the indices of the top k elements (original indices from the input tensor).

Parameters:

Returns:

The tensors (values, indices) produced by this topk operation.

tensorrt_llm.functional.transpose(

input: Tensor,

dim0: int,

dim1: int,

) → Tensor[source]#

Add an operation to transpose two dimensions of a tensor.

That operation produces a tensor in which the dimensions ‘dim0’ and ‘dim1’ are permuted. The other dimensions, if the rank of the tensor is greater than 2, remain untouched.

That function is a helper built on the ‘functional.permute’ function.

Parameters:

Returns:

The tensor produced by the permutation layer.

tensorrt_llm.functional.unary(

input: Tensor,

op: UnaryOperation,

) → Tensor[source]#

Add an elementwise operation on a single input.

The following closures are defined in functional.*:

round for op=trt.UnaryOperation.ROUND sqrt for op=trt.UnaryOperation.SQRT exp for op=trt.UnaryOperation.EXP sin for op=trt.UnaryOperation.SIN cos for op=trt.UnaryOperation.COS abs for op=trt.UnaryOperation.ABS log for op=trt.UnaryOperation.LOG

It is implemented using the IUnaryLayer from TensorRT.

Parameters:

Returns:

The tensor produced by this elementwise operation.

tensorrt_llm.functional.unbind(input: Tensor, dim: int = 0)[source]#

Removes a tensor dimension.

Returns a tuple of all slices along a given dimension, already without it.

tensorrt_llm.functional.unsqueeze(input: Tensor, axis: int)[source]#

Add an operation to insert a singleton dimension to a tensor.

That functions creates an operation that insert a singleton dimension (dimension of size 1) at position ‘axis’ in the output tensor. It works with negative values for the ‘axis’.

For example, for a tensor ‘input’ of shape [4, 4]:

unsqueeze(input, 0) will produce an output of shape [1, 4, 4], unsqueeze(input, 1) will produce an output of shape [4, 1, 4], unsqueeze(input, -1) will produce an output of shape [4, 4, 1], unsqueeze(input, -2) will produce an output of shape [4, 1, 4],

Parameters:

Returns:

The tensor produced by the layer.

tensorrt_llm.functional.view(

input: Tensor,

shape: Tensor | Sequence[int],

zero_is_placeholder: bool = True,

) → Tensor[source]#

Add an operation to create a view of a tensor.

That operation adds a tensorrt.IShuffleLayer to the network. If the ‘shape’ parameter is a Tensor, that view is dynamic. Otherwise, it is a static view.

Note that TensorRT limits the number of inferred dimensions to 1. It means that the shape sequence or tensor cannot contain more than one -1. This function enforces that constraint and will assert if it is not respected.

Parameters:

Returns:

The tensor produced by the view/shuffle layer.

tensorrt_llm.functional.where(

condition: Tensor | bool,

left: Tensor | int | float,

right: Tensor | int | float,

) → Tensor[source]#

Add a where (aka select or if-then-else) operation.

Assuming the three input parameters have the same shape, that function creates the operation to compute a tensor of the same shape such that:

for ii in range(mul(condition.shape)):

output[ii] = left[ii] if condition[ii] else right[ii]

For each input, that function first creates a constant tensor if the condition is boolean or the left/right input is an integer or a float. Then, if needed, it expands the smaller tensor to make sure its rank is the same as the larger one. Then, it performs the selection.

It is implemented using the ISelectLayer from TensorRT.

Parameters:

Returns:

The tensor produced by this where operation.