tensorrt_llm.models.llama.model — TensorRT-LLM (original) (raw)

SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

SPDX-License-Identifier: Apache-2.0

Licensed under the Apache License, Version 2.0 (the "License");

you may not use this file except in compliance with the License.

You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software

distributed under the License is distributed on an "AS IS" BASIS,

WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

See the License for the specific language governing permissions and

limitations under the License.

import os from typing import Optional, Union

import transformers

from ..._common import default_net from ..._utils import pad_vocab_size from ...functional import (AllReduceFusionOp, AllReduceParams, Tensor, allgather, concat, constant, div, non_gated_version, recv, send, unsqueeze) from ...layers import (MOE, Attention, AttentionMaskType, ColumnLinear, Embedding, FusedGatedMLP, GatedMLP, PositionEmbeddingType, RmsNorm) from ...lora_manager import LoraConfig, use_lora from ...mapping import Mapping from ...module import Module from ...quantization.functional import fused_layernorm from ..convert_utils import has_safetensors from ..model_weights_loader import ModelWeightsLoader from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM, QuantConfig) from .config import LLaMAConfig from .convert import (load_hf_llama, load_weights_from_deepcompressor, load_weights_from_gptq, load_weights_from_hf_by_shard, load_weights_from_hf_model, load_weights_from_hf_safetensors, load_weights_from_meta_ckpt)

class LLaMADecoderLayer(Module):

def __init__(self, config: LLaMAConfig, layer_idx: int):
    super().__init__()
    self.layer_idx = layer_idx
    layer_idx += config.layer_idx_offset
    self.config = config
    self.mapping = config.mapping

    if (self.config.use_input_layernorm_in_first_layer
            and self.layer_idx == 0) or self.layer_idx > 0:
        self.input_layernorm = RmsNorm(normalized_shape=config.hidden_size,
                                       eps=config.norm_epsilon,
                                       dtype=config.dtype)

    layers_range = config.mapping.pp_layers(config.num_hidden_layers)
    self.local_layer_idx = layer_idx - layers_range[0]
    self.is_last_local_layer = layer_idx == layers_range[-1]
    self.attention = Attention(
        local_layer_idx=self.local_layer_idx,
        hidden_size=config.hidden_size,
        attention_head_size=config.head_size,
        num_attention_heads=config.num_attention_heads,
        num_kv_heads=config.num_key_value_heads,
        max_position_embeddings=config.max_position_embeddings,
        dtype=config.dtype,
        attention_mask_type=AttentionMaskType.causal,
        bias=config.attn_bias,
        position_embedding_type=PositionEmbeddingType.rope_gpt_neox,
        rotary_embedding_base=config.rotary_base,
        rotary_embedding_scaling=config.rotary_scaling,
        tp_group=config.mapping.tp_group,
        tp_size=config.mapping.tp_size,
        tp_rank=config.mapping.tp_rank,
        q_scaling=1.0 / config.attention_multiplier,
        quant_mode=config.quant_mode,
        cp_group=config.mapping.cp_group,
        cp_size=config.mapping.cp_size,
        cp_rank=config.mapping.cp_rank)

    mlp_hidden_size = config.hidden_size * 4 if config.intermediate_size is None else config.intermediate_size

    ClsMLP = GatedMLP
    mlp_kwargs = {}
    if config.moe.has_moe():
        ClsMLP = MOE
        mlp_kwargs = {
            "moe_config": config.moe,
            "mapping": config.mapping,
        }
    self.mlp = ClsMLP(hidden_size=config.hidden_size,
                      ffn_hidden_size=mlp_hidden_size,
                      hidden_act=config.hidden_act,
                      dtype=config.dtype,
                      bias=config.mlp_bias,
                      tp_group=config.mapping.tp_group,
                      tp_size=config.mapping.tp_size,
                      quant_mode=config.quant_mode,
                      **mlp_kwargs)

    self.post_layernorm = RmsNorm(normalized_shape=config.hidden_size,
                                  eps=config.norm_epsilon,
                                  dtype=config.dtype)

    # Residual MLP that applies on pre-attention input
    # TODO: change to self.has_residual_mlp = self.config.residual_mlp after ModelOpt quantize config is updated
    self.has_residual_mlp = False
    if hasattr(self.config,
               "residual_mlp") and self.config.residual_mlp is True:
        self.has_residual_mlp = True

    if self.has_residual_mlp:
        self.residual_layernorm = RmsNorm(
            normalized_shape=config.hidden_size,
            eps=config.norm_epsilon,
            dtype=config.dtype)
        ClsMLP = GatedMLP  # TODO: may use FusedGatedMLP to further speedup
        self.residual_mlp = ClsMLP(
            hidden_size=config.hidden_size,
            ffn_hidden_size=config.
            hidden_size,  # residual mlp uses hidden_size
            hidden_act=non_gated_version(
                config.hidden_act),  # back to non-gated
            dtype=config.dtype,
            bias=config.mlp_bias,
            tp_group=config.mapping.tp_group,
            tp_size=config.mapping.tp_size,
            quant_mode=config.quant_mode)

def forward(self,
            hidden_states,
            attention_mask=None,
            use_cache=False,
            spec_decoding_params=None,
            kv_cache_params=None,
            attention_params=None,
            lora_layer_params=None,
            next_layer_input_layernorm_args=None):
    assert not (
        default_net().plugin_config.reduce_fusion and self.has_residual_mlp
    ), "Custom all reduce and residual mlp can't be enabled at the same time."
    assert not (
        default_net().plugin_config.reduce_fusion
        and default_net().plugin_config.user_buffer
        and default_net().plugin_config.pp_reduce_scatter
    ), "User buffer reduce fusion enabled with PP reduce scatter is not supported now."
    assert not (
        default_net().plugin_config.reduce_fusion
        and default_net().plugin_config.norm_quant_fusion
    ), "Reduce fusion and quant fusion can't be enabled at the same time."
    if default_net(
    ).plugin_config.reduce_fusion and self.local_layer_idx > 0:
        hidden_states, residual = hidden_states
    elif default_net(
    ).plugin_config.norm_quant_fusion and self.local_layer_idx > 0:
        hidden_states, residual = hidden_states
    else:
        residual = hidden_states
        if (self.config.use_input_layernorm_in_first_layer
                and self.layer_idx == 0) or self.layer_idx > 0:
            hidden_states = self.input_layernorm(hidden_states)

    reduce_fusion_op = AllReduceFusionOp.NONE
    if default_net().plugin_config.reduce_fusion:
        if default_net().plugin_config.user_buffer:
            if self.config.quant_mode.has_fp8_qdq():
                reduce_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8
            elif self.config.quant_mode.has_nvfp4():
                assert default_net(
                ).plugin_config.gemm_plugin == "nvfp4", "UB with nvfp4 model must use nvfp4 gemm plugin"
                reduce_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4
            else:
                assert False, "UB must enabled with fp8 or nvfp4 model"
        else:
            reduce_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM

    reduce_fusion_scale = None
    if default_net().plugin_config.reduce_fusion and default_net(
    ).plugin_config.user_buffer:
        if isinstance(self.mlp, FusedGatedMLP):
            if self.config.quant_mode.has_fp8_qdq():
                reduce_fusion_scale = constant(
                    self.mlp.fused_fc.activation_scaling_factor.raw_value.
                    copy())
            elif self.config.quant_mode.has_nvfp4():
                reduce_fusion_scale = constant(
                    [1.0] / self.mlp.fused_fc.
                    activation_global_scaling_factor.raw_value)
        else:
            if self.config.quant_mode.has_fp8_qdq():
                reduce_fusion_scale = constant(
                    self.mlp.fc.activation_scaling_factor.raw_value.copy())
            elif self.config.quant_mode.has_nvfp4():
                reduce_fusion_scale = constant(
                    [1.0] /
                    self.mlp.fc.activation_global_scaling_factor.raw_value)
    attention_output = self.attention(
        hidden_states,
        attention_mask=attention_mask,
        use_cache=use_cache,
        spec_decoding_params=spec_decoding_params,
        kv_cache_params=kv_cache_params,
        attention_params=attention_params,
        lora_layer_params=lora_layer_params,
        all_reduce_params=AllReduceParams(
            fusion_op=reduce_fusion_op,
            residual=residual,
            norm_weight=self.post_layernorm.weight.value,
            scale=reduce_fusion_scale,
            eps=self.post_layernorm.eps))
    if use_cache:
        attention_output, presents = attention_output

    if self.has_residual_mlp:
        hidden_states = residual + attention_output
        residual_attn = hidden_states
        # arctic layer w/ residual mlp

        # residual mlp
        hidden_states = self.residual_layernorm(hidden_states)
        hidden_states = self.residual_mlp(hidden_states)
        residual_mlp = residual_attn + hidden_states

        # parallel moe
        # parallel moe layers applies on PRE-ATTENTION input residual, therefore achieving pre-fetching and better parallelism
        hidden_states = self.post_layernorm(residual)
        hidden_states = self.mlp(hidden_states,
                                 lora_layer_params=lora_layer_params)
        hidden_states = residual_mlp + hidden_states
    else:
        if default_net().plugin_config.reduce_fusion:
            hidden_states, residual = attention_output
        elif default_net().plugin_config.norm_quant_fusion:
            hidden_states, residual_attn, act_per_block_scale = fused_layernorm(
                input=attention_output,
                normalized_shape=self.config.hidden_size,
                residual=residual,
                weight=self.post_layernorm.weight.value,
                scale=div(
                    1, self.mlp.fc.activation_global_scaling_factor.value)
                if self.mlp.fc.activation_global_scaling_factor.value else
                None,
                eps=self.post_layernorm.eps,
                p_dtype=self.config.dtype)

            hidden_states, residual_attn = (
                hidden_states, act_per_block_scale), residual_attn
            assert isinstance(hidden_states, tuple)
        else:
            hidden_states = residual + attention_output * self.config.residual_multiplier
            residual = hidden_states
            hidden_states = self.post_layernorm(hidden_states)
        if next_layer_input_layernorm_args is not None:
            #this is middle layer
            hidden_states = self.mlp(
                hidden_states,
                lora_layer_params=lora_layer_params,
                all_reduce_params=AllReduceParams(
                    fusion_op=reduce_fusion_op,
                    residual=residual_attn
                    if default_net().plugin_config.norm_quant_fusion else
                    residual,
                    norm_weight=next_layer_input_layernorm_args[0],
                    scale=next_layer_input_layernorm_args[2],
                    eps=next_layer_input_layernorm_args[1]))
            if default_net().plugin_config.norm_quant_fusion:
                hidden_states, residual, act_per_block_scale = fused_layernorm(
                    input=hidden_states,
                    normalized_shape=self.config.hidden_size,
                    residual=residual_attn,
                    weight=next_layer_input_layernorm_args[0],
                    scale=div(1, next_layer_input_layernorm_args[2])
                    if next_layer_input_layernorm_args[2] else None,
                    eps=next_layer_input_layernorm_args[1],
                    p_dtype=self.config.dtype)
                hidden_states = (hidden_states,
                                 act_per_block_scale), residual
        else:
            if default_net(
            ).plugin_config.pp_reduce_scatter and self.is_last_local_layer and not self.mapping.is_last_pp_rank(
            ):
                hidden_states = self.mlp(
                    hidden_states,
                    lora_layer_params=lora_layer_params,
                    last_local_layer_residual=residual)
            else:
                if (default_net().plugin_config.reduce_fusion
                        and default_net().plugin_config.user_buffer):
                    hidden_states, residual = self.mlp(
                        hidden_states,
                        lora_layer_params=lora_layer_params,
                        all_reduce_params=AllReduceParams(
                            fusion_op=AllReduceFusionOp.LAST_PROCESS_FOR_UB,
                            residual=residual))
                else:
                    hidden_states = self.mlp(
                        hidden_states, lora_layer_params=lora_layer_params)
                hidden_states = residual + hidden_states * self.config.residual_multiplier
    if use_cache:
        return (hidden_states, presents)
    return hidden_states

[docs] class LLaMAModel(Module):

def __init__(self, config: LLaMAConfig) -> None:
    super().__init__()

    self.mapping = config.mapping
    self.vocab_size = config.vocab_size
    self.has_partial_lora_mask = config.has_partial_lora_mask
    self.hidden_size = config.hidden_size
    if self.mapping.is_first_pp_rank():
        self.vocab_embedding = Embedding(config.vocab_size,
                                         config.hidden_size,
                                         dtype=config.dtype)
        self.embedding_multiplier = config.embedding_multiplier

    self.layers = DecoderLayerList(LLaMADecoderLayer, config)

    if config.fc_after_embed:
        self.fc = ColumnLinear(2 * config.hidden_size,
                               config.hidden_size,
                               bias=True,
                               dtype=config.dtype,
                               tp_group=config.mapping.tp_group,
                               tp_size=config.mapping.tp_size,
                               gather_output=True)

    if self.mapping.is_last_pp_rank():
        self.ln_f = None
        if config.use_last_layernorm:
            self.ln_f = RmsNorm(normalized_shape=config.hidden_size,
                                eps=config.norm_epsilon,
                                dtype=config.dtype)

[docs] def forward(self, input_ids, position_ids=None, use_cache=False, attention_mask=None, spec_decoding_params=None, kv_cache_params=None, attention_params=None, hidden_states=None, hidden_states_for_embed=None, prompt_embedding_table: Optional[Tensor] = None, prompt_tasks: Optional[Tensor] = None, prompt_vocab_size: Optional[Tensor] = None, lora_params=None):

    ptuning_args = [
        prompt_embedding_table, prompt_tasks, prompt_vocab_size
    ] if prompt_embedding_table is not None else []

    if self.mapping.is_first_pp_rank():
        hidden_states = self.vocab_embedding(input_ids, *ptuning_args)
        hidden_states *= self.embedding_multiplier
    else:
        hidden_states = recv(hidden_states, self.mapping.prev_pp_rank())
        if default_net().plugin_config.pp_reduce_scatter:
            hidden_states = allgather(hidden_states,
                                      self.mapping.tp_group,
                                      gather_dim=0)
            # reshape to (-1, hidden_size)
            hidden_states = hidden_states.view(
                concat([-1, self.hidden_size]))

    if hidden_states_for_embed is not None:
        hidden_states = concat([hidden_states, hidden_states_for_embed],
                               dim=-1)
        hidden_states = self.fc(hidden_states)

    if lora_params is not None and self.has_partial_lora_mask:
        partial_lora_mask = input_ids > (self.vocab_size - 1)
        lora_params.partial_lora_mask = unsqueeze(partial_lora_mask, -1)

    hidden_states = self.layers.forward(
        hidden_states,
        use_cache=use_cache,
        attention_mask=attention_mask,
        kv_cache_params=kv_cache_params,
        attention_params=attention_params,
        lora_params=lora_params,
        spec_decoding_params=spec_decoding_params)

    if use_cache:
        hidden_states, presents = hidden_states

    if self.mapping.is_last_pp_rank():
        if self.ln_f:
            hidden_states = self.ln_f(hidden_states)
    else:
        hidden_states = send(hidden_states, self.mapping.next_pp_rank())

    if use_cache:
        return (hidden_states, tuple(presents))
    return hidden_states

[docs] class LLaMAForCausalLM(DecoderModelForCausalLM): config_class = LLaMAConfig

def __init__(self, config: LLaMAConfig):
    transformer = LLaMAModel(config)
    vocab_size_padded = pad_vocab_size(config.vocab_size,
                                       config.mapping.tp_size)
    if config.mapping.is_last_pp_rank():
        lm_head = ColumnLinear(config.hidden_size,
                               vocab_size_padded,
                               bias=False,
                               dtype=config.dtype,
                               tp_group=config.mapping.tp_group,
                               tp_size=config.mapping.tp_size,
                               gather_output=True)
    else:
        lm_head = None
    self.quant_mode = config.quant_mode
    self.mapping = config.mapping
    super().__init__(config, transformer, lm_head)

[docs] @classmethod def from_hugging_face( cls, hf_model_or_dir: Union[str, 'transformers.PreTrainedModel'], dtype: str = 'auto', mapping: Optional[Mapping] = None, quant_config: Optional[QuantConfig] = None, **kwargs): ''' Create a LLaMAForCausalLM object from give parameters ''' import transformers

    load_by_shard = kwargs.pop('load_by_shard', False)
    load_model_on_cpu = kwargs.pop('load_model_on_cpu', False)
    quant_ckpt_path = kwargs.pop('quant_ckpt_path', None)
    use_autoawq = kwargs.pop('use_autoawq', None)
    if os.environ.get("TRTLLM_DISABLE_UNIFIED_CONVERTER"
                      ) is not None and not isinstance(
                          hf_model_or_dir, transformers.PreTrainedModel):
        if "vila" in hf_model_or_dir or "llava" in hf_model_or_dir:
            hf_model_or_dir = load_hf_llama(hf_model_or_dir,
                                            load_model_on_cpu)
        elif not load_by_shard and not has_safetensors(
                hf_model_or_dir) and (
                    quant_config is None
                    or not quant_config.quant_mode.has_any_quant()):
            hf_model_or_dir = load_hf_llama(hf_model_or_dir,
                                            load_model_on_cpu)

    assert hf_model_or_dir is not None
    use_preloading = isinstance(hf_model_or_dir,
                                transformers.PreTrainedModel)
    if use_preloading:
        hf_model = hf_model_or_dir
        hf_config_or_dir = hf_model.config
    else:
        hf_model_dir = hf_model_or_dir
        hf_config_or_dir = hf_model_or_dir

    config = LLaMAConfig.from_hugging_face(hf_config_or_dir,
                                           dtype=dtype,
                                           mapping=mapping,
                                           quant_config=quant_config,
                                           **kwargs)
    if config.remove_duplicated_kv_heads:
        config.num_key_value_heads = config.num_key_value_heads // 2
    if os.environ.get("TRTLLM_DISABLE_UNIFIED_CONVERTER") is None:
        custom_dict = {}
        model_name = hf_model.config.model_type if use_preloading else hf_model_or_dir
        if "llava" in model_name:
            custom_dict = {
                "transformer": "language_model.model",
                "lm_head": "language_model.lm_head"
            }
        elif "vila" in model_name:
            hf_model_dir += "/llm"
        elif "exaone" in model_name.lower():
            custom_dict = {
                "transformer": "transformer",
                "layers": "h",
                "vocab_embedding": "wte",
                "lm_head": "lm_head",
                "ln_f": "ln_f",
                "attention": "attn.attention",
                "dense": "out_proj",
                "gate": "c_fc_1",
                "proj": "c_proj",
                "fc": "c_fc_0",
                "input_layernorm": "ln_1",
                "post_layernorm": "ln_2",
            }
        elif config.tie_word_embeddings:
            custom_dict = {"lm_head": "model.embed_tokens"}

        if quant_ckpt_path is not None:
            hf_model_dir = quant_ckpt_path
        arg_dict = {"use_autoawq": True} if use_autoawq else {}

        loader = ModelWeightsLoader(hf_model_dir, custom_dict)
        model = cls(config)
        loader.generate_tllm_weights(model, arg_dict)
    else:
        if use_preloading:
            assert not load_by_shard
            weights = load_weights_from_hf_model(hf_model, config)
        elif load_by_shard:
            weights = load_weights_from_hf_by_shard(hf_model_dir, config)
        elif has_safetensors(
                hf_model_dir) and not config.quant_mode.has_any_quant():
            weights = load_weights_from_hf_safetensors(hf_model_dir, config)
        elif quant_ckpt_path is not None:
            if quant_config.quant_mode.is_int4_weight_only():
                weights = load_weights_from_gptq(quant_ckpt_path, config)
            elif quant_config.quant_mode.is_qserve_w4a8():
                weights = load_weights_from_deepcompressor(
                    quant_ckpt_path, config)
            else:
                raise ValueError(
                    "quant_ckpt_path should be specified only for GPTQ or QServe"
                )
        else:
            hf_model = load_hf_llama(hf_model_dir, load_model_on_cpu)
            weights = load_weights_from_hf_model(hf_model, config)
        model = cls(config)
        model.load(weights)
    return model

[docs] def default_plugin_config(self, **kwargs): plugin_config = super().default_plugin_config(**kwargs) if self.quant_mode.is_int4_weight_only_per_group(): plugin_config.weight_only_groupwise_quant_matmul_plugin = 'auto' return plugin_config

[docs] @classmethod def from_meta_ckpt(cls, meta_ckpt_dir: str, dtype: str = 'auto', mapping: Optional[Mapping] = None, quant_config: Optional[QuantConfig] = None, **kwargs): config = LLaMAConfig.from_meta_ckpt(meta_ckpt_dir, dtype=dtype, mapping=mapping, quant_config=quant_config, **kwargs)

    weights = load_weights_from_meta_ckpt(meta_ckpt_dir, config)

    model = cls(config)
    model.load(weights)
    return model

[docs] @classmethod def quantize( cls, hf_model_dir: str, output_dir: str, dtype: str = 'auto', mapping: Optional[Mapping] = None, quant_config: Optional[QuantConfig] = None, *, device: str = 'cuda', calib_dataset: str = 'cnn_dailymail', calib_batches: int = 512, calib_batch_size: int = 1, calib_max_seq_length: int = 512, random_seed: int = 1234, tokenizer_max_seq_length: int = 2048, **kwargs, ): if quant_config._requires_modelopt_quantization: # modelopt quantization flow super().quantize(hf_model_dir, output_dir, dtype=dtype, mapping=mapping, quant_config=quant_config, device=device, calib_dataset=calib_dataset, calib_batches=calib_batches, calib_batch_size=calib_batch_size, calib_max_seq_length=calib_max_seq_length, random_seed=random_seed, tokenizer_max_seq_length=tokenizer_max_seq_length) elif quant_config._requires_calibration: # non-modelopt quantization flow from . import convert

        config = LLaMAConfig.from_hugging_face(hf_model_dir,
                                               dtype=dtype,
                                               mapping=mapping,
                                               quant_config=quant_config,
                                               **kwargs)
        trust_remote_code = kwargs.pop("trust_remote_code", True)

        convert.quantize(hf_model_dir,
                         output_dir,
                         config=config,
                         device=device,
                         calib_dataset=calib_dataset,
                         trust_remote_code=trust_remote_code,
                         calib_batches=calib_batches,
                         calib_max_seq_length=calib_max_seq_length)
    else:
        raise ValueError(
            f"The quant_config ({quant_config}) does not require calibration, try {cls.__name__}.from_hugging_face instead."
        )

[docs] def use_lora(self, lora_config: LoraConfig): use_lora(self, lora_config)