tensorrt_llm.models.chatglm.config — TensorRT-LLM (original) (raw)
Source code for tensorrt_llm.models.chatglm.config
SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
from typing import Optional, Union
from ...mapping import Mapping from ..convert_utils import infer_dtype from ..modeling_utils import PretrainedConfig, QuantConfig
GLM_VERSIONS = ['glm4', 'chatglm3', 'chatglm2', 'chatglm', 'glm'] GLM_ARCH1_VERSIONS = ['chatglm', 'glm'] GLM_ARCH2_VERSIONS = ['glm4', 'chatglm3', 'chatglm2']
[docs] class ChatGLMConfig(PretrainedConfig):
def __init__(self,
*,
chatglm_version: str = 'chatglm3',
add_bias_linear: bool = False,
add_qkv_bias: bool = True,
apply_query_key_layer_scaling: bool = False,
apply_residual_connection_post_layernorm: bool = False,
rmsnorm: bool = True,
rotary_pct: float = 0.5,
rotary_base: float = 10000.0,
rotary_scaling: Optional[dict] = None,
**kwargs):
self.chatglm_version = chatglm_version
self.add_bias_linear = add_bias_linear
self.add_qkv_bias = add_qkv_bias
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
self.rmsnorm = rmsnorm
self.rotary_pct = rotary_pct
self.rotary_base = rotary_base
self.rotary_scaling = rotary_scaling
super().__init__(**kwargs)
[docs] def to_dict(self): output = super().to_dict() # Serialize the fields added in ChatGLMConfig output['chatglm_version'] = self.chatglm_version output['add_bias_linear'] = self.add_bias_linear output['add_qkv_bias'] = self.add_qkv_bias output[ 'apply_query_key_layer_scaling'] = self.apply_query_key_layer_scaling output[ 'apply_residual_connection_post_layernorm'] = self.apply_residual_connection_post_layernorm output['rmsnorm'] = self.rmsnorm output['rotary_pct'] = self.rotary_pct output['rotary_base'] = self.rotary_base output['rotary_scaling'] = self.rotary_scaling return output
[docs] @classmethod def from_hugging_face( cls, hf_config_or_dir: Union[str, 'transformers.PretrainedConfig'], dtype: str = 'auto', mapping: Optional[Mapping] = None, quant_config: Optional[QuantConfig] = None, **kwargs): import transformers trust_remote_code = kwargs.pop('trust_remote_code', True)
# load hugging face config
if isinstance(hf_config_or_dir, transformers.PretrainedConfig):
hf_config = hf_config_or_dir
else:
hf_config_dir = str(hf_config_or_dir)
hf_config = transformers.AutoConfig.from_pretrained(
hf_config_dir, trust_remote_code=trust_remote_code)
logits_dtype = kwargs.pop('logits_dtype', 'float32')
use_parallel_embedding = kwargs.pop('use_parallel_embedding', False)
embedding_sharding_dim = kwargs.pop('embedding_sharding_dim', 0)
chatglm_version = kwargs.pop('chatglm_version', None)
# get chatglm version
if chatglm_version is None:
print("Inferring chatglm version from path...")
for v in GLM_VERSIONS:
if v in hf_config._name_or_path:
chatglm_version = v
break
if 'glm_4' in hf_config._name_or_path.replace("-", "_"):
chatglm_version = 'glm4'
assert chatglm_version in GLM_VERSIONS
print(f"Chatglm version: {chatglm_version}")
if chatglm_version == 'glm':
hf_config.num_kv_heads = hf_config.num_attention_heads
hf_config.ffn_hidden_size = hf_config.hidden_size * 4
hf_config.hidden_act = 'gelu'
hf_config.layernorm_epsilon = 1e-5
hf_config.max_position_embeddings = hf_config.max_sequence_length
hf_config.add_bias_linear = True
hf_config.add_qkv_bias = True
hf_config.apply_query_key_layer_scaling = False
hf_config.apply_residual_connection_post_layernorm = False
hf_config.rmsnorm = False
hf_config.rope_ratio = 1.0
elif chatglm_version == 'chatglm':
hf_config.num_kv_heads = hf_config.num_attention_heads
hf_config.ffn_hidden_size = hf_config.inner_hidden_size
hf_config.hidden_act = 'gelu'
hf_config.max_position_embeddings = hf_config.max_sequence_length
hf_config.add_bias_linear = True
hf_config.add_qkv_bias = True
hf_config.apply_query_key_layer_scaling = False
hf_config.apply_residual_connection_post_layernorm = False
hf_config.rmsnorm = False
hf_config.rope_ratio = 1.0
else:
hf_config.vocab_size = hf_config.padded_vocab_size
hf_config.num_kv_heads = hf_config.multi_query_group_num
hf_config.hidden_act = 'swiglu'
hf_config.max_position_embeddings = hf_config.seq_length
hf_config.rmsnorm = getattr(hf_config, 'rmsnorm', 1.0)
hf_config.rope_ratio = getattr(hf_config, 'rope_ratio', 1.0)
if chatglm_version == 'glm':
position_embedding_type = 'learned_absolute'
elif chatglm_version == 'chatglm':
position_embedding_type = 'chatglm'
elif chatglm_version in GLM_ARCH2_VERSIONS:
position_embedding_type = 'rope_gptj'
rotary_base = 10000.0
rotary_embedding_scaling = None
if chatglm_version == 'chatglm2':
if hf_config.rope_ratio > 1:
rotary_embedding_scaling = {
'type': 'linear',
'factor': hf_config.rope_ratio
}
elif chatglm_version == 'chatglm3' or chatglm_version == 'glm4':
rotary_base *= hf_config.rope_ratio
dtype = infer_dtype(dtype, getattr(hf_config, 'torch_dtype', None))
return cls(
architecture=hf_config.architectures[0],
dtype=dtype,
logits_dtype=logits_dtype,
num_hidden_layers=hf_config.num_layers,
num_attention_heads=hf_config.num_attention_heads,
num_key_value_heads=hf_config.num_kv_heads,
hidden_size=hf_config.hidden_size,
intermediate_size=hf_config.ffn_hidden_size,
norm_epsilon=hf_config.layernorm_epsilon,
vocab_size=hf_config.vocab_size,
position_embedding_type=position_embedding_type,
max_position_embeddings=hf_config.max_position_embeddings,
rotary_pct=0.5,
rotary_base=rotary_base,
rotary_scaling=rotary_embedding_scaling,
hidden_act=hf_config.hidden_act,
use_parallel_embedding=use_parallel_embedding,
embedding_sharding_dim=embedding_sharding_dim,
quantization=quant_config,
mapping=mapping,
chatglm_version=chatglm_version,
add_bias_linear=hf_config.add_bias_linear,
add_qkv_bias=hf_config.add_qkv_bias,
apply_query_key_layer_scaling=False,
apply_residual_connection_post_layernorm=hf_config.
apply_residual_connection_post_layernorm,
rmsnorm=hf_config.rmsnorm,
)