tensorrt_llm.models.mllama.model — TensorRT-LLM (original) (raw)
SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
import math from collections import OrderedDict from typing import List, Optional, Union
import tensorrt as trt import torch
from tensorrt_llm._common import default_net from tensorrt_llm._utils import numpy_to_torch, str_dtype_to_torch from tensorrt_llm.bindings import KVCacheType from tensorrt_llm.functional import (Conditional, LayerNormPositionType, LayerNormType, MLPType, PositionEmbeddingType, Tensor, assertion, gather_last_token_logits, maximum, minimum, recv, reduce, send, shape, tanh) from tensorrt_llm.layers import (MLP, Attention, AttentionMaskParams, AttentionMaskType, AttentionParams, ColumnLinear, Embedding, FusedGatedMLP, GatedMLP, GroupNorm, KeyValueCacheParams, LayerNorm, LoraParams, RmsNorm) from tensorrt_llm.lora_manager import (LoraConfig, get_default_trtllm_modules_to_hf_modules, use_lora) from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.model_weights_loader import ModelWeightsLoader from tensorrt_llm.models.modeling_utils import PretrainedModel, QuantConfig from tensorrt_llm.module import Module, ModuleList from tensorrt_llm.parameter import Parameter from tensorrt_llm.quantization import QuantMode
from .config import MLLaMAConfig
layernorm_map = { LayerNormType.LayerNorm: LayerNorm, LayerNormType.RmsNorm: RmsNorm, LayerNormType.GroupNorm: GroupNorm, }
mlp_map = { MLPType.MLP: MLP, MLPType.GatedMLP: GatedMLP, MLPType.FusedGatedMLP: FusedGatedMLP, }
ADD_DEBUG_TENSOR = False
class CrossAttentionTransformerBlock(Module):
def __init__(
self,
*,
local_layer_idx,
hidden_size,
ffn_hidden_size,
num_attention_heads,
num_kv_heads,
head_size,
max_position_embeddings=None,
q_scaling=1.0,
has_attention_qkvo_bias=False,
has_mlp_bias=False,
layernorm_position=LayerNormPositionType.pre_layernorm,
layernorm_type=LayerNormType.RmsNorm,
layernorm_eps=1e-5,
hidden_act="gated-silu",
mlp_type=MLPType.GatedMLP,
mapping=Mapping(),
dtype=None,
residual_scaling=1.0,
relative_attention=False,
max_distance=0,
num_buckets=0,
fp16_clamping=False,
skip_cross_kv=False,
use_implicit_relative_attention=False,
rotary_embedding_base=None,
rotary_embedding_scaling=None,
quant_mode=QuantMode(0),
):
super().__init__()
self.local_layer_idx = local_layer_idx
self.layernorm_type = layernorm_type
ln_type = layernorm_map[layernorm_type]
self.layernorm_position = layernorm_position
assert self.layernorm_position == LayerNormPositionType.pre_layernorm
self.cross_attention = Attention(
local_layer_idx=local_layer_idx,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_head_size=head_size,
num_kv_heads=num_kv_heads,
max_position_embeddings=max_position_embeddings,
q_scaling=q_scaling,
bias=has_attention_qkvo_bias,
attention_mask_type=AttentionMaskType.causal,
tp_group=mapping.tp_group,
tp_size=mapping.tp_size,
tp_rank=mapping.tp_rank,
dtype=dtype,
cross_attention=True,
relative_attention=
False, # Cross attention has no relative attention bias
max_distance=max_distance,
num_buckets=num_buckets,
position_embedding_type=PositionEmbeddingType.
learned_absolute, # we don't use rope for cross attn
skip_cross_kv=skip_cross_kv,
qk_layernorm=True,
layernorm_type=layernorm_type,
quant_mode=quant_mode,
)
self.input_layernorm = ln_type(normalized_shape=hidden_size,
eps=layernorm_eps,
dtype=dtype)
self.gate_attn = Parameter(shape=tuple((1, )), dtype=dtype)
self.mlp_type = mlp_type
mlp_f = mlp_map[mlp_type]
self.mlp = mlp_f(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
hidden_act=hidden_act,
bias=has_mlp_bias,
tp_group=mapping.tp_group,
tp_size=mapping.tp_size,
dtype=dtype,
)
self.post_layernorm = ln_type(normalized_shape=hidden_size,
eps=layernorm_eps,
dtype=dtype)
self.gate_ffwd = Parameter(shape=tuple((1, )), dtype=dtype)
self.residual_scaling = residual_scaling
self.fp16_clamping = fp16_clamping
self.no_ffn = False
def forward(self,
hidden_states: Tensor,
encoder_output: Optional[Tensor] = None,
attention_mask_params=None,
use_cache=False,
kv_cache_params=None,
attention_params=None,
lora_layer_params=None,
cross_kv_cache_gen: Optional[Tensor] = None,
cross_kv_reuse: Optional[Tensor] = None,
full_text_row_masked_out_mask: Tensor = None,
skip_cross_attn_blocks: Tensor = None):
assert isinstance(hidden_states, Tensor)
if encoder_output:
assert isinstance(encoder_output, Tensor)
if ADD_DEBUG_TENSOR:
hidden_states.mark_output(
f'{self.local_layer_idx:2d}/1.0: hidden_states',
hidden_states.dtype)
# cross attention
residual = hidden_states * self.residual_scaling
# skip input_layernorm
if skip_cross_attn_blocks is not None:
input_ln_conditional = Conditional(skip_cross_attn_blocks)
skip_result = input_ln_conditional.add_input(hidden_states)
hidden_states = input_ln_conditional.add_input(hidden_states)
hidden_states = self.input_layernorm(hidden_states)
hidden_states = input_ln_conditional.add_output(
skip_result, hidden_states)
else:
hidden_states = self.input_layernorm(hidden_states)
if ADD_DEBUG_TENSOR:
hidden_states.mark_output(
f'{self.local_layer_idx:2d}/2.1: normed_input',
hidden_states.dtype)
# pass full_text_row_masked_out_mask and xattn_mask
attention_output = self.cross_attention(
hidden_states=hidden_states,
attention_mask=attention_mask_params.cross_attention_mask,
attention_packed_mask=attention_mask_params.
cross_attention_packed_mask,
encoder_output=encoder_output,
use_cache=use_cache,
kv_cache_params=kv_cache_params,
attention_params=attention_params,
lora_layer_params=lora_layer_params,
cross_kv_cache_gen=cross_kv_cache_gen,
cross_kv_reuse=cross_kv_reuse,
skip_attn=skip_cross_attn_blocks,
)
if use_cache:
attention_output, presents_cross = attention_output
if ADD_DEBUG_TENSOR:
attention_output.mark_output(
f'{self.local_layer_idx:2d}/3.1: cross_attention_output',
attention_output.dtype)
attn_residual_scale = tanh(self.gate_attn.value.cast(trt.float32)).cast(
attention_output.dtype)
attention_input = hidden_states
hidden_states = residual + attn_residual_scale * attention_output
# use to skip attention_output with residual
# Since conditional does not work for gpt_attention_plugin, we replace the
# attention_output by hidden_states (input of attention) now.
if skip_cross_attn_blocks is not None:
attn_conditional = Conditional(skip_cross_attn_blocks)
skip_result = attn_conditional.add_input(attention_input)
hidden_states = attn_conditional.add_input(hidden_states)
hidden_states = attn_conditional.add_output(skip_result,
hidden_states)
if ADD_DEBUG_TENSOR:
hidden_states.mark_output(
f'{self.local_layer_idx:2d}/3.2: cross_attn_output_with_residual',
hidden_states.dtype)
if self.fp16_clamping:
hidden_states = maximum(-64000.0, hidden_states)
hidden_states = minimum(64000.0, hidden_states)
# MLP
# skip post_layernorm and mlp
if skip_cross_attn_blocks is not None:
mlp_conditional = Conditional(skip_cross_attn_blocks)
skip_case = mlp_conditional.add_input(hidden_states)
hidden_states = mlp_conditional.add_input(hidden_states)
attention_output = attention_output * full_text_row_masked_out_mask # TODO should move this mask into attention?
residual = hidden_states * self.residual_scaling
hidden_states = self.post_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states,
lora_layer_params=lora_layer_params)
if ADD_DEBUG_TENSOR:
hidden_states.mark_output(
f'{self.local_layer_idx:2d}/4.1: mlp_output',
hidden_states.dtype)
hidden_states = hidden_states * full_text_row_masked_out_mask
if ADD_DEBUG_TENSOR:
hidden_states.mark_output(
f'{self.local_layer_idx:2d}/4.2: masked_mlp_output',
hidden_states.dtype)
ffn_residual_scale = tanh(self.gate_ffwd.value.cast(trt.float32)).cast(
hidden_states.dtype)
hidden_states = residual + ffn_residual_scale * hidden_states * float(
not self.no_ffn)
if skip_cross_attn_blocks is not None:
hidden_states = mlp_conditional.add_output(skip_case, hidden_states)
if self.fp16_clamping:
hidden_states = maximum(-64000.0, hidden_states)
hidden_states = minimum(64000.0, hidden_states)
if ADD_DEBUG_TENSOR:
hidden_states.mark_output(
f'{self.local_layer_idx:2d}/4.4: transformer_out',
hidden_states.dtype)
if use_cache:
return (hidden_states, presents_cross)
return hidden_states
class TransformerBlock(Module):
def __init__(
self,
*,
local_layer_idx,
hidden_size,
ffn_hidden_size,
num_attention_heads,
num_kv_heads,
head_size,
max_position_embeddings=None,
q_scaling=1.0,
has_attention_qkvo_bias=False,
has_mlp_bias=False,
layernorm_position=LayerNormPositionType.pre_layernorm,
layernorm_type=LayerNormType.RmsNorm,
layernorm_eps=1e-5,
hidden_act="gated-silu",
mlp_type=MLPType.GatedMLP,
mapping=Mapping(),
dtype=None,
residual_scaling=1.0,
relative_attention=False,
max_distance=0,
num_buckets=0,
fp16_clamping=False,
skip_cross_kv=False,
use_implicit_relative_attention=False,
rotary_embedding_base=None,
rotary_embedding_scaling=None,
quant_mode=QuantMode(0),
):
super().__init__()
self.local_layer_idx = local_layer_idx
self.layernorm_type = layernorm_type
ln_type = layernorm_map[layernorm_type]
self.layernorm_position = layernorm_position
assert self.layernorm_position == LayerNormPositionType.pre_layernorm
self.self_attention = Attention(
local_layer_idx=local_layer_idx,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_head_size=head_size,
num_kv_heads=num_kv_heads,
max_position_embeddings=max_position_embeddings,
q_scaling=q_scaling,
bias=has_attention_qkvo_bias,
attention_mask_type=AttentionMaskType.causal,
tp_group=mapping.tp_group,
tp_size=mapping.tp_size,
tp_rank=mapping.tp_rank,
dtype=dtype,
cross_attention=False,
relative_attention=relative_attention,
max_distance=max_distance if use_implicit_relative_attention else 0,
num_buckets=num_buckets,
position_embedding_type=PositionEmbeddingType.relative
if relative_attention else PositionEmbeddingType.rope_gpt_neox,
use_implicit_relative_attention=use_implicit_relative_attention,
rotary_embedding_base=rotary_embedding_base,
rotary_embedding_scaling=rotary_embedding_scaling,
quant_mode=quant_mode,
)
self.input_layernorm = ln_type(normalized_shape=hidden_size,
eps=layernorm_eps,
dtype=dtype)
self.mlp_type = mlp_type
mlp_f = mlp_map[mlp_type]
self.mlp = mlp_f(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
hidden_act=hidden_act,
bias=has_mlp_bias,
tp_group=mapping.tp_group,
tp_size=mapping.tp_size,
dtype=dtype,
)
self.post_layernorm = ln_type(normalized_shape=hidden_size,
eps=layernorm_eps,
dtype=dtype)
self.residual_scaling = residual_scaling
self.fp16_clamping = fp16_clamping
def forward(
self,
hidden_states: Tensor,
encoder_output: Optional[Tensor] = None, # not used
attention_mask_params=None,
use_cache=False,
kv_cache_params=None,
attention_params=None,
lora_layer_params=None,
cross_kv_cache_gen: Optional[Tensor] = None,
cross_kv_reuse: Optional[Tensor] = None,
full_text_row_masked_out_mask: Tensor = None, # not used
skip_cross_attn_blocks=None,
):
assert isinstance(hidden_states, Tensor)
# self-attention
residual = hidden_states * self.residual_scaling
if ADD_DEBUG_TENSOR:
hidden_states.mark_output(
f'{self.local_layer_idx:2d}/1.0: hidden_states',
hidden_states.dtype)
hidden_states = self.input_layernorm(hidden_states)
if ADD_DEBUG_TENSOR:
hidden_states.mark_output(
f'{self.local_layer_idx:2d}/2.1: normed attn_input',
hidden_states.dtype)
attention_output = self.self_attention(
hidden_states=hidden_states,
attention_mask=attention_mask_params.self_attention_mask,
use_cache=use_cache,
kv_cache_params=kv_cache_params,
attention_params=attention_params,
lora_layer_params=lora_layer_params)
if use_cache:
attention_output, presents_self = attention_output
if ADD_DEBUG_TENSOR:
attention_output.mark_output(
f'{self.local_layer_idx:2d}/3.1: self_attention_output',
attention_output.dtype)
hidden_states = residual + attention_output
if ADD_DEBUG_TENSOR:
hidden_states.mark_output(
f'{self.local_layer_idx:2d}/3.1: attention_output_with_residual',
hidden_states.dtype)
if self.fp16_clamping:
hidden_states = maximum(-64000.0, hidden_states)
hidden_states = minimum(64000.0, hidden_states)
# MLP
residual = hidden_states * self.residual_scaling
hidden_states = self.post_layernorm(hidden_states)
if ADD_DEBUG_TENSOR:
hidden_states.mark_output(
f'{self.local_layer_idx:2d}/3.2: normed_mlp_input',
hidden_states.dtype)
hidden_states = self.mlp(hidden_states,
lora_layer_params=lora_layer_params)
if ADD_DEBUG_TENSOR:
hidden_states.mark_output(
f'{self.local_layer_idx:2d}/4.1: mlp_output',
hidden_states.dtype)
hidden_states = residual + hidden_states
if ADD_DEBUG_TENSOR:
hidden_states.mark_output(
f'{self.local_layer_idx:2d}/4.2: mlp_output_residual',
hidden_states.dtype)
if self.fp16_clamping:
hidden_states = maximum(-64000.0, hidden_states)
hidden_states = minimum(64000.0, hidden_states)
if use_cache:
return (hidden_states, presents_self)
return hidden_states
class MLLaMAModel(Module):
def __init__(self, config: MLLaMAConfig) -> None:
super().__init__()
self.config = config
self.position_embedding_type = config.position_embedding_type
self.mapping = self.config.mapping
self.layernorm_type = self.config.layernorm_type
ln_type = layernorm_map[self.layernorm_type]
self.has_attention_qkvo_bias = self.config.has_attention_qkvo_bias
self.has_mlp_bias = self.config.has_mlp_bias
self.has_model_final_layernorm = self.config.has_model_final_layernorm
self._dtype = self.config.dtype
# no quantization considered for now
self._kv_dtype = self._dtype
self._logits_dtype = self.config.logits_dtype
self.total_num_layers = self.config.num_hidden_layers
self.num_layers = self.config.num_hidden_layers // self.mapping.pp_size
self.hidden_size = self.config.hidden_size
self.encoder_hidden_size = self.config.hidden_size
self.num_heads = self.config.num_attention_heads
# num_kv_heads = self.num_heads
num_kv_heads = self.config.num_key_value_heads
if num_kv_heads is None or num_kv_heads <= 0:
num_kv_heads = self.num_heads
self.num_kv_heads = num_kv_heads
self.head_size = self.hidden_size // self.num_heads if self.config.head_size is None else self.config.head_size
self.fp16_clamping = False
self.skip_cross_kv = self.config.skip_cross_kv
self.mlp_type = MLPType.MLP if not hasattr(
self.config, "mlp_type") else self.config.mlp_type
self.use_implicit_relative_attention = self.config.use_implicit_relative_attention if hasattr(
self.config, "use_implicit_relative_attention") else False
self.cross_attention_layers = self.config.cross_attention_layers
if self.mapping.is_first_pp_rank():
self.vocab_embedding = Embedding(
self.config.embed_vocab_size,
self.config.hidden_size,
dtype=self._dtype,
tp_size=self.mapping.tp_size
if self.config.use_parallel_embedding else 1,
tp_group=self.mapping.tp_group
if self.config.use_parallel_embedding else None,
sharding_dim=self.config.embedding_sharding_dim,
tp_rank=self.mapping.tp_rank)
layers_range = self.mapping.pp_layers(self.total_num_layers)
_layers = []
for layer_idx in layers_range:
local_layer_idx = layer_idx - layers_range[0]
args = {
"local_layer_idx": local_layer_idx,
"hidden_size": self.config.hidden_size,
"ffn_hidden_size": self.config.intermediate_size,
"num_attention_heads": self.num_heads,
"num_kv_heads": self.num_kv_heads,
"head_size": self.head_size,
"max_position_embeddings": self.config.max_position_embeddings,
"layernorm_position": self.config.layernorm_position,
"layernorm_eps": self.config.norm_epsilon,
"layernorm_type": self.config.layernorm_type,
"hidden_act": self.config.hidden_act,
"mlp_type": self.mlp_type,
"mapping": self.mapping,
"dtype": self._dtype,
"residual_scaling": self.config.residual_scaling,
"max_distance": self.config.max_distance,
"num_buckets": self.config.num_buckets,
"fp16_clamping": self.fp16_clamping,
"skip_cross_kv": self.skip_cross_kv,
"rotary_embedding_base": self.config.rotary_base,
"rotary_embedding_scaling": self.config.rotary_scaling,
"quant_mode": self.config.quant_mode,
}
if layer_idx in self.cross_attention_layers:
assert layers_range[0] == 0, "not support PP now"
_layers.append(CrossAttentionTransformerBlock(**args))
else:
_layers.append(TransformerBlock(**args))
self.layers = ModuleList(_layers)
if self.mapping.is_last_pp_rank():
self.ln_f = None
if self.has_model_final_layernorm:
self.ln_f = ln_type(normalized_shape=self.config.hidden_size,
eps=self.config.norm_epsilon,
dtype=self.config.dtype)
if self.config.relative_attention and not self.use_implicit_relative_attention:
self.rel_attn_table = Parameter(
shape=(self.config.num_attention_heads // self.mapping.tp_size,
self.config.num_buckets),
dtype=self._dtype)
def forward(
self,
decoder_input_ids: Tensor,
encoder_output: Tensor,
use_cache=False,
attention_mask_params=None,
last_token_ids=None,
kv_cache_params=None,
attention_params=None,
hidden_states=None,
lora_params: LoraParams = None,
cross_kv_cache_gen: Optional[Tensor] = None,
cross_kv_reuse: Optional[Tensor] = None,
prompt_embedding_table: Optional[Tensor] = None,
prompt_tasks: Optional[Tensor] = None,
prompt_vocab_size: Optional[Tensor] = None,
skip_cross_attn_blocks: Optional[Tensor] = None,
):
if self.mapping.is_first_pp_rank():
assert isinstance(decoder_input_ids, Tensor)
else:
assert isinstance(hidden_states, Tensor)
# In PP, layer 0 has ids as inputs, all other layers have hidden_states as inputs
if self.mapping.is_first_pp_rank():
hidden_states = self.vocab_embedding(decoder_input_ids)
self.register_network_output('embedding_layer_output',
hidden_states)
else:
hidden_states = recv(hidden_states, self.mapping.prev_pp_rank())
kv_cache_params.fill_none_tensor_list(len(self.layers))
full_text_row_masked_out_mask = reduce(
(attention_mask_params.cross_attention_mask).cast(
hidden_states.dtype),
trt.ReduceOperation.MAX,
dim=-1,
keepdim=True)
if ADD_DEBUG_TENSOR:
full_text_row_masked_out_mask.mark_output(
"full_text_row_masked_out_mask",
full_text_row_masked_out_mask.dtype)
cross_attention_mask_type = attention_mask_params.cross_attention_mask.dtype
attention_mask_params.cross_attention_mask = (
attention_mask_params.cross_attention_mask.cast(
full_text_row_masked_out_mask.dtype) *
full_text_row_masked_out_mask).cast(cross_attention_mask_type)
invert_mask = 1.0 - attention_mask_params.cross_attention_mask.cast(
hidden_states.dtype)
invert_full_text_row_masked_out_mask = 1.0 - full_text_row_masked_out_mask
final_mask = invert_mask - invert_full_text_row_masked_out_mask
attention_mask_params.cross_attention_mask = final_mask.cast(
cross_attention_mask_type)
if ADD_DEBUG_TENSOR:
attention_mask_params.cross_attention_mask.mark_output(
"attention_mask_params.cross_attention_mask",
attention_mask_params.cross_attention_mask.dtype)
if use_cache:
presents = []
for i, (decoder_layer, past) in enumerate(
zip(self.layers, kv_cache_params.past_key_value)):
lora_layer_params = None
if lora_params is not None and lora_params.lora_ranks is not None:
lora_layer_params = lora_params.get_layer_params(i)
hidden_states = decoder_layer(
hidden_states,
encoder_output=encoder_output,
attention_mask_params=attention_mask_params,
use_cache=use_cache,
kv_cache_params=KeyValueCacheParams(
past_key_value=past,
host_past_key_value_lengths=kv_cache_params.
host_past_key_value_lengths,
host_max_attention_window_sizes=kv_cache_params.
host_max_attention_window_sizes,
host_sink_token_length=kv_cache_params.
host_sink_token_length,
cache_indirection=kv_cache_params.cache_indirection,
kv_cache_block_offsets=kv_cache_params.
kv_cache_block_offsets,
host_kv_cache_block_offsets=kv_cache_params.
host_cross_kv_cache_block_offsets,
host_kv_cache_pool_pointers=kv_cache_params.
host_kv_cache_pool_pointers,
host_kv_cache_pool_mapping=kv_cache_params.
host_kv_cache_pool_mapping,
cross_kv_cache_block_offsets=kv_cache_params.
cross_kv_cache_block_offsets,
host_cross_kv_cache_block_offsets=kv_cache_params.
host_cross_kv_cache_block_offsets,
host_cross_kv_cache_pool_pointers=kv_cache_params.
host_cross_kv_cache_pool_pointers,
host_cross_kv_cache_pool_mapping=kv_cache_params.
host_cross_kv_cache_pool_mapping,
),
skip_cross_attn_blocks=skip_cross_attn_blocks if isinstance(
decoder_layer, CrossAttentionTransformerBlock) else None,
attention_params=attention_params,
lora_layer_params=lora_layer_params,
cross_kv_cache_gen=cross_kv_cache_gen,
cross_kv_reuse=cross_kv_reuse,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
)
if use_cache:
present = hidden_states[1]
presents.append((present))
hidden_states = hidden_states[0]
if self.mapping.is_last_pp_rank():
if self.ln_f:
hidden_states = self.ln_f(hidden_states)
else:
hidden_states = send(hidden_states, self.mapping.next_pp_rank())
if use_cache:
return (hidden_states, tuple(presents))
return hidden_states
def precompute_relative_attention_bias(self, build_config):
if self.config.relative_attention and not self.use_implicit_relative_attention:
relative_attention_bias_builder = torch.ops.tensorrt_llm.relative_attention_bias
rel_attn_precomputed = torch.zeros(
(self.config.num_attention_heads // self.mapping.tp_size,
build_config.max_seq_len + 1, build_config.max_seq_len + 1),
dtype=str_dtype_to_torch(self.config.dtype),
device='cuda')
rel_attn_table = numpy_to_torch(
self.rel_attn_table.raw_value).to('cuda')
relative_attention_bias_builder(
rel_attn_precomputed,
rel_attn_table,
self.config.num_attention_heads // self.mapping.tp_size,
build_config.max_seq_len,
self.config.num_buckets,
False,
self.config.max_distance,
)
for layer_idx in range(self.num_layers):
self.layers[layer_idx].self_attention.set_rel_attn_table(
build_config.max_seq_len, rel_attn_precomputed)
TODO try to inherit the DecoderModelForCausalLM
[docs] class MLLaMAForCausalLM(PretrainedModel): config_class = MLLaMAConfig
def __init__(self, config: MLLaMAConfig):
super().__init__(config)
Attention.create_attention_const_params(self, config)
self.position_embedding_type = config.position_embedding_type
self.transformer = MLLaMAModel(config)
self.mapping = self.config.mapping
self.has_model_final_layernorm = self.config.has_model_final_layernorm
self._dtype = self.config.dtype
self._kv_dtype = self._dtype
self._logits_dtype = self.config.logits_dtype
if self.mapping.is_last_pp_rank():
self.lm_head = ColumnLinear(
self.config.hidden_size,
self.config.vocab_size,
bias=False if not hasattr(self.config, "has_lm_head_bias") else
self.config.has_lm_head_bias,
dtype=self.config.dtype,
tp_group=self.config.mapping.tp_group,
tp_size=self.config.mapping.tp_size,
gather_output=True,
)
self.trtllm_modules_to_hf_modules = {
**get_default_trtllm_modules_to_hf_modules(),
"attn_q": "self_attn.q_proj",
"attn_k": "self_attn.k_proj",
"attn_v": "self_attn.v_proj",
"attn_dense": "self_attn.o_proj",
"cross_attn_q": "encoder_attn.q_proj",
"cross_attn_k": "encoder_attn.k_proj",
"cross_attn_v": "encoder_attn.v_proj",
"cross_attn_dense": "encoder_attn.o_proj",
}
[docs] def forward( self, decoder_input_ids: Tensor, encoder_output: Tensor, use_cache=False, attention_mask_params=None, last_token_ids=None, kv_cache_params=None, attention_params=None, hidden_states=None, lora_params: LoraParams = None, cross_kv_cache_gen: Optional[Tensor] = None, cross_kv_reuse: Optional[Tensor] = None, prompt_embedding_table: Optional[Tensor] = None, prompt_tasks: Optional[Tensor] = None, prompt_vocab_size: Optional[Tensor] = None, skip_cross_attn_blocks: Optional[Tensor] = None, ): if self.mapping.is_first_pp_rank(): assert isinstance(decoder_input_ids, Tensor) else: assert isinstance(hidden_states, Tensor) attention_params = Attention.fill_attention_params( self, attention_params) hidden_states = self.transformer( decoder_input_ids=decoder_input_ids, encoder_output=encoder_output, use_cache=use_cache, attention_mask_params=attention_mask_params, last_token_ids=last_token_ids, kv_cache_params=kv_cache_params, attention_params=attention_params, hidden_states=hidden_states, lora_params=lora_params, cross_kv_cache_gen=cross_kv_cache_gen, cross_kv_reuse=cross_kv_reuse, prompt_embedding_table=prompt_embedding_table, prompt_tasks=prompt_tasks, prompt_vocab_size=prompt_vocab_size, skip_cross_attn_blocks=skip_cross_attn_blocks, )
if use_cache:
hidden_states, presents = hidden_states
if self.mapping.is_last_pp_rank():
# [bs, seq, hidden_size] or [num_tokens, hidden_size] -> [bs, hidden_size]
hidden_states = gather_last_token_logits(
hidden_states, last_token_ids,
default_net().plugin_config.remove_input_padding)
# [bs, hidden_size] -> [bs, vocab_size]
lm_logits = self.lm_head(hidden_states)
lm_logits.mark_output(f'logits', self._logits_dtype)
else:
hidden_states = send(hidden_states, self.mapping.next_pp_rank())
hidden_states.mark_output(f'hidden_states_output', self._dtype)
if use_cache and default_net().plugin_config.paged_kv_cache == False:
for i, present in zip(self.mapping.pp_layers(self.total_num_layers),
presents):
present[0].mark_output(f'present_key_value_{i}', self._kv_dtype)
if default_net().plugin_config.gpt_attention_plugin:
present[1].mark_output(f'cross_present_key_value_{i}',
self._kv_dtype)
if self.mapping.is_last_pp_rank():
return (lm_logits, tuple(presents))
return (hidden_states, tuple(presents))
else:
if self.mapping.is_last_pp_rank():
return lm_logits
return hidden_states
[docs] def prepare_inputs(self, max_batch_size, max_beam_width, max_decoder_input_len, max_seq_len, max_encoder_input_len, gather_context_logits: bool = False, gather_generation_logits: bool = False, lora_target_modules: List[str] = None, prompt_embedding_table_size: int = 0, use_cache=True, *args, **kwargs): '''@brief: Prepare inputs Tensors for the model, the given sizes are used to determine the ranges of the dimensions of when using TRT dynamic shapes.
@return: a list contains values which can be fed into the self.forward()
'''
# Prepare inputs
max_output_len = max_decoder_input_len + max_seq_len
head_size = self.transformer.head_size
num_kv_heads = (self.transformer.num_kv_heads + self.mapping.tp_size -
1) // self.mapping.tp_size
encoder_head_size = head_size
encoder_num_kv_heads = num_kv_heads
bb_range = [
1, (max_batch_size * max_beam_width + 1) // 2,
max_batch_size * max_beam_width
]
bs_range = [1, (max_batch_size + 1) // 2, max_batch_size]
beam_width_range = [1, (max_beam_width + 1) // 2, max_beam_width]
inlen_range = [
1, 1, max_decoder_input_len
] # context phase >= 1 (if forced_input_ids), generation phase = 1
encoder_inlen_range = [
1, (max_encoder_input_len + 1) // 2, max_encoder_input_len
]
mask_len_range = [1, (max_output_len + 1) // 2 + 1, max_output_len + 1]
max_output_len_range = [0, (max_output_len + 1) // 2, max_output_len]
encoder_num_tokens_range = [
0, # 0 for generation phase, >0 for context phase
(max_encoder_input_len * max_batch_size + 1) // 2,
max_encoder_input_len * max_batch_size,
]
decoder_num_tokens_range = [
1,
max_batch_size * max_beam_width,
max(max_decoder_input_len * max_batch_size,
max_beam_width * max_batch_size),
]
# No enable_two_optimization_profiles support yet
encoder_input_len_range = [
0, # 0 for generation phase, >0 for context phase
(max_encoder_input_len + 1) // 2,
max_encoder_input_len
]
# pack masks into bits (store as int32).
max_cross_packed_mask_dim0 = max_batch_size * (
(max_decoder_input_len + 128 - 1) // 128) * 128
max_cross_packed_mask_dim1 = (
(max_encoder_input_len + 256 - 1) // 256) * 256 // 32
cross_packed_mask_dim0_range = [
1, (max_cross_packed_mask_dim0 + 1) // 2, max_cross_packed_mask_dim0
]
cross_packed_mask_dim1_range = [
0, # 0 for generation phase, >0 for context phase
(max_cross_packed_mask_dim1 + 1) // 2,
max_cross_packed_mask_dim1
]
past_key_value = []
sequence_length = None
host_past_key_value_lengths = None
attention_mask_params = AttentionMaskParams()
use_gpt_attention_plugin = default_net(
).plugin_config.gpt_attention_plugin
remove_input_padding = default_net().plugin_config.remove_input_padding
paged_kv_cache = default_net().plugin_config.paged_kv_cache
tokens_per_block = default_net().plugin_config.tokens_per_block
use_lora_plugin = default_net().plugin_config.lora_plugin
kv_cache_type = None
if not use_cache:
kv_cache_type = KVCacheType.DISABLED
else:
if paged_kv_cache:
kv_cache_type = KVCacheType.PAGED
else:
kv_cache_type = KVCacheType.CONTINUOUS
input_ids, hidden_states = None, None
if remove_input_padding:
if self.mapping.is_first_pp_rank():
input_ids = Tensor(name='input_ids',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([
('decoder_num_tokens',
[decoder_num_tokens_range]),
]))
else:
hidden_states = Tensor(name='hidden_states_input',
dtype=self._dtype,
shape=[-1, self.hidden_size],
dim_range=OrderedDict([
('decoder_num_tokens',
[decoder_num_tokens_range]),
('hidden_size', [self.hidden_size]),
]))
else:
if self.mapping.is_first_pp_rank():
input_ids = Tensor(name='input_ids',
dtype=trt.int32,
shape=[-1, -1],
dim_range=OrderedDict([
('batch_size_beam_width', [bb_range]),
('input_len', [inlen_range]),
]))
else:
hidden_states = Tensor(name='hidden_states_input',
dtype=self._dtype,
shape=[-1, -1, self.hidden_size],
dim_range=OrderedDict([
('batch_size_beam_width', [bb_range
]),
('input_len', [inlen_range]),
('hidden_size', [self.hidden_size]),
]))
encoder_input_lengths = Tensor(
name="encoder_input_lengths",
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([("batch_size_beam_width", [bb_range])]),
)
encoder_max_input_length = Tensor(
name="encoder_max_input_length",
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([("encoder_max_input_length",
[encoder_inlen_range])]),
)
if remove_input_padding:
encoder_output = Tensor(
name="encoder_output",
dtype=self._dtype,
shape=[-1, self.config.hidden_size],
dim_range=OrderedDict([
("encoder_num_tokens", [encoder_num_tokens_range]),
("hidden_size", [self.config.hidden_size]),
]),
)
else:
encoder_output = Tensor(
name="encoder_output",
dtype=self._dtype,
shape=[-1, -1, self.config.hidden_size],
dim_range=OrderedDict([
("batch_size_beam_width_encoder", [bb_range]),
("encoder_input_len", [encoder_input_len_range]),
("hidden_size", [self.config.hidden_size]),
]),
)
context_lengths = None
host_context_lengths = None
host_request_types = None
host_runtime_perf_knobs = None
host_context_progress = None
if use_gpt_attention_plugin and remove_input_padding:
host_context_lengths = Tensor(name='host_context_lengths',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([
('batch_size_beam_width',
[bb_range])
]))
if use_gpt_attention_plugin:
if kv_cache_type != KVCacheType.DISABLED:
sequence_length = Tensor(
name='sequence_length',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([('batch_size_beam_width', [bb_range])
]),
)
host_past_key_value_lengths = Tensor(
name='host_past_key_value_lengths',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([('batch_size_beam_width', [bb_range])
]),
)
context_lengths = Tensor(name='context_lengths',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([
('batch_size_beam_width', [bb_range])
]))
host_request_types = Tensor(name='host_request_types',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([
('batch_size_beam_width',
[bb_range])
]))
host_runtime_perf_knobs = Tensor(name='host_runtime_perf_knobs',
dtype=trt.int64,
shape=[16],
dim_range=OrderedDict([
('perf_knob_size', [16])
]))
host_context_progress = Tensor(name='host_context_progress',
dtype=trt.int64,
shape=[1],
dim_range=OrderedDict([
('context_progress_size', [1])
]))
last_token_ids = None
if self.mapping.is_last_pp_rank() and not gather_context_logits:
last_token_ids = Tensor(
name="last_token_ids",
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([("batch_size_last_token_ids", [bb_range])
]),
)
attention_mask = None
if not use_gpt_attention_plugin:
attention_mask = Tensor(
name='attention_mask',
dtype=trt.int32,
shape=[-1, -1],
dim_range=OrderedDict([
('batch_size_beam_width', [bb_range]),
('mask_len', [mask_len_range]),
]),
)
assert False, "not support non-attention-plugin case now"
cross_attention_mask = Tensor(
name='cross_attention_mask',
dtype=trt.bool,
shape=[-1, -1],
dim_range=OrderedDict([
('decoder_num_tokens_2',
[decoder_num_tokens_range
]), # TODO should use same name as input_ids
('encoder_input_len_2', [encoder_input_len_range]),
]),
)
cross_attention_packed_mask = Tensor(
name='cross_attention_packed_mask',
dtype=trt.int32,
shape=[-1, -1],
dim_range=OrderedDict([
('cross_packed_mask_dim0', [cross_packed_mask_dim0_range]),
('cross_packed_mask_dim1', [cross_packed_mask_dim1_range]),
]),
)
# create the attention_mask_params.
attention_mask_params = AttentionMaskParams(
attention_mask, None, cross_attention_mask,
cross_attention_packed_mask)
cache_indirection = Tensor(
name='cache_indirection',
dtype=trt.int32,
shape=[-1, -1, -1],
dim_range=OrderedDict([
('batch_size_cache', [bs_range]),
('beam_width', [beam_width_range]),
('max_seq_len', [max_output_len_range]),
]),
)
layers_range = self.mapping.pp_layers(self.transformer.total_num_layers)
num_pp_layers = len(layers_range)
host_max_attention_window_sizes = None
host_sink_token_length = None
if use_gpt_attention_plugin:
host_max_attention_window_sizes = Tensor(
name=f'host_max_attention_window_sizes',
dtype=trt.int32,
shape=[num_pp_layers],
dim_range=OrderedDict([('num_layers', [num_pp_layers])]))
host_sink_token_length = Tensor(name='host_sink_token_length',
dtype=trt.int32,
shape=[1],
dim_range=OrderedDict([('scalar',
[1])]))
# TODO LoRA for mllama is not verified.
lora_weights_pointers = None
lora_ranks = None
lora_params = None
if use_lora_plugin:
lora_weights_pointers = []
lora_ranks = []
missing_qkv_modules = []
if any(x in lora_target_modules
for x in ["attn_q", "attn_k", "attn_v"]):
for lora_module in [
"attn_q",
"attn_k",
"attn_v",
]:
if lora_module not in lora_target_modules:
missing_qkv_modules.append(lora_module)
if any(x in lora_target_modules
for x in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]):
for lora_module in [
"cross_attn_q", "cross_attn_k", "cross_attn_v"
]:
if lora_module not in lora_target_modules:
missing_qkv_modules.append(lora_module)
# For LoRA
for i in layers_range:
lora_weight_pointer_dict = {}
lora_rank_dict = {}
for lora_module in (lora_target_modules + missing_qkv_modules):
lora_weight_pointer = Tensor(
name=f'{lora_module}_lora_weights_pointers_{i}',
dtype=trt.int64,
shape=[-1, 3],
dim_range=OrderedDict([('batch_size_beam_width',
[bb_range]),
('in_out_scales', [3])]))
lora_weight_pointer_dict.update({
f'{lora_module}_lora_weights_pointers':
lora_weight_pointer
})
lora_rank = Tensor(name=f'{lora_module}_lora_ranks_{i}',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([
('batch_size_beam_width', [bb_range])
]))
lora_rank_dict.update(
{f'{lora_module}_lora_ranks': lora_rank})
lora_weights_pointers.append(lora_weight_pointer_dict)
lora_ranks.append(lora_rank_dict)
# For cross attention, we need to use encoder_input_lengths (in CPU) to pass
# as the host_context_lengths to the lora_plugin. But for self attention, we
# should keep using the original host_context_lengths. Therefore, we keep both
# of them in the lora_params.
host_encoder_input_lengths = None
if remove_input_padding:
host_encoder_input_lengths = Tensor(
name="host_encoder_input_lengths",
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([("batch_size_beam_width", [bb_range])
]),
)
lora_params = LoraParams(
lora_ranks=lora_ranks,
lora_weights_pointers=lora_weights_pointers,
host_context_lengths=host_context_lengths,
max_context_length=max_decoder_input_len,
max_encoder_context_length=max_encoder_input_len,
host_request_types=host_request_types,
host_encoder_input_lengths=host_encoder_input_lengths,
)
kv_cache_block_offsets = None
host_kv_cache_block_offsets = None
host_kv_cache_pool_pointers = None
host_kv_cache_pool_mapping = None
cross_kv_cache_block_offsets = None
host_cross_kv_cache_block_offsets = None
host_cross_kv_cache_pool_pointers = None
host_cross_kv_cache_pool_mapping = None
if use_cache:
if not paged_kv_cache:
for i in layers_range:
kv_dim_range = OrderedDict([
('batch_size_beam_width', [bb_range]),
('kv', [2]),
('num_heads', [num_kv_heads]),
('past_key_len', [max_output_len_range]),
('head_size', [head_size]),
])
kv = Tensor(name=f'past_key_value_{i}',
dtype=self._kv_dtype,
shape=[-1, 2, num_kv_heads, -1, head_size],
dim_range=kv_dim_range)
past_key_value.append(kv)
if i in self.transformer.cross_attention_layers:
xa_layer_id = self.transformer.cross_attention_layers.index(
i) + layers_range[-1]
cross_kv_dim_range = OrderedDict([
('batch_size_beam_width', [bb_range]),
('kv', [2]),
('cross_num_heads', [encoder_num_kv_heads]),
('cross_past_key_len', [encoder_input_len_range]),
('cross_head_size', [encoder_head_size]),
])
cross_kv = Tensor(
name=f'cross_past_key_value_{xa_layer_id}',
dtype=self._kv_dtype,
shape=[
-1, 2, encoder_num_kv_heads, -1, encoder_head_size
],
dim_range=cross_kv_dim_range)
past_key_value.append(kv)
# TODO: Remove this when TRT fix the named dimension
if not remove_input_padding:
assertion(
shape(
input_ids if self.mapping.is_first_pp_rank() else
hidden_states, 0) == shape(kv, 0), 'batch size')
else: # paged_kv_cache == True
# PagedKV setup for KV cache of self-attention
max_blocks_per_seq_range = [[
math.ceil(max_output_len_range[0] / tokens_per_block),
math.ceil(max_output_len_range[1] / tokens_per_block),
math.ceil(max_output_len_range[2] / tokens_per_block)
]]
max_blocks_per_seq_range = [[
x for x in max_blocks_per_seq_range[0]
]]
# PagedKV setup for KV cache of cross-attention
max_cross_blocks_per_seq_range = [[
math.ceil(encoder_input_len_range[0] / tokens_per_block),
math.ceil(encoder_input_len_range[1] / tokens_per_block),
math.ceil(encoder_input_len_range[2] / tokens_per_block)
]]
max_cross_blocks_per_seq_range = [[
x for x in max_cross_blocks_per_seq_range[0]
]]
num_kv_cache_pools = 2
kv_cache_block_offsets = Tensor(
name=f'kv_cache_block_offsets',
dtype=trt.int32,
shape=[num_kv_cache_pools, -1, 2, -1],
dim_range=OrderedDict([
('num_kv_cache_pools', [num_kv_cache_pools]),
('batch_size_beam_width', [bb_range]),
('kv', [2]),
('max_blocks_per_seq', max_blocks_per_seq_range),
]))
host_kv_cache_block_offsets = Tensor(
name=f'host_kv_cache_block_offsets',
dtype=trt.int32,
shape=[num_kv_cache_pools, -1, 2, -1],
dim_range=OrderedDict([
('num_kv_cache_pools', [num_kv_cache_pools]),
('batch_size_beam_width', [bb_range]),
('kv', [2]),
('max_blocks_per_seq', max_blocks_per_seq_range),
]))
host_kv_cache_pool_pointers = Tensor(
name=f'host_kv_cache_pool_pointers',
dtype=trt.int64,
shape=[num_kv_cache_pools, 2],
dim_range=OrderedDict([
('num_kv_cache_pools', [num_kv_cache_pools]),
('num_pools', [2]),
]))
host_kv_cache_pool_mapping = Tensor(
name=f"host_kv_cache_pool_mapping",
dtype=trt.int32,
# 2: (Index of pool, Index of layer within pool)
shape=[num_pp_layers, 2],
dim_range=OrderedDict([
('pools_mapping', [num_pp_layers]),
('layer_cache_pool_locator', [2]),
]))
# paged blocks for cross kv
cross_kv_cache_block_offsets = Tensor(
name=f'cross_kv_cache_block_offsets',
dtype=trt.int32,
shape=[num_kv_cache_pools, -1, 2, -1],
dim_range=OrderedDict([
('num_kv_cache_pools', [num_kv_cache_pools]),
('batch_size_beam_width', [bb_range]),
('kv', [2]),
('max_cross_blocks_per_seq',
max_cross_blocks_per_seq_range),
]))
host_cross_kv_cache_block_offsets = Tensor(
name=f'host_cross_kv_cache_block_offsets',
dtype=trt.int32,
shape=[num_kv_cache_pools, -1, 2, -1],
dim_range=OrderedDict([
('num_kv_cache_pools', [num_kv_cache_pools]),
('batch_size_beam_width', [bb_range]),
('kv', [2]),
('max_cross_blocks_per_seq',
max_cross_blocks_per_seq_range),
]))
host_cross_kv_cache_pool_pointers = Tensor(
name=f'host_cross_kv_cache_pool_pointers',
dtype=trt.int64,
shape=[num_kv_cache_pools, 2],
dim_range=OrderedDict([
('num_kv_cache_pools', [num_kv_cache_pools]),
('num_pools', [2]),
]))
host_cross_kv_cache_pool_mapping = Tensor(
name=f"host_cross_kv_cache_pool_mapping",
dtype=trt.int32,
# 2: (Index of pool, Index of layer within pool)
shape=[num_pp_layers, 2],
dim_range=OrderedDict([
('pools_mapping', [num_pp_layers]),
('layer_cache_pool_locator', [2]),
]))
for i in layers_range:
past_key_value.append(None)
kv_cache_params = KeyValueCacheParams(
past_key_value=past_key_value,
host_past_key_value_lengths=host_past_key_value_lengths,
host_max_attention_window_sizes=host_max_attention_window_sizes,
host_sink_token_length=host_sink_token_length,
cache_indirection=cache_indirection,
kv_cache_block_offsets=kv_cache_block_offsets,
host_kv_cache_block_offsets=host_kv_cache_block_offsets,
host_kv_cache_pool_pointers=host_kv_cache_pool_pointers,
host_kv_cache_pool_mapping=host_kv_cache_pool_mapping,
cross_kv_cache_block_offsets=cross_kv_cache_block_offsets,
host_cross_kv_cache_block_offsets=
host_cross_kv_cache_block_offsets,
host_cross_kv_cache_pool_pointers=
host_cross_kv_cache_pool_pointers,
host_cross_kv_cache_pool_mapping=
host_cross_kv_cache_pool_mapping,
)
attention_params = AttentionParams(
sequence_length=sequence_length,
context_lengths=context_lengths,
host_context_lengths=host_context_lengths,
max_context_length=max_decoder_input_len,
host_request_types=host_request_types,
host_runtime_perf_knobs=host_runtime_perf_knobs,
host_context_progress=host_context_progress,
encoder_input_lengths=encoder_input_lengths,
encoder_max_input_length=encoder_max_input_length,
)
cross_kv_cache_gen = Tensor(name='cross_kv_cache_gen',
dtype=trt.bool,
shape=[1],
dim_range=OrderedDict([
('boolean', [1]),
]))
cross_kv_reuse = None
num_heads = (self.transformer.num_heads + self.mapping.tp_size -
1) // self.mapping.tp_size
cross_kv_out_dim = 2 * num_kv_heads * self.transformer.head_size
if self.transformer.skip_cross_kv:
if remove_input_padding:
cross_kv_reuse = Tensor(
name="cross_kv_reuse",
dtype=self._dtype,
shape=[-1, cross_kv_out_dim],
dim_range=OrderedDict([
("encoder_num_tokens", [encoder_num_tokens_range]),
("encoder_kv_size", [cross_kv_out_dim]),
]),
)
else:
cross_kv_reuse = Tensor(
name="cross_kv_reuse",
dtype=self._dtype,
shape=[-1, -1, cross_kv_out_dim],
dim_range=OrderedDict([
("batch_size_beam_width_encoder", [bb_range]),
("encoder_input_len", [encoder_input_len_range]),
("encoder_kv_size", [cross_kv_out_dim]),
]),
)
skip_cross_attn_blocks = None
if self.config.skip_cross_attn_blocks:
skip_cross_attn_blocks = Tensor(name='skip_cross_attn_blocks',
dtype=trt.bool,
shape=[1],
dim_range=OrderedDict([
('boolean', [1]),
]))
prompt_embedding_table = None
tasks = None
prompt_vocab_size = None
if self.mapping.is_first_pp_rank() and prompt_embedding_table_size > 0:
p_embedding_range = [[
1, prompt_embedding_table_size // 2, prompt_embedding_table_size
]]
prompt_embedding_table = Tensor(
name='prompt_embedding_table',
dtype=self._dtype,
shape=[-1, self.transformer.hidden_size],
dim_range=OrderedDict([
('prompt_embedding_table_size', p_embedding_range),
('hidden_size', [self.transformer.hidden_size]),
]))
if remove_input_padding:
num_tokens_range = [
1,
(max_decoder_input_len * max_batch_size + 1) // 2,
max_decoder_input_len * max_batch_size,
]
tasks = Tensor(name='tasks',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([
('decoder_num_tokens',
[decoder_num_tokens_range]),
]))
else:
tasks = Tensor(name='tasks',
dtype=trt.int32,
shape=[-1, 1],
dim_range=OrderedDict([
('batch_size', bs_range),
('broadcast_dim', [1]),
]))
prompt_vocab_size = Tensor(name='prompt_vocab_size',
dtype=trt.int32,
shape=[1],
dim_range=OrderedDict([('size', [1])]))
result = {
'decoder_input_ids': input_ids,
'encoder_output': encoder_output,
'use_cache': True,
'attention_mask_params': attention_mask_params,
'last_token_ids': last_token_ids,
'kv_cache_params': kv_cache_params,
'attention_params': attention_params,
'hidden_states': hidden_states,
'lora_params': lora_params,
'cross_kv_cache_gen': cross_kv_cache_gen,
'cross_kv_reuse': cross_kv_reuse,
'prompt_embedding_table': prompt_embedding_table,
'prompt_tasks': tasks,
'prompt_vocab_size': prompt_vocab_size,
'skip_cross_attn_blocks': skip_cross_attn_blocks,
}
return result
[docs] def use_lora(self, lora_config: LoraConfig): use_lora(self, lora_config, self.trtllm_modules_to_hf_modules)
[docs] @classmethod def from_hugging_face( cls, hf_model_or_dir: Union[str, 'transformers.PreTrainedModel'], dtype: str = 'auto', mapping: Optional[Mapping] = None, quant_config: Optional[QuantConfig] = None, **kwargs): ''' Create a MLLaMAForCausalLM object from give parameters ''' import transformers
kwargs.pop('load_by_shard', False)
kwargs.pop('load_model_on_cpu', False)
quant_ckpt_path = kwargs.pop('quant_ckpt_path', None)
assert hf_model_or_dir is not None
use_preloading = isinstance(hf_model_or_dir,
transformers.PreTrainedModel)
if use_preloading:
hf_model = hf_model_or_dir
hf_config_or_dir = hf_model.config
else:
hf_model_dir = hf_model_or_dir
hf_config_or_dir = hf_model_or_dir
config = MLLaMAConfig.from_hugging_face(hf_config_or_dir,
dtype=dtype,
mapping=mapping,
quant_config=quant_config,
**kwargs)
custom_dict = {
"lm_head": "language_model.lm_head",
"transformer.ln_f": "language_model.model.norm",
"transformer": "language_model.model",
"self_attention": "self_attn",
"cross_attention": "cross_attn",
"vocab_embedding": "embed_tokens",
"gate_attn": "cross_attn_attn_gate",
"gate_ffwd": "cross_attn_mlp_gate",
"q_layernorm": "q_norm",
"k_layernorm": "k_norm",
}
if quant_ckpt_path is not None:
hf_model_dir = quant_ckpt_path
loader = ModelWeightsLoader(hf_model_dir, custom_dict)
model = cls(config)
loader.generate_tllm_weights(model)
return model