tensorrt_llm.models.gemma.model — TensorRT-LLM (original) (raw)
SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
import math from typing import TYPE_CHECKING, Any, Dict, Optional
from tensorrt_llm.models.gemma.convert import (QuantizeModifiers, Weights, load_gemma_weights_from_hf_model, non_modelopt_quantize_if_needed) from tensorrt_llm.quantization.mode import (MODELOPT_FLOW_QUANTIZATIONS, QuantAlgo)
from ..._common import default_net from ..._utils import pad_vocab_size from ...functional import (AllReduceFusionOp, AllReduceParams, LayerNormType, Tensor, cast, recv, send) from ...layers import (Attention, AttentionMaskType, AttentionParams, ColumnLinear, Embedding, GatedMLP, KeyValueCacheParams, LoraParams, PositionEmbeddingType, RmsNorm) from ...lora_manager import LoraConfig, use_lora from ...mapping import Mapping from ...module import Module from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM, QuantConfig, save_checkpoint, save_config) from .config import GemmaConfig
if TYPE_CHECKING:
from .config import HfConfigOrDir
class GemmaDecoderLayer(Module):
def __init__(self, config: GemmaConfig, layer_idx: int):
super().__init__()
self.layer_idx = layer_idx
self.config = config
self.input_layernorm = RmsNorm(normalized_shape=config.hidden_size,
eps=config.norm_epsilon,
dtype=config.dtype)
layers_range = config.mapping.pp_layers(config.num_hidden_layers)
self.local_layer_idx = layer_idx - layers_range[0]
q_scaling = 1.0
max_attn_value = 0.0
qk_layernorm = False
is_sliding = False
rotary_base = config.rotary_base
rotary_base_local = None
gemma2_config = config.gemma2_config()
gemma3_config = config.gemma3_config()
if gemma2_config:
q_scaling = math.sqrt(
gemma2_config.query_pre_attn_scalar) / math.sqrt(
config.head_size)
max_attn_value = config.attn_logit_softcapping or 0.0
elif gemma3_config:
qk_layernorm = True
q_scaling = math.sqrt(
gemma3_config.query_pre_attn_scalar) / math.sqrt(
config.head_size)
is_sliding = bool(
(layer_idx + 1) % gemma3_config.sliding_window_pattern)
rotary_base_local = config.rope_local_base_freq
self.attention = Attention(
local_layer_idx=self.local_layer_idx,
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
attention_head_size=config.head_size,
qk_layernorm=qk_layernorm,
layernorm_type=LayerNormType.RmsNorm,
max_position_embeddings=config.max_position_embeddings,
dtype=config.dtype,
attention_mask_type=AttentionMaskType.causal,
bias=config.attn_bias,
position_embedding_type=PositionEmbeddingType.rope_gpt_neox,
rotary_embedding_base=rotary_base,
rotary_embedding_base_local=rotary_base_local,
rotary_embedding_scaling=config.rotary_scaling,
tp_group=config.mapping.tp_group,
tp_size=config.mapping.tp_size,
quant_mode=config.quant_mode,
q_scaling=q_scaling,
max_attn_value=max_attn_value,
is_local=is_sliding,
)
mlp_hidden_size = config.hidden_size * 4 if config.intermediate_size is None else config.intermediate_size
self.mlp = GatedMLP(hidden_size=config.hidden_size,
ffn_hidden_size=mlp_hidden_size,
hidden_act=config.hidden_act,
dtype=config.dtype,
bias=config.mlp_bias,
tp_group=config.mapping.tp_group,
tp_size=config.mapping.tp_size,
quant_mode=config.quant_mode)
if self.config.inter_layernorms:
self.pre_feedforward_layernorm = RmsNorm(
normalized_shape=config.hidden_size,
eps=config.norm_epsilon,
dtype=config.dtype)
self.post_feedforward_layernorm = RmsNorm(
normalized_shape=config.hidden_size,
eps=config.norm_epsilon,
dtype=config.dtype)
self.post_layernorm = RmsNorm(normalized_shape=config.hidden_size,
eps=config.norm_epsilon,
dtype=config.dtype)
def forward(self,
hidden_states: Tensor,
attention_mask: Optional[Tensor] = None,
use_cache: bool = False,
kv_cache_params: Optional[KeyValueCacheParams] = None,
attention_params: Optional[AttentionParams] = None,
lora_layer_params: Optional[LoraParams] = None,
next_layer_input_layernorm_args=None):
# assert not (
# default_net().plugin_config.reduce_fusion and self.has_residual_mlp
# ), "Custom all reduce and residual mlp can't be enabled at the same time."
if default_net(
).plugin_config.reduce_fusion and self.local_layer_idx > 0:
hidden_states, residual = hidden_states #FIXME:AN need to check if appropriate residual value is hidden state is pulled out.
else:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
attention_output = self.attention(
hidden_states,
attention_mask=attention_mask,
use_cache=use_cache,
kv_cache_params=kv_cache_params,
attention_params=attention_params,
norm_before_bmm1=True,
lora_layer_params=lora_layer_params,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_PREPOST_NORM
if default_net().plugin_config.reduce_fusion else
AllReduceFusionOp.NONE,
residual=residual,
norm_weight=self.post_layernorm.weight.value,
norm_pre_residual_weight=self.pre_feedforward_layernorm.weight.
value if self.config.inter_layernorms else None,
eps=self.post_layernorm.eps))
if use_cache:
attention_output, presents = attention_output
if default_net().plugin_config.reduce_fusion:
hidden_states, residual = attention_output
else:
if self.config.inter_layernorms:
attention_output = self.post_layernorm(attention_output)
hidden_states = residual + attention_output
residual = hidden_states
if self.config.inter_layernorms:
hidden_states = self.pre_feedforward_layernorm(hidden_states)
else:
hidden_states = self.post_layernorm(hidden_states)
if next_layer_input_layernorm_args is not None:
hidden_states = self.mlp(
hidden_states,
lora_layer_params=lora_layer_params,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_PREPOST_NORM
if default_net().plugin_config.reduce_fusion else
AllReduceFusionOp.NONE,
residual=residual,
norm_weight=next_layer_input_layernorm_args[0],
norm_pre_residual_weight=self.post_feedforward_layernorm.
weight.value,
eps=next_layer_input_layernorm_args[1]))
else:
hidden_states = self.mlp(hidden_states,
lora_layer_params=lora_layer_params)
if self.config.inter_layernorms:
hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = residual + hidden_states
if use_cache:
return (hidden_states, presents)
return hidden_states
class GemmaModel(Module):
def __init__(self, config: GemmaConfig) -> None:
super().__init__()
self.mapping = config.mapping
if self.mapping.is_first_pp_rank():
self.vocab_embedding = Embedding(config.vocab_size,
config.hidden_size,
dtype=config.dtype)
self.layers = DecoderLayerList(GemmaDecoderLayer, config)
if self.mapping.is_last_pp_rank():
self.ln_f = RmsNorm(normalized_shape=config.hidden_size,
eps=config.norm_epsilon,
dtype=config.dtype)
self.hidden_size = config.hidden_size
def forward(self,
input_ids,
position_ids=None,
use_cache=False,
attention_mask=None,
kv_cache_params=None,
attention_params=None,
hidden_states=None,
prompt_embedding_table: Optional[Tensor] = None,
prompt_tasks: Optional[Tensor] = None,
prompt_vocab_size: Optional[Tensor] = None,
lora_params=None):
ptuning_args = [
prompt_embedding_table, prompt_tasks, prompt_vocab_size
] if prompt_embedding_table is not None else []
if self.mapping.is_first_pp_rank():
hidden_states = self.vocab_embedding(input_ids, *ptuning_args)
hidden_states = cast(hidden_states * math.sqrt(self.hidden_size),
hidden_states.dtype)
else:
hidden_states = recv(hidden_states, self.mapping.prev_pp_rank())
hidden_states = self.layers.forward(
hidden_states,
use_cache=use_cache,
attention_mask=attention_mask,
kv_cache_params=kv_cache_params,
attention_params=attention_params,
lora_params=lora_params,
)
if use_cache:
hidden_states, presents = hidden_states
if self.mapping.is_last_pp_rank():
hidden_states = self.ln_f(hidden_states)
else:
hidden_states = send(hidden_states, self.mapping.next_pp_rank())
if use_cache:
return (hidden_states, tuple(presents))
return hidden_states
[docs] class GemmaForCausalLM(DecoderModelForCausalLM): config_class = GemmaConfig
def __init__(self, config: GemmaConfig):
transformer = GemmaModel(config)
vocab_size_padded = pad_vocab_size(config.vocab_size,
config.mapping.tp_size)
if config.mapping.is_last_pp_rank():
lm_head = ColumnLinear(config.hidden_size,
vocab_size_padded,
bias=False,
dtype=config.dtype,
tp_group=config.mapping.tp_group,
tp_size=config.mapping.tp_size,
gather_output=True)
else:
lm_head = None
self.quant_mode = config.quant_mode
self.mapping = config.mapping
super().__init__(config, transformer, lm_head)
@staticmethod
def _load_gemma_weights_from_hf(hf_model_dir: "HfConfigOrDir",
trt_llm_config: GemmaConfig, *,
load_model_on_cpu: bool) -> Weights:
"""`AutoModelForCausalLM.from_pretrained` will parse the correct gemma, whether Gemma or Gemma2 or future versions."""
import transformers
hf_gemma = transformers.AutoModelForCausalLM.from_pretrained(
hf_model_dir,
device_map="cpu" if load_model_on_cpu else "auto",
torch_dtype='auto',
)
weights = load_gemma_weights_from_hf_model(hf_gemma, trt_llm_config)
del hf_gemma
return weights
[docs] @classmethod def from_hugging_face(cls, hf_model_dir: "HfConfigOrDir", dtype='float16', mapping: Optional[Mapping] = None, quant_config: Optional[QuantConfig] = None, load_model_on_cpu: bool = True, **kwargs): config = GemmaConfig.from_hugging_face(hf_config_or_dir=hf_model_dir, dtype=dtype, mapping=mapping, quant_config=quant_config, **kwargs) model = GemmaForCausalLM(config) weights = cls._load_gemma_weights_from_hf( hf_model_dir, config, load_model_on_cpu=load_model_on_cpu) model.load(weights) return model
NATIVE_QUANT_FLOW = {
QuantAlgo.W8A16, QuantAlgo.W4A16,
QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN,
QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN,
QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN,
QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN
}
[docs]
@classmethod
def assert_valid_quant_algo(cls, quant_algo: Optional[QuantAlgo]):
allowed_quant_values = {
None
} | cls.NATIVE_QUANT_FLOW | MODELOPT_FLOW_QUANTIZATIONS
assert quant_algo in allowed_quant_values, f"{quant_algo} isn't in the allowed QuantAlgo
values for this model: {allowed_quant_values}"
[docs] @classmethod def quantize( cls, hf_model_dir: str, output_dir: str, dtype: str = 'float16', mapping: Optional[Mapping] = None, quant_config: Optional[QuantConfig] = None, *, gemma_config_kwargs: Dict[str, Any] = None, **quantize_kwargs: Dict[str, Any], ): config = GemmaConfig.from_hugging_face(hf_model_dir, dtype=dtype, mapping=mapping, quant_config=quant_config, **(gemma_config_kwargs or {}))
quant_algo = config.quantization.quant_algo
if quant_algo is None and config.quantization.kv_cache_quant_algo is None:
raise ValueError(
"There is no point in calling `quantize()` if both `quant_algo` and `kv_cache_quant_algo` are `None`"
)
elif quant_algo in MODELOPT_FLOW_QUANTIZATIONS:
super().quantize(hf_model_dir,
output_dir,
dtype=config.dtype,
mapping=config.mapping,
quant_config=config.quantization,
**quantize_kwargs)
elif quant_algo in cls.NATIVE_QUANT_FLOW:
save_config(config, output_dir=output_dir, log=True)
for config in config.for_each_rank():
hf_weights = cls._load_gemma_weights_from_hf(
hf_model_dir, config)
ranked_weights = non_modelopt_quantize_if_needed(
hf_weights,
model_dir=hf_model_dir,
quantize_modifiers=QuantizeModifiers(),
trt_llm_config=config)
save_checkpoint(
output_dir=output_dir,
weights=ranked_weights,
rank=config.mapping.rank,
)
del hf_weights
else:
cls.assert_valid_quant_algo(quant_algo)
[docs] def use_lora(self, lora_config: LoraConfig) -> None: return use_lora( self, lora_config) # Use the default trtllm->hf module mapping