tensorrt_llm.models.mllama.model — TensorRT-LLM (original) (raw)

SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

SPDX-License-Identifier: Apache-2.0

Licensed under the Apache License, Version 2.0 (the "License");

you may not use this file except in compliance with the License.

You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software

distributed under the License is distributed on an "AS IS" BASIS,

WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

See the License for the specific language governing permissions and

limitations under the License.

import math from collections import OrderedDict from typing import List, Optional, Union

import tensorrt as trt import torch

from tensorrt_llm._common import default_net from tensorrt_llm._utils import numpy_to_torch, str_dtype_to_torch from tensorrt_llm.bindings import KVCacheType from tensorrt_llm.functional import (Conditional, LayerNormPositionType, LayerNormType, MLPType, PositionEmbeddingType, Tensor, assertion, gather_last_token_logits, maximum, minimum, recv, reduce, send, shape, tanh) from tensorrt_llm.layers import (MLP, Attention, AttentionMaskParams, AttentionMaskType, AttentionParams, ColumnLinear, Embedding, FusedGatedMLP, GatedMLP, GroupNorm, KeyValueCacheParams, LayerNorm, LoraParams, RmsNorm) from tensorrt_llm.lora_manager import (LoraConfig, get_default_trtllm_modules_to_hf_modules, use_lora) from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.model_weights_loader import ModelWeightsLoader from tensorrt_llm.models.modeling_utils import PretrainedModel, QuantConfig from tensorrt_llm.module import Module, ModuleList from tensorrt_llm.parameter import Parameter from tensorrt_llm.quantization import QuantMode

from .config import MLLaMAConfig

layernorm_map = { LayerNormType.LayerNorm: LayerNorm, LayerNormType.RmsNorm: RmsNorm, LayerNormType.GroupNorm: GroupNorm, }

mlp_map = { MLPType.MLP: MLP, MLPType.GatedMLP: GatedMLP, MLPType.FusedGatedMLP: FusedGatedMLP, }

ADD_DEBUG_TENSOR = False

class CrossAttentionTransformerBlock(Module):

def __init__(
        self,
        *,
        local_layer_idx,
        hidden_size,
        ffn_hidden_size,
        num_attention_heads,
        num_kv_heads,
        head_size,
        max_position_embeddings=None,
        q_scaling=1.0,
        has_attention_qkvo_bias=False,
        has_mlp_bias=False,
        layernorm_position=LayerNormPositionType.pre_layernorm,
        layernorm_type=LayerNormType.RmsNorm,
        layernorm_eps=1e-5,
        hidden_act="gated-silu",
        mlp_type=MLPType.GatedMLP,
        mapping=Mapping(),
        dtype=None,
        residual_scaling=1.0,
        relative_attention=False,
        max_distance=0,
        num_buckets=0,
        fp16_clamping=False,
        skip_cross_kv=False,
        use_implicit_relative_attention=False,
        rotary_embedding_base=None,
        rotary_embedding_scaling=None,
        quant_mode=QuantMode(0),
):
    super().__init__()
    self.local_layer_idx = local_layer_idx
    self.layernorm_type = layernorm_type
    ln_type = layernorm_map[layernorm_type]

    self.layernorm_position = layernorm_position
    assert self.layernorm_position == LayerNormPositionType.pre_layernorm

    self.cross_attention = Attention(
        local_layer_idx=local_layer_idx,
        hidden_size=hidden_size,
        num_attention_heads=num_attention_heads,
        attention_head_size=head_size,
        num_kv_heads=num_kv_heads,
        max_position_embeddings=max_position_embeddings,
        q_scaling=q_scaling,
        bias=has_attention_qkvo_bias,
        attention_mask_type=AttentionMaskType.causal,
        tp_group=mapping.tp_group,
        tp_size=mapping.tp_size,
        tp_rank=mapping.tp_rank,
        dtype=dtype,
        cross_attention=True,
        relative_attention=
        False,  # Cross attention has no relative attention bias
        max_distance=max_distance,
        num_buckets=num_buckets,
        position_embedding_type=PositionEmbeddingType.
        learned_absolute,  # we don't use rope for cross attn
        skip_cross_kv=skip_cross_kv,
        qk_layernorm=True,
        layernorm_type=layernorm_type,
        quant_mode=quant_mode,
    )

    self.input_layernorm = ln_type(normalized_shape=hidden_size,
                                   eps=layernorm_eps,
                                   dtype=dtype)
    self.gate_attn = Parameter(shape=tuple((1, )), dtype=dtype)

    self.mlp_type = mlp_type
    mlp_f = mlp_map[mlp_type]

    self.mlp = mlp_f(
        hidden_size=hidden_size,
        ffn_hidden_size=ffn_hidden_size,
        hidden_act=hidden_act,
        bias=has_mlp_bias,
        tp_group=mapping.tp_group,
        tp_size=mapping.tp_size,
        dtype=dtype,
    )

    self.post_layernorm = ln_type(normalized_shape=hidden_size,
                                  eps=layernorm_eps,
                                  dtype=dtype)
    self.gate_ffwd = Parameter(shape=tuple((1, )), dtype=dtype)

    self.residual_scaling = residual_scaling

    self.fp16_clamping = fp16_clamping
    self.no_ffn = False

def forward(self,
            hidden_states: Tensor,
            encoder_output: Optional[Tensor] = None,
            attention_mask_params=None,
            use_cache=False,
            kv_cache_params=None,
            attention_params=None,
            lora_layer_params=None,
            cross_kv_cache_gen: Optional[Tensor] = None,
            cross_kv_reuse: Optional[Tensor] = None,
            full_text_row_masked_out_mask: Tensor = None,
            skip_cross_attn_blocks: Tensor = None):

    assert isinstance(hidden_states, Tensor)

    if encoder_output:
        assert isinstance(encoder_output, Tensor)

    if ADD_DEBUG_TENSOR:
        hidden_states.mark_output(
            f'{self.local_layer_idx:2d}/1.0: hidden_states',
            hidden_states.dtype)
    # cross attention
    residual = hidden_states * self.residual_scaling

    # skip input_layernorm
    if skip_cross_attn_blocks is not None:
        input_ln_conditional = Conditional(skip_cross_attn_blocks)
        skip_result = input_ln_conditional.add_input(hidden_states)
        hidden_states = input_ln_conditional.add_input(hidden_states)
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = input_ln_conditional.add_output(
            skip_result, hidden_states)
    else:
        hidden_states = self.input_layernorm(hidden_states)

    if ADD_DEBUG_TENSOR:
        hidden_states.mark_output(
            f'{self.local_layer_idx:2d}/2.1: normed_input',
            hidden_states.dtype)
    # pass full_text_row_masked_out_mask and xattn_mask
    attention_output = self.cross_attention(
        hidden_states=hidden_states,
        attention_mask=attention_mask_params.cross_attention_mask,
        attention_packed_mask=attention_mask_params.
        cross_attention_packed_mask,
        encoder_output=encoder_output,
        use_cache=use_cache,
        kv_cache_params=kv_cache_params,
        attention_params=attention_params,
        lora_layer_params=lora_layer_params,
        cross_kv_cache_gen=cross_kv_cache_gen,
        cross_kv_reuse=cross_kv_reuse,
        skip_attn=skip_cross_attn_blocks,
    )

    if use_cache:
        attention_output, presents_cross = attention_output
    if ADD_DEBUG_TENSOR:
        attention_output.mark_output(
            f'{self.local_layer_idx:2d}/3.1: cross_attention_output',
            attention_output.dtype)

    attn_residual_scale = tanh(self.gate_attn.value.cast(trt.float32)).cast(
        attention_output.dtype)

    attention_input = hidden_states
    hidden_states = residual + attn_residual_scale * attention_output

    # use to skip attention_output with residual
    # Since conditional does not work for gpt_attention_plugin, we replace the
    # attention_output by hidden_states (input of attention) now.
    if skip_cross_attn_blocks is not None:
        attn_conditional = Conditional(skip_cross_attn_blocks)
        skip_result = attn_conditional.add_input(attention_input)
        hidden_states = attn_conditional.add_input(hidden_states)
        hidden_states = attn_conditional.add_output(skip_result,
                                                    hidden_states)

    if ADD_DEBUG_TENSOR:
        hidden_states.mark_output(
            f'{self.local_layer_idx:2d}/3.2: cross_attn_output_with_residual',
            hidden_states.dtype)

    if self.fp16_clamping:
        hidden_states = maximum(-64000.0, hidden_states)
        hidden_states = minimum(64000.0, hidden_states)

    # MLP
    # skip post_layernorm and mlp
    if skip_cross_attn_blocks is not None:
        mlp_conditional = Conditional(skip_cross_attn_blocks)
        skip_case = mlp_conditional.add_input(hidden_states)
        hidden_states = mlp_conditional.add_input(hidden_states)

    attention_output = attention_output * full_text_row_masked_out_mask  # TODO should move this mask into attention?

    residual = hidden_states * self.residual_scaling

    hidden_states = self.post_layernorm(hidden_states)

    hidden_states = self.mlp(hidden_states,
                             lora_layer_params=lora_layer_params)
    if ADD_DEBUG_TENSOR:
        hidden_states.mark_output(
            f'{self.local_layer_idx:2d}/4.1: mlp_output',
            hidden_states.dtype)

    hidden_states = hidden_states * full_text_row_masked_out_mask
    if ADD_DEBUG_TENSOR:
        hidden_states.mark_output(
            f'{self.local_layer_idx:2d}/4.2: masked_mlp_output',
            hidden_states.dtype)
    ffn_residual_scale = tanh(self.gate_ffwd.value.cast(trt.float32)).cast(
        hidden_states.dtype)
    hidden_states = residual + ffn_residual_scale * hidden_states * float(
        not self.no_ffn)

    if skip_cross_attn_blocks is not None:
        hidden_states = mlp_conditional.add_output(skip_case, hidden_states)

    if self.fp16_clamping:
        hidden_states = maximum(-64000.0, hidden_states)
        hidden_states = minimum(64000.0, hidden_states)

    if ADD_DEBUG_TENSOR:
        hidden_states.mark_output(
            f'{self.local_layer_idx:2d}/4.4: transformer_out',
            hidden_states.dtype)

    if use_cache:
        return (hidden_states, presents_cross)
    return hidden_states

class TransformerBlock(Module):

def __init__(
        self,
        *,
        local_layer_idx,
        hidden_size,
        ffn_hidden_size,
        num_attention_heads,
        num_kv_heads,
        head_size,
        max_position_embeddings=None,
        q_scaling=1.0,
        has_attention_qkvo_bias=False,
        has_mlp_bias=False,
        layernorm_position=LayerNormPositionType.pre_layernorm,
        layernorm_type=LayerNormType.RmsNorm,
        layernorm_eps=1e-5,
        hidden_act="gated-silu",
        mlp_type=MLPType.GatedMLP,
        mapping=Mapping(),
        dtype=None,
        residual_scaling=1.0,
        relative_attention=False,
        max_distance=0,
        num_buckets=0,
        fp16_clamping=False,
        skip_cross_kv=False,
        use_implicit_relative_attention=False,
        rotary_embedding_base=None,
        rotary_embedding_scaling=None,
        quant_mode=QuantMode(0),
):
    super().__init__()
    self.local_layer_idx = local_layer_idx
    self.layernorm_type = layernorm_type
    ln_type = layernorm_map[layernorm_type]

    self.layernorm_position = layernorm_position
    assert self.layernorm_position == LayerNormPositionType.pre_layernorm

    self.self_attention = Attention(
        local_layer_idx=local_layer_idx,
        hidden_size=hidden_size,
        num_attention_heads=num_attention_heads,
        attention_head_size=head_size,
        num_kv_heads=num_kv_heads,
        max_position_embeddings=max_position_embeddings,
        q_scaling=q_scaling,
        bias=has_attention_qkvo_bias,
        attention_mask_type=AttentionMaskType.causal,
        tp_group=mapping.tp_group,
        tp_size=mapping.tp_size,
        tp_rank=mapping.tp_rank,
        dtype=dtype,
        cross_attention=False,
        relative_attention=relative_attention,
        max_distance=max_distance if use_implicit_relative_attention else 0,
        num_buckets=num_buckets,
        position_embedding_type=PositionEmbeddingType.relative
        if relative_attention else PositionEmbeddingType.rope_gpt_neox,
        use_implicit_relative_attention=use_implicit_relative_attention,
        rotary_embedding_base=rotary_embedding_base,
        rotary_embedding_scaling=rotary_embedding_scaling,
        quant_mode=quant_mode,
    )

    self.input_layernorm = ln_type(normalized_shape=hidden_size,
                                   eps=layernorm_eps,
                                   dtype=dtype)

    self.mlp_type = mlp_type
    mlp_f = mlp_map[mlp_type]
    self.mlp = mlp_f(
        hidden_size=hidden_size,
        ffn_hidden_size=ffn_hidden_size,
        hidden_act=hidden_act,
        bias=has_mlp_bias,
        tp_group=mapping.tp_group,
        tp_size=mapping.tp_size,
        dtype=dtype,
    )

    self.post_layernorm = ln_type(normalized_shape=hidden_size,
                                  eps=layernorm_eps,
                                  dtype=dtype)

    self.residual_scaling = residual_scaling

    self.fp16_clamping = fp16_clamping

def forward(
    self,
    hidden_states: Tensor,
    encoder_output: Optional[Tensor] = None,  # not used
    attention_mask_params=None,
    use_cache=False,
    kv_cache_params=None,
    attention_params=None,
    lora_layer_params=None,
    cross_kv_cache_gen: Optional[Tensor] = None,
    cross_kv_reuse: Optional[Tensor] = None,
    full_text_row_masked_out_mask: Tensor = None,  # not used
    skip_cross_attn_blocks=None,
):
    assert isinstance(hidden_states, Tensor)

    # self-attention
    residual = hidden_states * self.residual_scaling
    if ADD_DEBUG_TENSOR:
        hidden_states.mark_output(
            f'{self.local_layer_idx:2d}/1.0: hidden_states',
            hidden_states.dtype)

    hidden_states = self.input_layernorm(hidden_states)
    if ADD_DEBUG_TENSOR:
        hidden_states.mark_output(
            f'{self.local_layer_idx:2d}/2.1: normed attn_input',
            hidden_states.dtype)

    attention_output = self.self_attention(
        hidden_states=hidden_states,
        attention_mask=attention_mask_params.self_attention_mask,
        use_cache=use_cache,
        kv_cache_params=kv_cache_params,
        attention_params=attention_params,
        lora_layer_params=lora_layer_params)

    if use_cache:
        attention_output, presents_self = attention_output

    if ADD_DEBUG_TENSOR:
        attention_output.mark_output(
            f'{self.local_layer_idx:2d}/3.1: self_attention_output',
            attention_output.dtype)

    hidden_states = residual + attention_output
    if ADD_DEBUG_TENSOR:
        hidden_states.mark_output(
            f'{self.local_layer_idx:2d}/3.1: attention_output_with_residual',
            hidden_states.dtype)

    if self.fp16_clamping:
        hidden_states = maximum(-64000.0, hidden_states)
        hidden_states = minimum(64000.0, hidden_states)

    # MLP
    residual = hidden_states * self.residual_scaling

    hidden_states = self.post_layernorm(hidden_states)
    if ADD_DEBUG_TENSOR:
        hidden_states.mark_output(
            f'{self.local_layer_idx:2d}/3.2: normed_mlp_input',
            hidden_states.dtype)

    hidden_states = self.mlp(hidden_states,
                             lora_layer_params=lora_layer_params)

    if ADD_DEBUG_TENSOR:
        hidden_states.mark_output(
            f'{self.local_layer_idx:2d}/4.1: mlp_output',
            hidden_states.dtype)

    hidden_states = residual + hidden_states
    if ADD_DEBUG_TENSOR:
        hidden_states.mark_output(
            f'{self.local_layer_idx:2d}/4.2: mlp_output_residual',
            hidden_states.dtype)

    if self.fp16_clamping:
        hidden_states = maximum(-64000.0, hidden_states)
        hidden_states = minimum(64000.0, hidden_states)

    if use_cache:
        return (hidden_states, presents_self)
    return hidden_states

class MLLaMAModel(Module):

def __init__(self, config: MLLaMAConfig) -> None:
    super().__init__()
    self.config = config
    self.position_embedding_type = config.position_embedding_type

    self.mapping = self.config.mapping

    self.layernorm_type = self.config.layernorm_type
    ln_type = layernorm_map[self.layernorm_type]

    self.has_attention_qkvo_bias = self.config.has_attention_qkvo_bias
    self.has_mlp_bias = self.config.has_mlp_bias

    self.has_model_final_layernorm = self.config.has_model_final_layernorm
    self._dtype = self.config.dtype
    # no quantization considered for now
    self._kv_dtype = self._dtype
    self._logits_dtype = self.config.logits_dtype

    self.total_num_layers = self.config.num_hidden_layers
    self.num_layers = self.config.num_hidden_layers // self.mapping.pp_size

    self.hidden_size = self.config.hidden_size
    self.encoder_hidden_size = self.config.hidden_size
    self.num_heads = self.config.num_attention_heads
    # num_kv_heads = self.num_heads
    num_kv_heads = self.config.num_key_value_heads
    if num_kv_heads is None or num_kv_heads <= 0:
        num_kv_heads = self.num_heads
    self.num_kv_heads = num_kv_heads
    self.head_size = self.hidden_size // self.num_heads if self.config.head_size is None else self.config.head_size

    self.fp16_clamping = False

    self.skip_cross_kv = self.config.skip_cross_kv
    self.mlp_type = MLPType.MLP if not hasattr(
        self.config, "mlp_type") else self.config.mlp_type
    self.use_implicit_relative_attention = self.config.use_implicit_relative_attention if hasattr(
        self.config, "use_implicit_relative_attention") else False

    self.cross_attention_layers = self.config.cross_attention_layers

    if self.mapping.is_first_pp_rank():
        self.vocab_embedding = Embedding(
            self.config.embed_vocab_size,
            self.config.hidden_size,
            dtype=self._dtype,
            tp_size=self.mapping.tp_size
            if self.config.use_parallel_embedding else 1,
            tp_group=self.mapping.tp_group
            if self.config.use_parallel_embedding else None,
            sharding_dim=self.config.embedding_sharding_dim,
            tp_rank=self.mapping.tp_rank)

    layers_range = self.mapping.pp_layers(self.total_num_layers)
    _layers = []
    for layer_idx in layers_range:
        local_layer_idx = layer_idx - layers_range[0]
        args = {
            "local_layer_idx": local_layer_idx,
            "hidden_size": self.config.hidden_size,
            "ffn_hidden_size": self.config.intermediate_size,
            "num_attention_heads": self.num_heads,
            "num_kv_heads": self.num_kv_heads,
            "head_size": self.head_size,
            "max_position_embeddings": self.config.max_position_embeddings,
            "layernorm_position": self.config.layernorm_position,
            "layernorm_eps": self.config.norm_epsilon,
            "layernorm_type": self.config.layernorm_type,
            "hidden_act": self.config.hidden_act,
            "mlp_type": self.mlp_type,
            "mapping": self.mapping,
            "dtype": self._dtype,
            "residual_scaling": self.config.residual_scaling,
            "max_distance": self.config.max_distance,
            "num_buckets": self.config.num_buckets,
            "fp16_clamping": self.fp16_clamping,
            "skip_cross_kv": self.skip_cross_kv,
            "rotary_embedding_base": self.config.rotary_base,
            "rotary_embedding_scaling": self.config.rotary_scaling,
            "quant_mode": self.config.quant_mode,
        }
        if layer_idx in self.cross_attention_layers:
            assert layers_range[0] == 0, "not support PP now"
            _layers.append(CrossAttentionTransformerBlock(**args))
        else:
            _layers.append(TransformerBlock(**args))

    self.layers = ModuleList(_layers)

    if self.mapping.is_last_pp_rank():
        self.ln_f = None
        if self.has_model_final_layernorm:
            self.ln_f = ln_type(normalized_shape=self.config.hidden_size,
                                eps=self.config.norm_epsilon,
                                dtype=self.config.dtype)

    if self.config.relative_attention and not self.use_implicit_relative_attention:
        self.rel_attn_table = Parameter(
            shape=(self.config.num_attention_heads // self.mapping.tp_size,
                   self.config.num_buckets),
            dtype=self._dtype)

def forward(
    self,
    decoder_input_ids: Tensor,
    encoder_output: Tensor,
    use_cache=False,
    attention_mask_params=None,
    last_token_ids=None,
    kv_cache_params=None,
    attention_params=None,
    hidden_states=None,
    lora_params: LoraParams = None,
    cross_kv_cache_gen: Optional[Tensor] = None,
    cross_kv_reuse: Optional[Tensor] = None,
    prompt_embedding_table: Optional[Tensor] = None,
    prompt_tasks: Optional[Tensor] = None,
    prompt_vocab_size: Optional[Tensor] = None,
    skip_cross_attn_blocks: Optional[Tensor] = None,
):
    if self.mapping.is_first_pp_rank():
        assert isinstance(decoder_input_ids, Tensor)
    else:
        assert isinstance(hidden_states, Tensor)

    # In PP, layer 0 has ids as inputs, all other layers have hidden_states as inputs
    if self.mapping.is_first_pp_rank():
        hidden_states = self.vocab_embedding(decoder_input_ids)
        self.register_network_output('embedding_layer_output',
                                     hidden_states)
    else:
        hidden_states = recv(hidden_states, self.mapping.prev_pp_rank())

    kv_cache_params.fill_none_tensor_list(len(self.layers))

    full_text_row_masked_out_mask = reduce(
        (attention_mask_params.cross_attention_mask).cast(
            hidden_states.dtype),
        trt.ReduceOperation.MAX,
        dim=-1,
        keepdim=True)
    if ADD_DEBUG_TENSOR:
        full_text_row_masked_out_mask.mark_output(
            "full_text_row_masked_out_mask",
            full_text_row_masked_out_mask.dtype)

    cross_attention_mask_type = attention_mask_params.cross_attention_mask.dtype
    attention_mask_params.cross_attention_mask = (
        attention_mask_params.cross_attention_mask.cast(
            full_text_row_masked_out_mask.dtype) *
        full_text_row_masked_out_mask).cast(cross_attention_mask_type)

    invert_mask = 1.0 - attention_mask_params.cross_attention_mask.cast(
        hidden_states.dtype)
    invert_full_text_row_masked_out_mask = 1.0 - full_text_row_masked_out_mask
    final_mask = invert_mask - invert_full_text_row_masked_out_mask
    attention_mask_params.cross_attention_mask = final_mask.cast(
        cross_attention_mask_type)
    if ADD_DEBUG_TENSOR:
        attention_mask_params.cross_attention_mask.mark_output(
            "attention_mask_params.cross_attention_mask",
            attention_mask_params.cross_attention_mask.dtype)

    if use_cache:
        presents = []
    for i, (decoder_layer, past) in enumerate(
            zip(self.layers, kv_cache_params.past_key_value)):

        lora_layer_params = None
        if lora_params is not None and lora_params.lora_ranks is not None:
            lora_layer_params = lora_params.get_layer_params(i)
        hidden_states = decoder_layer(
            hidden_states,
            encoder_output=encoder_output,
            attention_mask_params=attention_mask_params,
            use_cache=use_cache,
            kv_cache_params=KeyValueCacheParams(
                past_key_value=past,
                host_past_key_value_lengths=kv_cache_params.
                host_past_key_value_lengths,
                host_max_attention_window_sizes=kv_cache_params.
                host_max_attention_window_sizes,
                host_sink_token_length=kv_cache_params.
                host_sink_token_length,
                cache_indirection=kv_cache_params.cache_indirection,
                kv_cache_block_offsets=kv_cache_params.
                kv_cache_block_offsets,
                host_kv_cache_block_offsets=kv_cache_params.
                host_cross_kv_cache_block_offsets,
                host_kv_cache_pool_pointers=kv_cache_params.
                host_kv_cache_pool_pointers,
                host_kv_cache_pool_mapping=kv_cache_params.
                host_kv_cache_pool_mapping,
                cross_kv_cache_block_offsets=kv_cache_params.
                cross_kv_cache_block_offsets,
                host_cross_kv_cache_block_offsets=kv_cache_params.
                host_cross_kv_cache_block_offsets,
                host_cross_kv_cache_pool_pointers=kv_cache_params.
                host_cross_kv_cache_pool_pointers,
                host_cross_kv_cache_pool_mapping=kv_cache_params.
                host_cross_kv_cache_pool_mapping,
            ),
            skip_cross_attn_blocks=skip_cross_attn_blocks if isinstance(
                decoder_layer, CrossAttentionTransformerBlock) else None,
            attention_params=attention_params,
            lora_layer_params=lora_layer_params,
            cross_kv_cache_gen=cross_kv_cache_gen,
            cross_kv_reuse=cross_kv_reuse,
            full_text_row_masked_out_mask=full_text_row_masked_out_mask,
        )

        if use_cache:
            present = hidden_states[1]
            presents.append((present))
            hidden_states = hidden_states[0]

    if self.mapping.is_last_pp_rank():
        if self.ln_f:
            hidden_states = self.ln_f(hidden_states)
    else:
        hidden_states = send(hidden_states, self.mapping.next_pp_rank())

    if use_cache:
        return (hidden_states, tuple(presents))
    return hidden_states

def precompute_relative_attention_bias(self, build_config):
    if self.config.relative_attention and not self.use_implicit_relative_attention:
        relative_attention_bias_builder = torch.ops.tensorrt_llm.relative_attention_bias
        rel_attn_precomputed = torch.zeros(
            (self.config.num_attention_heads // self.mapping.tp_size,
             build_config.max_seq_len + 1, build_config.max_seq_len + 1),
            dtype=str_dtype_to_torch(self.config.dtype),
            device='cuda')
        rel_attn_table = numpy_to_torch(
            self.rel_attn_table.raw_value).to('cuda')
        relative_attention_bias_builder(
            rel_attn_precomputed,
            rel_attn_table,
            self.config.num_attention_heads // self.mapping.tp_size,
            build_config.max_seq_len,
            self.config.num_buckets,
            False,
            self.config.max_distance,
        )
        for layer_idx in range(self.num_layers):
            self.layers[layer_idx].self_attention.set_rel_attn_table(
                build_config.max_seq_len, rel_attn_precomputed)

TODO try to inherit the DecoderModelForCausalLM

[docs] class MLLaMAForCausalLM(PretrainedModel): config_class = MLLaMAConfig

def __init__(self, config: MLLaMAConfig):
    super().__init__(config)
    Attention.create_attention_const_params(self, config)
    self.position_embedding_type = config.position_embedding_type
    self.transformer = MLLaMAModel(config)

    self.mapping = self.config.mapping

    self.has_model_final_layernorm = self.config.has_model_final_layernorm
    self._dtype = self.config.dtype
    self._kv_dtype = self._dtype
    self._logits_dtype = self.config.logits_dtype

    if self.mapping.is_last_pp_rank():
        self.lm_head = ColumnLinear(
            self.config.hidden_size,
            self.config.vocab_size,
            bias=False if not hasattr(self.config, "has_lm_head_bias") else
            self.config.has_lm_head_bias,
            dtype=self.config.dtype,
            tp_group=self.config.mapping.tp_group,
            tp_size=self.config.mapping.tp_size,
            gather_output=True,
        )

    self.trtllm_modules_to_hf_modules = {
        **get_default_trtllm_modules_to_hf_modules(),
        "attn_q": "self_attn.q_proj",
        "attn_k": "self_attn.k_proj",
        "attn_v": "self_attn.v_proj",
        "attn_dense": "self_attn.o_proj",
        "cross_attn_q": "encoder_attn.q_proj",
        "cross_attn_k": "encoder_attn.k_proj",
        "cross_attn_v": "encoder_attn.v_proj",
        "cross_attn_dense": "encoder_attn.o_proj",
    }

[docs] def forward( self, decoder_input_ids: Tensor, encoder_output: Tensor, use_cache=False, attention_mask_params=None, last_token_ids=None, kv_cache_params=None, attention_params=None, hidden_states=None, lora_params: LoraParams = None, cross_kv_cache_gen: Optional[Tensor] = None, cross_kv_reuse: Optional[Tensor] = None, prompt_embedding_table: Optional[Tensor] = None, prompt_tasks: Optional[Tensor] = None, prompt_vocab_size: Optional[Tensor] = None, skip_cross_attn_blocks: Optional[Tensor] = None, ): if self.mapping.is_first_pp_rank(): assert isinstance(decoder_input_ids, Tensor) else: assert isinstance(hidden_states, Tensor) attention_params = Attention.fill_attention_params( self, attention_params) hidden_states = self.transformer( decoder_input_ids=decoder_input_ids, encoder_output=encoder_output, use_cache=use_cache, attention_mask_params=attention_mask_params, last_token_ids=last_token_ids, kv_cache_params=kv_cache_params, attention_params=attention_params, hidden_states=hidden_states, lora_params=lora_params, cross_kv_cache_gen=cross_kv_cache_gen, cross_kv_reuse=cross_kv_reuse, prompt_embedding_table=prompt_embedding_table, prompt_tasks=prompt_tasks, prompt_vocab_size=prompt_vocab_size, skip_cross_attn_blocks=skip_cross_attn_blocks, )

    if use_cache:
        hidden_states, presents = hidden_states

    if self.mapping.is_last_pp_rank():
        # [bs, seq, hidden_size] or [num_tokens, hidden_size] -> [bs, hidden_size]
        hidden_states = gather_last_token_logits(
            hidden_states, last_token_ids,
            default_net().plugin_config.remove_input_padding)

        # [bs, hidden_size] -> [bs, vocab_size]
        lm_logits = self.lm_head(hidden_states)
        lm_logits.mark_output(f'logits', self._logits_dtype)
    else:
        hidden_states = send(hidden_states, self.mapping.next_pp_rank())
        hidden_states.mark_output(f'hidden_states_output', self._dtype)

    if use_cache and default_net().plugin_config.paged_kv_cache == False:
        for i, present in zip(self.mapping.pp_layers(self.total_num_layers),
                              presents):
            present[0].mark_output(f'present_key_value_{i}', self._kv_dtype)
            if default_net().plugin_config.gpt_attention_plugin:
                present[1].mark_output(f'cross_present_key_value_{i}',
                                       self._kv_dtype)
        if self.mapping.is_last_pp_rank():
            return (lm_logits, tuple(presents))
        return (hidden_states, tuple(presents))
    else:
        if self.mapping.is_last_pp_rank():
            return lm_logits
        return hidden_states

[docs] def prepare_inputs(self, max_batch_size, max_beam_width, max_decoder_input_len, max_seq_len, max_encoder_input_len, gather_context_logits: bool = False, gather_generation_logits: bool = False, lora_target_modules: List[str] = None, prompt_embedding_table_size: int = 0, use_cache=True, *args, **kwargs): '''@brief: Prepare inputs Tensors for the model, the given sizes are used to determine the ranges of the dimensions of when using TRT dynamic shapes.

        @return: a list contains values which can be fed into the self.forward()
    '''

    # Prepare inputs
    max_output_len = max_decoder_input_len + max_seq_len

    head_size = self.transformer.head_size
    num_kv_heads = (self.transformer.num_kv_heads + self.mapping.tp_size -
                    1) // self.mapping.tp_size

    encoder_head_size = head_size
    encoder_num_kv_heads = num_kv_heads

    bb_range = [
        1, (max_batch_size * max_beam_width + 1) // 2,
        max_batch_size * max_beam_width
    ]
    bs_range = [1, (max_batch_size + 1) // 2, max_batch_size]
    beam_width_range = [1, (max_beam_width + 1) // 2, max_beam_width]
    inlen_range = [
        1, 1, max_decoder_input_len
    ]  # context phase >= 1 (if forced_input_ids), generation phase = 1
    encoder_inlen_range = [
        1, (max_encoder_input_len + 1) // 2, max_encoder_input_len
    ]
    mask_len_range = [1, (max_output_len + 1) // 2 + 1, max_output_len + 1]
    max_output_len_range = [0, (max_output_len + 1) // 2, max_output_len]

    encoder_num_tokens_range = [
        0,  # 0 for generation phase, >0 for context phase
        (max_encoder_input_len * max_batch_size + 1) // 2,
        max_encoder_input_len * max_batch_size,
    ]
    decoder_num_tokens_range = [
        1,
        max_batch_size * max_beam_width,
        max(max_decoder_input_len * max_batch_size,
            max_beam_width * max_batch_size),
    ]

    # No enable_two_optimization_profiles support yet
    encoder_input_len_range = [
        0,  # 0 for generation phase, >0 for context phase
        (max_encoder_input_len + 1) // 2,
        max_encoder_input_len
    ]
    # pack masks into bits (store as int32).
    max_cross_packed_mask_dim0 = max_batch_size * (
        (max_decoder_input_len + 128 - 1) // 128) * 128
    max_cross_packed_mask_dim1 = (
        (max_encoder_input_len + 256 - 1) // 256) * 256 // 32
    cross_packed_mask_dim0_range = [
        1, (max_cross_packed_mask_dim0 + 1) // 2, max_cross_packed_mask_dim0
    ]
    cross_packed_mask_dim1_range = [
        0,  # 0 for generation phase, >0 for context phase
        (max_cross_packed_mask_dim1 + 1) // 2,
        max_cross_packed_mask_dim1
    ]
    past_key_value = []
    sequence_length = None
    host_past_key_value_lengths = None
    attention_mask_params = AttentionMaskParams()
    use_gpt_attention_plugin = default_net(
    ).plugin_config.gpt_attention_plugin
    remove_input_padding = default_net().plugin_config.remove_input_padding
    paged_kv_cache = default_net().plugin_config.paged_kv_cache
    tokens_per_block = default_net().plugin_config.tokens_per_block
    use_lora_plugin = default_net().plugin_config.lora_plugin
    kv_cache_type = None
    if not use_cache:
        kv_cache_type = KVCacheType.DISABLED
    else:
        if paged_kv_cache:
            kv_cache_type = KVCacheType.PAGED
        else:
            kv_cache_type = KVCacheType.CONTINUOUS

    input_ids, hidden_states = None, None
    if remove_input_padding:
        if self.mapping.is_first_pp_rank():
            input_ids = Tensor(name='input_ids',
                               dtype=trt.int32,
                               shape=[-1],
                               dim_range=OrderedDict([
                                   ('decoder_num_tokens',
                                    [decoder_num_tokens_range]),
                               ]))
        else:
            hidden_states = Tensor(name='hidden_states_input',
                                   dtype=self._dtype,
                                   shape=[-1, self.hidden_size],
                                   dim_range=OrderedDict([
                                       ('decoder_num_tokens',
                                        [decoder_num_tokens_range]),
                                       ('hidden_size', [self.hidden_size]),
                                   ]))
    else:
        if self.mapping.is_first_pp_rank():
            input_ids = Tensor(name='input_ids',
                               dtype=trt.int32,
                               shape=[-1, -1],
                               dim_range=OrderedDict([
                                   ('batch_size_beam_width', [bb_range]),
                                   ('input_len', [inlen_range]),
                               ]))
        else:
            hidden_states = Tensor(name='hidden_states_input',
                                   dtype=self._dtype,
                                   shape=[-1, -1, self.hidden_size],
                                   dim_range=OrderedDict([
                                       ('batch_size_beam_width', [bb_range
                                                                  ]),
                                       ('input_len', [inlen_range]),
                                       ('hidden_size', [self.hidden_size]),
                                   ]))

    encoder_input_lengths = Tensor(
        name="encoder_input_lengths",
        dtype=trt.int32,
        shape=[-1],
        dim_range=OrderedDict([("batch_size_beam_width", [bb_range])]),
    )
    encoder_max_input_length = Tensor(
        name="encoder_max_input_length",
        dtype=trt.int32,
        shape=[-1],
        dim_range=OrderedDict([("encoder_max_input_length",
                                [encoder_inlen_range])]),
    )
    if remove_input_padding:
        encoder_output = Tensor(
            name="encoder_output",
            dtype=self._dtype,
            shape=[-1, self.config.hidden_size],
            dim_range=OrderedDict([
                ("encoder_num_tokens", [encoder_num_tokens_range]),
                ("hidden_size", [self.config.hidden_size]),
            ]),
        )
    else:
        encoder_output = Tensor(
            name="encoder_output",
            dtype=self._dtype,
            shape=[-1, -1, self.config.hidden_size],
            dim_range=OrderedDict([
                ("batch_size_beam_width_encoder", [bb_range]),
                ("encoder_input_len", [encoder_input_len_range]),
                ("hidden_size", [self.config.hidden_size]),
            ]),
        )

    context_lengths = None
    host_context_lengths = None
    host_request_types = None
    host_runtime_perf_knobs = None
    host_context_progress = None
    if use_gpt_attention_plugin and remove_input_padding:
        host_context_lengths = Tensor(name='host_context_lengths',
                                      dtype=trt.int32,
                                      shape=[-1],
                                      dim_range=OrderedDict([
                                          ('batch_size_beam_width',
                                           [bb_range])
                                      ]))

    if use_gpt_attention_plugin:
        if kv_cache_type != KVCacheType.DISABLED:
            sequence_length = Tensor(
                name='sequence_length',
                dtype=trt.int32,
                shape=[-1],
                dim_range=OrderedDict([('batch_size_beam_width', [bb_range])
                                       ]),
            )
            host_past_key_value_lengths = Tensor(
                name='host_past_key_value_lengths',
                dtype=trt.int32,
                shape=[-1],
                dim_range=OrderedDict([('batch_size_beam_width', [bb_range])
                                       ]),
            )

        context_lengths = Tensor(name='context_lengths',
                                 dtype=trt.int32,
                                 shape=[-1],
                                 dim_range=OrderedDict([
                                     ('batch_size_beam_width', [bb_range])
                                 ]))
        host_request_types = Tensor(name='host_request_types',
                                    dtype=trt.int32,
                                    shape=[-1],
                                    dim_range=OrderedDict([
                                        ('batch_size_beam_width',
                                         [bb_range])
                                    ]))

        host_runtime_perf_knobs = Tensor(name='host_runtime_perf_knobs',
                                         dtype=trt.int64,
                                         shape=[16],
                                         dim_range=OrderedDict([
                                             ('perf_knob_size', [16])
                                         ]))

        host_context_progress = Tensor(name='host_context_progress',
                                       dtype=trt.int64,
                                       shape=[1],
                                       dim_range=OrderedDict([
                                           ('context_progress_size', [1])
                                       ]))

    last_token_ids = None
    if self.mapping.is_last_pp_rank() and not gather_context_logits:
        last_token_ids = Tensor(
            name="last_token_ids",
            dtype=trt.int32,
            shape=[-1],
            dim_range=OrderedDict([("batch_size_last_token_ids", [bb_range])
                                   ]),
        )

    attention_mask = None
    if not use_gpt_attention_plugin:
        attention_mask = Tensor(
            name='attention_mask',
            dtype=trt.int32,
            shape=[-1, -1],
            dim_range=OrderedDict([
                ('batch_size_beam_width', [bb_range]),
                ('mask_len', [mask_len_range]),
            ]),
        )
        assert False, "not support non-attention-plugin case now"

    cross_attention_mask = Tensor(
        name='cross_attention_mask',
        dtype=trt.bool,
        shape=[-1, -1],
        dim_range=OrderedDict([
            ('decoder_num_tokens_2',
             [decoder_num_tokens_range
              ]),  # TODO should use same name as input_ids
            ('encoder_input_len_2', [encoder_input_len_range]),
        ]),
    )

    cross_attention_packed_mask = Tensor(
        name='cross_attention_packed_mask',
        dtype=trt.int32,
        shape=[-1, -1],
        dim_range=OrderedDict([
            ('cross_packed_mask_dim0', [cross_packed_mask_dim0_range]),
            ('cross_packed_mask_dim1', [cross_packed_mask_dim1_range]),
        ]),
    )

    # create the attention_mask_params.
    attention_mask_params = AttentionMaskParams(
        attention_mask, None, cross_attention_mask,
        cross_attention_packed_mask)

    cache_indirection = Tensor(
        name='cache_indirection',
        dtype=trt.int32,
        shape=[-1, -1, -1],
        dim_range=OrderedDict([
            ('batch_size_cache', [bs_range]),
            ('beam_width', [beam_width_range]),
            ('max_seq_len', [max_output_len_range]),
        ]),
    )

    layers_range = self.mapping.pp_layers(self.transformer.total_num_layers)
    num_pp_layers = len(layers_range)

    host_max_attention_window_sizes = None
    host_sink_token_length = None
    if use_gpt_attention_plugin:
        host_max_attention_window_sizes = Tensor(
            name=f'host_max_attention_window_sizes',
            dtype=trt.int32,
            shape=[num_pp_layers],
            dim_range=OrderedDict([('num_layers', [num_pp_layers])]))
        host_sink_token_length = Tensor(name='host_sink_token_length',
                                        dtype=trt.int32,
                                        shape=[1],
                                        dim_range=OrderedDict([('scalar',
                                                                [1])]))
    # TODO LoRA for mllama is not verified.
    lora_weights_pointers = None
    lora_ranks = None
    lora_params = None
    if use_lora_plugin:
        lora_weights_pointers = []
        lora_ranks = []
        missing_qkv_modules = []
        if any(x in lora_target_modules
               for x in ["attn_q", "attn_k", "attn_v"]):
            for lora_module in [
                    "attn_q",
                    "attn_k",
                    "attn_v",
            ]:
                if lora_module not in lora_target_modules:
                    missing_qkv_modules.append(lora_module)
        if any(x in lora_target_modules
               for x in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]):
            for lora_module in [
                    "cross_attn_q", "cross_attn_k", "cross_attn_v"
            ]:
                if lora_module not in lora_target_modules:
                    missing_qkv_modules.append(lora_module)

        # For LoRA
        for i in layers_range:
            lora_weight_pointer_dict = {}
            lora_rank_dict = {}
            for lora_module in (lora_target_modules + missing_qkv_modules):
                lora_weight_pointer = Tensor(
                    name=f'{lora_module}_lora_weights_pointers_{i}',
                    dtype=trt.int64,
                    shape=[-1, 3],
                    dim_range=OrderedDict([('batch_size_beam_width',
                                            [bb_range]),
                                           ('in_out_scales', [3])]))
                lora_weight_pointer_dict.update({
                    f'{lora_module}_lora_weights_pointers':
                    lora_weight_pointer
                })

                lora_rank = Tensor(name=f'{lora_module}_lora_ranks_{i}',
                                   dtype=trt.int32,
                                   shape=[-1],
                                   dim_range=OrderedDict([
                                       ('batch_size_beam_width', [bb_range])
                                   ]))
                lora_rank_dict.update(
                    {f'{lora_module}_lora_ranks': lora_rank})

            lora_weights_pointers.append(lora_weight_pointer_dict)
            lora_ranks.append(lora_rank_dict)

        # For cross attention, we need to use encoder_input_lengths (in CPU) to pass
        # as the host_context_lengths to the lora_plugin. But for self attention, we
        # should keep using the original host_context_lengths. Therefore, we keep both
        # of them in the lora_params.
        host_encoder_input_lengths = None
        if remove_input_padding:
            host_encoder_input_lengths = Tensor(
                name="host_encoder_input_lengths",
                dtype=trt.int32,
                shape=[-1],
                dim_range=OrderedDict([("batch_size_beam_width", [bb_range])
                                       ]),
            )

        lora_params = LoraParams(
            lora_ranks=lora_ranks,
            lora_weights_pointers=lora_weights_pointers,
            host_context_lengths=host_context_lengths,
            max_context_length=max_decoder_input_len,
            max_encoder_context_length=max_encoder_input_len,
            host_request_types=host_request_types,
            host_encoder_input_lengths=host_encoder_input_lengths,
        )

    kv_cache_block_offsets = None
    host_kv_cache_block_offsets = None
    host_kv_cache_pool_pointers = None
    host_kv_cache_pool_mapping = None

    cross_kv_cache_block_offsets = None
    host_cross_kv_cache_block_offsets = None
    host_cross_kv_cache_pool_pointers = None
    host_cross_kv_cache_pool_mapping = None

    if use_cache:
        if not paged_kv_cache:
            for i in layers_range:
                kv_dim_range = OrderedDict([
                    ('batch_size_beam_width', [bb_range]),
                    ('kv', [2]),
                    ('num_heads', [num_kv_heads]),
                    ('past_key_len', [max_output_len_range]),
                    ('head_size', [head_size]),
                ])
                kv = Tensor(name=f'past_key_value_{i}',
                            dtype=self._kv_dtype,
                            shape=[-1, 2, num_kv_heads, -1, head_size],
                            dim_range=kv_dim_range)

                past_key_value.append(kv)

            if i in self.transformer.cross_attention_layers:
                xa_layer_id = self.transformer.cross_attention_layers.index(
                    i) + layers_range[-1]
                cross_kv_dim_range = OrderedDict([
                    ('batch_size_beam_width', [bb_range]),
                    ('kv', [2]),
                    ('cross_num_heads', [encoder_num_kv_heads]),
                    ('cross_past_key_len', [encoder_input_len_range]),
                    ('cross_head_size', [encoder_head_size]),
                ])
                cross_kv = Tensor(
                    name=f'cross_past_key_value_{xa_layer_id}',
                    dtype=self._kv_dtype,
                    shape=[
                        -1, 2, encoder_num_kv_heads, -1, encoder_head_size
                    ],
                    dim_range=cross_kv_dim_range)
                past_key_value.append(kv)

            # TODO: Remove this when TRT fix the named dimension
            if not remove_input_padding:
                assertion(
                    shape(
                        input_ids if self.mapping.is_first_pp_rank() else
                        hidden_states, 0) == shape(kv, 0), 'batch size')

        else:  # paged_kv_cache == True
            # PagedKV setup for KV cache of self-attention
            max_blocks_per_seq_range = [[
                math.ceil(max_output_len_range[0] / tokens_per_block),
                math.ceil(max_output_len_range[1] / tokens_per_block),
                math.ceil(max_output_len_range[2] / tokens_per_block)
            ]]
            max_blocks_per_seq_range = [[
                x for x in max_blocks_per_seq_range[0]
            ]]

            # PagedKV setup for KV cache of cross-attention
            max_cross_blocks_per_seq_range = [[
                math.ceil(encoder_input_len_range[0] / tokens_per_block),
                math.ceil(encoder_input_len_range[1] / tokens_per_block),
                math.ceil(encoder_input_len_range[2] / tokens_per_block)
            ]]
            max_cross_blocks_per_seq_range = [[
                x for x in max_cross_blocks_per_seq_range[0]
            ]]

            num_kv_cache_pools = 2

            kv_cache_block_offsets = Tensor(
                name=f'kv_cache_block_offsets',
                dtype=trt.int32,
                shape=[num_kv_cache_pools, -1, 2, -1],
                dim_range=OrderedDict([
                    ('num_kv_cache_pools', [num_kv_cache_pools]),
                    ('batch_size_beam_width', [bb_range]),
                    ('kv', [2]),
                    ('max_blocks_per_seq', max_blocks_per_seq_range),
                ]))
            host_kv_cache_block_offsets = Tensor(
                name=f'host_kv_cache_block_offsets',
                dtype=trt.int32,
                shape=[num_kv_cache_pools, -1, 2, -1],
                dim_range=OrderedDict([
                    ('num_kv_cache_pools', [num_kv_cache_pools]),
                    ('batch_size_beam_width', [bb_range]),
                    ('kv', [2]),
                    ('max_blocks_per_seq', max_blocks_per_seq_range),
                ]))
            host_kv_cache_pool_pointers = Tensor(
                name=f'host_kv_cache_pool_pointers',
                dtype=trt.int64,
                shape=[num_kv_cache_pools, 2],
                dim_range=OrderedDict([
                    ('num_kv_cache_pools', [num_kv_cache_pools]),
                    ('num_pools', [2]),
                ]))
            host_kv_cache_pool_mapping = Tensor(
                name=f"host_kv_cache_pool_mapping",
                dtype=trt.int32,
                # 2: (Index of pool, Index of layer within pool)
                shape=[num_pp_layers, 2],
                dim_range=OrderedDict([
                    ('pools_mapping', [num_pp_layers]),
                    ('layer_cache_pool_locator', [2]),
                ]))

            # paged blocks for cross kv
            cross_kv_cache_block_offsets = Tensor(
                name=f'cross_kv_cache_block_offsets',
                dtype=trt.int32,
                shape=[num_kv_cache_pools, -1, 2, -1],
                dim_range=OrderedDict([
                    ('num_kv_cache_pools', [num_kv_cache_pools]),
                    ('batch_size_beam_width', [bb_range]),
                    ('kv', [2]),
                    ('max_cross_blocks_per_seq',
                     max_cross_blocks_per_seq_range),
                ]))
            host_cross_kv_cache_block_offsets = Tensor(
                name=f'host_cross_kv_cache_block_offsets',
                dtype=trt.int32,
                shape=[num_kv_cache_pools, -1, 2, -1],
                dim_range=OrderedDict([
                    ('num_kv_cache_pools', [num_kv_cache_pools]),
                    ('batch_size_beam_width', [bb_range]),
                    ('kv', [2]),
                    ('max_cross_blocks_per_seq',
                     max_cross_blocks_per_seq_range),
                ]))
            host_cross_kv_cache_pool_pointers = Tensor(
                name=f'host_cross_kv_cache_pool_pointers',
                dtype=trt.int64,
                shape=[num_kv_cache_pools, 2],
                dim_range=OrderedDict([
                    ('num_kv_cache_pools', [num_kv_cache_pools]),
                    ('num_pools', [2]),
                ]))
            host_cross_kv_cache_pool_mapping = Tensor(
                name=f"host_cross_kv_cache_pool_mapping",
                dtype=trt.int32,
                # 2: (Index of pool, Index of layer within pool)
                shape=[num_pp_layers, 2],
                dim_range=OrderedDict([
                    ('pools_mapping', [num_pp_layers]),
                    ('layer_cache_pool_locator', [2]),
                ]))

            for i in layers_range:
                past_key_value.append(None)

        kv_cache_params = KeyValueCacheParams(
            past_key_value=past_key_value,
            host_past_key_value_lengths=host_past_key_value_lengths,
            host_max_attention_window_sizes=host_max_attention_window_sizes,
            host_sink_token_length=host_sink_token_length,
            cache_indirection=cache_indirection,
            kv_cache_block_offsets=kv_cache_block_offsets,
            host_kv_cache_block_offsets=host_kv_cache_block_offsets,
            host_kv_cache_pool_pointers=host_kv_cache_pool_pointers,
            host_kv_cache_pool_mapping=host_kv_cache_pool_mapping,
            cross_kv_cache_block_offsets=cross_kv_cache_block_offsets,
            host_cross_kv_cache_block_offsets=
            host_cross_kv_cache_block_offsets,
            host_cross_kv_cache_pool_pointers=
            host_cross_kv_cache_pool_pointers,
            host_cross_kv_cache_pool_mapping=
            host_cross_kv_cache_pool_mapping,
        )

        attention_params = AttentionParams(
            sequence_length=sequence_length,
            context_lengths=context_lengths,
            host_context_lengths=host_context_lengths,
            max_context_length=max_decoder_input_len,
            host_request_types=host_request_types,
            host_runtime_perf_knobs=host_runtime_perf_knobs,
            host_context_progress=host_context_progress,
            encoder_input_lengths=encoder_input_lengths,
            encoder_max_input_length=encoder_max_input_length,
        )

    cross_kv_cache_gen = Tensor(name='cross_kv_cache_gen',
                                dtype=trt.bool,
                                shape=[1],
                                dim_range=OrderedDict([
                                    ('boolean', [1]),
                                ]))
    cross_kv_reuse = None
    num_heads = (self.transformer.num_heads + self.mapping.tp_size -
                 1) // self.mapping.tp_size
    cross_kv_out_dim = 2 * num_kv_heads * self.transformer.head_size
    if self.transformer.skip_cross_kv:
        if remove_input_padding:
            cross_kv_reuse = Tensor(
                name="cross_kv_reuse",
                dtype=self._dtype,
                shape=[-1, cross_kv_out_dim],
                dim_range=OrderedDict([
                    ("encoder_num_tokens", [encoder_num_tokens_range]),
                    ("encoder_kv_size", [cross_kv_out_dim]),
                ]),
            )
        else:
            cross_kv_reuse = Tensor(
                name="cross_kv_reuse",
                dtype=self._dtype,
                shape=[-1, -1, cross_kv_out_dim],
                dim_range=OrderedDict([
                    ("batch_size_beam_width_encoder", [bb_range]),
                    ("encoder_input_len", [encoder_input_len_range]),
                    ("encoder_kv_size", [cross_kv_out_dim]),
                ]),
            )

    skip_cross_attn_blocks = None
    if self.config.skip_cross_attn_blocks:
        skip_cross_attn_blocks = Tensor(name='skip_cross_attn_blocks',
                                        dtype=trt.bool,
                                        shape=[1],
                                        dim_range=OrderedDict([
                                            ('boolean', [1]),
                                        ]))

    prompt_embedding_table = None
    tasks = None
    prompt_vocab_size = None

    if self.mapping.is_first_pp_rank() and prompt_embedding_table_size > 0:
        p_embedding_range = [[
            1, prompt_embedding_table_size // 2, prompt_embedding_table_size
        ]]

        prompt_embedding_table = Tensor(
            name='prompt_embedding_table',
            dtype=self._dtype,
            shape=[-1, self.transformer.hidden_size],
            dim_range=OrderedDict([
                ('prompt_embedding_table_size', p_embedding_range),
                ('hidden_size', [self.transformer.hidden_size]),
            ]))
        if remove_input_padding:
            num_tokens_range = [
                1,
                (max_decoder_input_len * max_batch_size + 1) // 2,
                max_decoder_input_len * max_batch_size,
            ]
            tasks = Tensor(name='tasks',
                           dtype=trt.int32,
                           shape=[-1],
                           dim_range=OrderedDict([
                               ('decoder_num_tokens',
                                [decoder_num_tokens_range]),
                           ]))
        else:
            tasks = Tensor(name='tasks',
                           dtype=trt.int32,
                           shape=[-1, 1],
                           dim_range=OrderedDict([
                               ('batch_size', bs_range),
                               ('broadcast_dim', [1]),
                           ]))
        prompt_vocab_size = Tensor(name='prompt_vocab_size',
                                   dtype=trt.int32,
                                   shape=[1],
                                   dim_range=OrderedDict([('size', [1])]))

    result = {
        'decoder_input_ids': input_ids,
        'encoder_output': encoder_output,
        'use_cache': True,
        'attention_mask_params': attention_mask_params,
        'last_token_ids': last_token_ids,
        'kv_cache_params': kv_cache_params,
        'attention_params': attention_params,
        'hidden_states': hidden_states,
        'lora_params': lora_params,
        'cross_kv_cache_gen': cross_kv_cache_gen,
        'cross_kv_reuse': cross_kv_reuse,
        'prompt_embedding_table': prompt_embedding_table,
        'prompt_tasks': tasks,
        'prompt_vocab_size': prompt_vocab_size,
        'skip_cross_attn_blocks': skip_cross_attn_blocks,
    }

    return result

[docs] def use_lora(self, lora_config: LoraConfig): use_lora(self, lora_config, self.trtllm_modules_to_hf_modules)

[docs] @classmethod def from_hugging_face( cls, hf_model_or_dir: Union[str, 'transformers.PreTrainedModel'], dtype: str = 'auto', mapping: Optional[Mapping] = None, quant_config: Optional[QuantConfig] = None, **kwargs): ''' Create a MLLaMAForCausalLM object from give parameters ''' import transformers

    kwargs.pop('load_by_shard', False)
    kwargs.pop('load_model_on_cpu', False)
    quant_ckpt_path = kwargs.pop('quant_ckpt_path', None)

    assert hf_model_or_dir is not None
    use_preloading = isinstance(hf_model_or_dir,
                                transformers.PreTrainedModel)
    if use_preloading:
        hf_model = hf_model_or_dir
        hf_config_or_dir = hf_model.config
    else:
        hf_model_dir = hf_model_or_dir
        hf_config_or_dir = hf_model_or_dir

    config = MLLaMAConfig.from_hugging_face(hf_config_or_dir,
                                            dtype=dtype,
                                            mapping=mapping,
                                            quant_config=quant_config,
                                            **kwargs)

    custom_dict = {
        "lm_head": "language_model.lm_head",
        "transformer.ln_f": "language_model.model.norm",
        "transformer": "language_model.model",
        "self_attention": "self_attn",
        "cross_attention": "cross_attn",
        "vocab_embedding": "embed_tokens",
        "gate_attn": "cross_attn_attn_gate",
        "gate_ffwd": "cross_attn_mlp_gate",
        "q_layernorm": "q_norm",
        "k_layernorm": "k_norm",
    }

    if quant_ckpt_path is not None:
        hf_model_dir = quant_ckpt_path

    loader = ModelWeightsLoader(hf_model_dir, custom_dict)
    model = cls(config)
    loader.generate_tllm_weights(model)

    return model