tensorrt_llm.models.chatglm.model — TensorRT-LLM (original) (raw)

SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

SPDX-License-Identifier: Apache-2.0

Licensed under the Apache License, Version 2.0 (the "License");

you may not use this file except in compliance with the License.

You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software

distributed under the License is distributed on an "AS IS" BASIS,

WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

See the License for the specific language governing permissions and

limitations under the License.

from typing import Optional, Union

import torch from transformers import AutoModel

from ..._common import default_net from ..._utils import pad_vocab_size from ...functional import Tensor, concat, shape from ...layers import (MLP, Attention, AttentionMaskType, AttentionParams, ColumnLinear, Embedding, KeyValueCacheParams, LayerNorm, RmsNorm) from ...mapping import Mapping from ...module import Module from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM, QuantConfig) from .config import GLM_ARCH1_VERSIONS, GLM_ARCH2_VERSIONS, ChatGLMConfig from .convert import load_weights_from_hf_model

class ChatGLMDecoderLayer(Module):

def __init__(self, config: ChatGLMConfig, layer_idx: int):
    super().__init__()
    self.layer_idx = layer_idx
    self.config = config
    self.chatglm_version = config.chatglm_version

    hidden_size = config.hidden_size
    dtype = config.dtype
    tp_group = config.mapping.tp_group
    tp_size = config.mapping.tp_size
    tp_rank = config.mapping.tp_rank
    layernorm_epsilon = config.norm_epsilon

    self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
    self.alpha = (2 * config.num_hidden_layers)**0.5
    norm_cls = RmsNorm if config.rmsnorm else LayerNorm

    if config.chatglm_version == 'glm':
        attention_mask_type = AttentionMaskType.bidirectionalglm
    elif config.chatglm_version == 'chatglm':
        attention_mask_type = AttentionMaskType.bidirectional
    elif config.chatglm_version in GLM_ARCH2_VERSIONS:
        attention_mask_type = AttentionMaskType.causal

    self.input_layernorm = norm_cls(
        normalized_shape=hidden_size,
        eps=layernorm_epsilon,
        elementwise_affine=True,
        dtype=dtype,
    )

    layers_range = config.mapping.pp_layers(config.num_hidden_layers)
    local_layer_idx = layer_idx - layers_range[0]
    self.attention = Attention(
        local_layer_idx=local_layer_idx,
        hidden_size=hidden_size,
        num_attention_heads=config.num_attention_heads,
        num_kv_heads=config.num_key_value_heads,
        max_position_embeddings=config.max_position_embeddings,
        num_layers=config.num_hidden_layers,
        apply_query_key_layer_scaling=config.apply_query_key_layer_scaling,
        attention_mask_type=attention_mask_type,
        bias=config.add_qkv_bias,
        dense_bias=config.add_bias_linear,
        dtype=config.dtype,
        position_embedding_type=config.position_embedding_type,
        rotary_embedding_base=config.rotary_base,
        rotary_embedding_scaling=config.rotary_scaling,
        rotary_embedding_percentage=config.rotary_pct,
        tp_group=tp_group,
        tp_size=tp_size,
        tp_rank=tp_rank,
        quant_mode=config.quant_mode,
        q_scaling=1.0,
        cross_attention=False,
        relative_attention=False,
        max_distance=0,
        num_buckets=0,
        cp_rank=config.mapping.cp_rank,
        cp_size=config.mapping.cp_size,
        cp_group=config.mapping.cp_group,
    )

    mlp_hidden_size = hidden_size * 4 if config.intermediate_size is None else config.intermediate_size

    self.mlp = MLP(
        hidden_size=hidden_size,
        ffn_hidden_size=mlp_hidden_size,
        hidden_act=config.hidden_act,
        bias=config.add_bias_linear,
        dtype=dtype,
        tp_group=tp_group,
        tp_size=tp_size,
        quant_mode=config.quant_mode,
    )

    self.post_layernorm = norm_cls(
        normalized_shape=hidden_size,
        eps=layernorm_epsilon,
        elementwise_affine=True,
        dtype=dtype,
    )

def forward(
    self,
    hidden_states: Tensor,
    attention_mask: Tensor = None,
    position_ids: Tensor = None,  # only used in ChatGLM-6B
    use_cache: bool = False,
    kv_cache_params: KeyValueCacheParams = None,
    attention_params: AttentionParams = None,
):
    norm_output = self.input_layernorm(hidden_states)

    attention_output = self.attention(
        hidden_states=norm_output,
        attention_mask=attention_mask,
        use_cache=use_cache,
        kv_cache_params=kv_cache_params,
        attention_params=attention_params,
        encoder_output=None,
        position_embedding=position_ids,
    )

    if use_cache:
        attention_output, presents = attention_output

    if self.chatglm_version == 'chatglm':
        residual = norm_output

        norm_input = residual * self.alpha + attention_output

        norm_output = self.post_layernorm(norm_input)

        mlp_output = self.mlp(norm_output)

        residual = norm_output

        output = residual * self.alpha + mlp_output

    else:
        residual = norm_output if self.apply_residual_connection_post_layernorm else hidden_states

        norm_input = residual + attention_output

        norm_output = self.post_layernorm(norm_input)

        mlp_output = self.mlp(norm_output)

        residual = norm_output if self.apply_residual_connection_post_layernorm else norm_input

        output = residual + mlp_output

    if use_cache:
        return (output, presents)
    return output

[docs] class ChatGLMModel(Module):

def __init__(self, config: ChatGLMConfig):
    super().__init__()
    self.chatglm_version = config.chatglm_version
    norm_cls = RmsNorm if config.rmsnorm else LayerNorm

    self.vocab_embedding = Embedding(config.vocab_size,
                                     config.hidden_size,
                                     dtype=config.dtype)

    if config.chatglm_version == 'glm':
        self.position_embedding = Embedding(
            config.max_position_embeddings + 1,
            config.hidden_size,
            dtype=config.dtype,
        )
        self.block_embedding = Embedding(
            config.max_position_embeddings + 1,
            config.hidden_size,
            dtype=config.dtype,
        )

    self.layers = DecoderLayerList(ChatGLMDecoderLayer, config)

    self.ln_f = norm_cls(
        normalized_shape=config.hidden_size,
        eps=config.norm_epsilon,
        elementwise_affine=True,
        dtype=config.dtype,
    )

[docs] def forward( self, input_ids: Tensor = None, position_ids: Tensor = None, # only used in ChatGLM-6B use_cache: bool = False, attention_mask: Tensor = None, kv_cache_params: KeyValueCacheParams = None, attention_params: AttentionParams = None, ): hidden_states = self.vocab_embedding(input_ids)

    if self.chatglm_version == 'glm':
        if default_net().plugin_config.remove_input_padding:
            position_ids_list = position_ids.split(1, dim=0)
        else:
            position_ids_list = position_ids.split(1, dim=1)

        position_embedding = self.position_embedding(position_ids_list[0])
        block_embedding = self.block_embedding(position_ids_list[1])
        position_embedding = position_embedding + block_embedding

        if default_net().plugin_config.remove_input_padding:
            position_embedding = position_embedding.view(
                concat([
                    shape(position_embedding, 1),
                    shape(position_embedding, 2)
                ]))
        else:
            position_embedding = position_embedding.view(
                concat([
                    shape(position_embedding, 0),
                    shape(position_embedding, 2),
                    shape(position_embedding, 3),
                ]))

        hidden_states = hidden_states + position_embedding

    hidden_states = self.layers(hidden_states,
                                use_cache=use_cache,
                                attention_mask=attention_mask,
                                kv_cache_params=kv_cache_params,
                                attention_params=attention_params,
                                position_ids=position_ids)

    if use_cache:
        hidden_states, presents = hidden_states

    hidden_states = self.ln_f(hidden_states)

    if use_cache:
        return (hidden_states, tuple(presents))
    return hidden_states

[docs] class ChatGLMForCausalLM(DecoderModelForCausalLM): config_class = ChatGLMConfig

def __init__(self, config: ChatGLMConfig):
    transformer = ChatGLMModel(config)
    vocab_size_padded = pad_vocab_size(config.vocab_size,
                                       config.mapping.tp_size)

    lm_head = ColumnLinear(config.hidden_size,
                           vocab_size_padded,
                           bias=False,
                           dtype=config.dtype,
                           tp_group=config.mapping.tp_group,
                           tp_size=config.mapping.tp_size,
                           gather_output=True)
    super().__init__(config, transformer, lm_head)

[docs] @classmethod def from_hugging_face( cls, hf_model_or_dir: Union[str, 'transformers.PreTrainedModel'], dtype: str = 'auto', mapping: Optional[Mapping] = None, quant_config: Optional[QuantConfig] = None, **kwargs): ''' Create a LLaMAForCausalLM object from give parameters ''' load_model_on_cpu = kwargs.pop('load_model_on_cpu', False) trust_remote_code = kwargs.pop('trust_remote_code', True)

    config = ChatGLMConfig.from_hugging_face(hf_model_or_dir,
                                             dtype=dtype,
                                             mapping=mapping,
                                             quant_config=quant_config,
                                             **kwargs)
    if config.chatglm_version == 'glm':
        device_map = 'cuda' if not load_model_on_cpu else 'cpu'
    else:
        device_map = 'auto' if not load_model_on_cpu else 'cpu'
    hf_model = AutoModel.from_pretrained(
        hf_model_or_dir,
        trust_remote_code=trust_remote_code,
        torch_dtype='auto' if config.chatglm_version != 'glm' else getattr(
            torch, config.dtype),
        device_map=device_map)
    weights = load_weights_from_hf_model(hf_model, config)

    model = cls(config)
    model.load(weights)
    return model

[docs] @classmethod def quantize( cls, hf_model_dir: str, output_dir: str, dtype: str = 'auto', mapping: Optional[Mapping] = None, quant_config: Optional[QuantConfig] = None, *, device: str = 'cuda', calib_dataset: str = 'cnn_dailymail', calib_batches: int = 512, calib_batch_size: int = 1, calib_max_seq_length: int = 512, random_seed: int = 1234, tokenizer_max_seq_length: int = 2048, **kwargs, ): if quant_config._requires_modelopt_quantization: # modelopt quantization flow super().quantize(hf_model_dir, output_dir, dtype=dtype, mapping=mapping, quant_config=quant_config, device=device, calib_dataset=calib_dataset, calib_batches=calib_batches, calib_batch_size=calib_batch_size, calib_max_seq_length=calib_max_seq_length, random_seed=random_seed, tokenizer_max_seq_length=tokenizer_max_seq_length) elif quant_config._requires_calibration: # non-modelopt quantization flow from . import convert

        config = ChatGLMConfig.from_hugging_face(hf_model_dir,
                                                 dtype=dtype,
                                                 mapping=mapping,
                                                 quant_config=quant_config,
                                                 **kwargs)
        convert.quantize(hf_model_dir,
                         output_dir,
                         config=config,
                         calib_dataset=calib_dataset,
                         device=device)
    else:
        raise ValueError(
            f"The quant_config ({quant_config}) does not require calibration, try {cls.__name__}.from_hugging_face instead."
        )

[docs] def prepare_inputs(self, *args, **kwargs): """See PretrainedModel.prepare_inputs for the detailed parameter list. """ if self.transformer.chatglm_version in GLM_ARCH1_VERSIONS: position_encoding_2d = True else: position_encoding_2d = False return super().prepare_inputs(*args, **kwargs, position_encoding_2d=position_encoding_2d)