tensorrt_llm.models.bert.model — TensorRT-LLM (original) (raw)
SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
from typing import Optional, OrderedDict, Union
import numpy as np import tensorrt as trt import torch import transformers
from tensorrt_llm.models.modeling_utils import PretrainedModel
from ..._common import default_net from ...functional import (ACT2FN, Tensor, concat, constant, cumsum, expand, index_select, select, shape, slice, unsqueeze) from ...layers import MLP, BertAttention, Embedding, LayerNorm, Linear from ...mapping import Mapping from ...module import Module, ModuleList from ..modeling_utils import QuantConfig from .config import BERTConfig from .convert import (load_hf_bert_base, load_hf_bert_cls, load_hf_bert_qa, load_weights_from_hf_model)
class BertEmbedding(Module):
def __init__(self,
vocab_size,
hidden_size,
max_position_embeddings,
type_vocab_size,
dtype=None):
super().__init__()
self.vocab_embedding = Embedding(vocab_size, hidden_size, dtype=dtype)
self.position_embedding = Embedding(max_position_embeddings,
hidden_size,
dtype=dtype)
self.token_embedding = Embedding(type_vocab_size,
hidden_size,
dtype=dtype)
self.max_position_embeddings = max_position_embeddings
self.embedding_ln = LayerNorm(normalized_shape=hidden_size, dtype=dtype)
def forward(self, input_ids, position_ids, token_type_ids):
x = self.vocab_embedding(input_ids)
x = x + self.position_embedding(position_ids)
x = x + self.token_embedding(token_type_ids)
x = self.embedding_ln(x)
return x
class BertEncoderLayer(Module):
def __init__(self,
hidden_size,
num_attention_heads,
max_position_embeddings,
hidden_act='relu',
tp_group=None,
tp_size=1,
dtype=None):
super().__init__()
self.input_layernorm = LayerNorm(normalized_shape=hidden_size,
dtype=dtype)
self.attention = BertAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
max_position_embeddings=max_position_embeddings,
tp_group=tp_group,
tp_size=tp_size,
dtype=dtype)
self.mlp = MLP(hidden_size=hidden_size,
ffn_hidden_size=hidden_size * 4,
hidden_act=hidden_act,
tp_group=tp_group,
tp_size=tp_size,
dtype=dtype)
self.post_layernorm = LayerNorm(normalized_shape=hidden_size,
dtype=dtype)
def forward(self,
hidden_states,
attention_mask=None,
input_lengths=None,
max_input_length=None):
residual = hidden_states
attention_output = self.attention(hidden_states,
attention_mask=attention_mask,
input_lengths=input_lengths,
max_input_length=max_input_length)
hidden_states = residual + attention_output
hidden_states = self.input_layernorm(hidden_states)
residual = hidden_states
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
hidden_states = self.post_layernorm(hidden_states)
return hidden_states
class BertBase(PretrainedModel): ''' Base class that provides from_huggingface() and prepare_inputs() methods ''' config_class = BERTConfig
def __init__(self, config: BERTConfig):
super().__init__(config)
@classmethod
def load_hf_bert(cls, model_dir: str, load_model_on_cpu: bool,
dtype: torch.dtype):
"""
Use as the abstractmethod, load corresponding HF model.
Subclass must implement this method!
"""
assert cls.__name__ != "BertBase", f"Never call from BertBase class!"
if cls.__name__ == "BertModel":
return load_hf_bert_base(model_dir, load_model_on_cpu, dtype)
elif cls.__name__ == "BertForQuestionAnswering":
return load_hf_bert_qa(model_dir, load_model_on_cpu, dtype)
elif cls.__name__ == "BertForSequenceClassification":
return load_hf_bert_cls(model_dir, load_model_on_cpu, dtype)
else:
assert False, f"Unknown class {cls.__name__}!"
@classmethod
def from_hugging_face(
cls,
hf_model_or_dir: Union[str, 'transformers.PreTrainedModel'],
dtype: str = 'float16',
mapping: Optional[Mapping] = None,
quant_config: Optional[QuantConfig] = None,
**kwargs):
"""
Create a BertModel object from give parameters
"""
import transformers
assert hf_model_or_dir is not None
use_preloading = isinstance(hf_model_or_dir,
transformers.PreTrainedModel)
if use_preloading:
hf_model = hf_model_or_dir
hf_config_or_dir = hf_model.config
else:
hf_model_dir = hf_model_or_dir
hf_config_or_dir = hf_model_or_dir
load_model_on_cpu = kwargs.pop('load_model_on_cpu', False)
tllm_config = BERTConfig.from_hugging_face(
hf_config_or_dir=hf_config_or_dir,
dtype=dtype,
mapping=mapping,
quant_config=quant_config,
**kwargs)
#NOTE: override architecture info
RobertaCls_mapping = {
"BertModel": "RobertaModel",
"BertForQuestionAnswering": "RobertaForQuestionAnswering",
"BertForSequenceClassification": "RobertaForSequenceClassification",
}
if tllm_config.is_roberta:
setattr(tllm_config, 'architecture',
RobertaCls_mapping[cls.__name__])
else:
setattr(tllm_config, 'architecture', cls.__name__)
torch_dtype = torch.float16 if dtype == 'float16' else torch.float32
if not use_preloading:
hf_model = cls.load_hf_bert(model_dir=hf_model_dir,
load_model_on_cpu=load_model_on_cpu,
dtype=torch_dtype)
weights = load_weights_from_hf_model(hf_model=hf_model,
config=tllm_config)
model = cls(tllm_config)
model.load(weights)
return model
# Override the PretrainedModel's meothd, can unify in the future.
def prepare_inputs(self, max_batch_size, max_input_len, **kwargs):
remove_input_padding = default_net().plugin_config.remove_input_padding
# opt_shape is set to half of max batch_size and seq_len by default
# tune this according to real data distribution
bs_range = [1, (max_batch_size + 1) // 2, max_batch_size]
inlen_range = [1, (max_input_len + 1) // 2, max_input_len]
num_tokens_range = [
1,
(max_input_len * max_batch_size + 1) // 2,
max_input_len * max_batch_size,
]
if not remove_input_padding:
input_ids = Tensor(
name='input_ids',
dtype=trt.int32,
shape=[-1, -1],
dim_range=OrderedDict([('batch_size', [bs_range]),
('input_len', [inlen_range])]),
)
# also called segment_ids
token_type_ids = Tensor(
name='token_type_ids',
dtype=trt.int32,
shape=[-1, -1],
dim_range=OrderedDict([('batch_size', [bs_range]),
('input_len', [inlen_range])]),
)
else:
input_ids = Tensor(
name="input_ids",
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([("num_tokens", [num_tokens_range])]),
)
token_type_ids = Tensor(
name='token_type_ids',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([('num_tokens', [num_tokens_range])]),
)
position_ids = Tensor(
name='position_ids',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([('num_tokens', [num_tokens_range])]),
)
max_input_length = Tensor(
name="max_input_length",
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([("max_input_length", [inlen_range])]),
)
input_lengths = Tensor(name='input_lengths',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([('batch_size', [bs_range])
]))
inputs = {
'input_ids': input_ids,
'input_lengths': input_lengths,
'token_type_ids': token_type_ids,
}
if remove_input_padding:
inputs['position_ids'] = position_ids
inputs['max_input_length'] = max_input_length
return inputs
[docs] class BertModel(BertBase):
def __init__(self, config: BERTConfig):
super().__init__(config)
self.config = config
self.max_position_embeddings = config.max_position_embeddings
self.padding_idx = config.pad_token_id
self.is_roberta = config.is_roberta
self.embedding = BertEmbedding(
vocab_size=config.vocab_size,
hidden_size=config.hidden_size,
max_position_embeddings=config.max_position_embeddings,
type_vocab_size=config.type_vocab_size,
dtype=config.dtype)
self.layers = ModuleList([
BertEncoderLayer(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
max_position_embeddings=config.max_position_embeddings,
hidden_act=config.hidden_act,
tp_group=config.mapping.tp_group,
tp_size=config.mapping.tp_size,
dtype=config.dtype) for _ in range(config.num_hidden_layers)
])
[docs] def forward(self, input_ids=None, input_lengths=None, position_ids=None, token_type_ids=None, hidden_states=None, max_input_length=None): # remove_input_padding requires these fields as explicit input mask = None if not default_net().plugin_config.remove_input_padding: seq_len_2d = concat([1, shape(input_ids, 1)])
# create position ids
position_ids_buffer = constant(
np.expand_dims(
np.arange(self.max_position_embeddings).astype(np.int32),
0))
tmp_position_ids = slice(position_ids_buffer,
starts=[0, 0],
sizes=seq_len_2d)
tmp_position_ids = expand(tmp_position_ids, shape(input_ids)) #BxL
tmp_input_lengths = unsqueeze(input_lengths, 1) #Bx1
tmp_input_lengths = expand(tmp_input_lengths,
shape(input_ids)) #BxL
mask = tmp_position_ids < tmp_input_lengths # BxL
mask = mask.cast('int32')
if position_ids is None:
if self.is_roberta:
# see create_position_ids_from_input_ids() in https://github.com/huggingface/transformers/blob/main/src/transformers/models/roberta/modeling_roberta.py
position_ids = (tmp_position_ids + 1) * mask
position_ids = position_ids + self.padding_idx
else:
position_ids = slice(position_ids_buffer,
starts=[0, 0],
sizes=seq_len_2d)
position_ids = expand(position_ids, shape(input_ids))
# create token_type_ids
if token_type_ids is None:
token_type_ids_buffer = constant(
np.expand_dims(
np.zeros(self.max_position_embeddings).astype(np.int32),
0))
token_type_ids = slice(token_type_ids_buffer,
starts=[0, 0],
sizes=seq_len_2d)
token_type_ids = expand(token_type_ids, shape(input_ids))
hidden_states = self.embedding(input_ids, position_ids, token_type_ids)
self.register_network_output('embedding_output', hidden_states)
for idx, layer in enumerate(self.layers):
hidden_states = layer(hidden_states=hidden_states,
input_lengths=input_lengths,
attention_mask=mask,
max_input_length=max_input_length)
# keep the last layer output name as hidden_states
if ((idx == (self.config.num_hidden_layers - 1)) and
(self.config.architecture in ["BertModel", "RobertaModel"])):
hidden_states.mark_output('hidden_states', self.config.dtype)
else:
self.register_network_output(f"layer_{idx}_output",
hidden_states)
return hidden_states
RobertaModel = BertModel
[docs] class BertForQuestionAnswering(BertBase):
def __init__(self, config: BERTConfig):
super().__init__(config)
self.bert = BertModel(config)
self.num_labels = config.num_labels
self.qa_outputs = Linear(config.hidden_size,
config.num_labels,
dtype=config.dtype)
[docs] def forward(self, input_ids=None, input_lengths=None, token_type_ids=None, position_ids=None, hidden_states=None, max_input_length=None):
remove_input_padding = default_net().plugin_config.remove_input_padding
if remove_input_padding:
assert token_type_ids is not None and \
position_ids is not None and \
max_input_length is not None, \
"token_type_ids, position_ids, max_input_length is required " \
"in remove_input_padding mode"
hidden_states = self.bert.forward(input_ids=input_ids,
input_lengths=input_lengths,
token_type_ids=token_type_ids,
position_ids=position_ids,
hidden_states=hidden_states,
max_input_length=max_input_length)
logits = self.qa_outputs(hidden_states)
logits.mark_output('logits', self.config.logits_dtype)
return logits
RobertaForQuestionAnswering = BertForQuestionAnswering
class BertPooler(Module):
def __init__(self, hidden_size, dtype):
super().__init__()
self.dense = Linear(hidden_size, hidden_size, dtype=dtype)
self.activation = ACT2FN['tanh']
def forward(self, hidden_states, input_lengths, remove_input_padding):
if not remove_input_padding:
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = select(hidden_states, 1, 0)
else:
# when remove_input_padding is enabled, the shape of hidden_states is [num_tokens, hidden_size]
# We can take the first token of each sequence according to input_lengths,
# and then do pooling similar to padding mode.
# For example, if input_lengths is [8, 5, 6], then the indices of first tokens
# should be [0, 8, 13]
first_token_indices = cumsum(
concat([
0,
slice(input_lengths,
starts=[0],
sizes=(shape(input_lengths) -
constant(np.array([1], dtype=np.int32))))
]), 0)
first_token_tensor = index_select(hidden_states, 0,
first_token_indices)
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class RobertaClassificationHead(Module): """Head for sentence-level classification tasks."""
def __init__(self, hidden_size, dtype, num_labels):
super().__init__()
self.dense = Linear(hidden_size, hidden_size, dtype=dtype)
self.out_proj = Linear(hidden_size, num_labels)
def forward(self, hidden_states, input_lengths, remove_input_padding):
if not remove_input_padding:
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = select(hidden_states, 1, 0)
else:
# when remove_input_padding is enabled, the shape of hidden_states is [num_tokens, hidden_size]
# We can take the first token of each sequence according to input_lengths,
# and then do pooling similar to padding mode.
# For example, if input_lengths is [8, 5, 6], then the indices of first tokens
# should be [0, 8, 13]
first_token_indices = cumsum(
concat([
0,
slice(input_lengths,
starts=[0],
sizes=(shape(input_lengths) -
constant(np.array([1], dtype=np.int32))))
]), 0)
first_token_tensor = index_select(hidden_states, 0,
first_token_indices)
x = self.dense(first_token_tensor)
x = ACT2FN['tanh'](x)
x = self.out_proj(x)
return x
[docs] class BertForSequenceClassification(BertBase):
def __init__(self, config: BERTConfig):
super().__init__(config)
self.config = config
self.is_roberta = config.is_roberta
self.bert = BertModel(config)
self.num_labels = config.num_labels
if not config.is_roberta:
self.pooler = BertPooler(hidden_size=config.hidden_size,
dtype=config.dtype)
self.classifier = Linear(config.hidden_size,
config.num_labels,
dtype=config.dtype)
else:
self.classifier = RobertaClassificationHead(
hidden_size=config.hidden_size,
num_labels=config.num_labels,
dtype=config.dtype)
[docs] def forward(self, input_ids, input_lengths, token_type_ids=None, position_ids=None, hidden_states=None, max_input_length=None):
remove_input_padding = default_net().plugin_config.remove_input_padding
# required as explicit input in remove_input_padding mode
# see examples/models/core/bert/run_remove_input_padding.py for how to create them from input_ids and input_lengths
if remove_input_padding:
assert token_type_ids is not None and \
position_ids is not None and \
max_input_length is not None, \
"token_type_ids, position_ids, max_input_length is required " \
"in remove_input_padding mode"
hidden_states = self.bert.forward(input_ids=input_ids,
input_lengths=input_lengths,
token_type_ids=token_type_ids,
position_ids=position_ids,
hidden_states=hidden_states,
max_input_length=max_input_length)
if not self.is_roberta:
pooled_output = self.pooler(
hidden_states=hidden_states,
input_lengths=input_lengths,
remove_input_padding=remove_input_padding)
logits = self.classifier(pooled_output)
else:
logits = self.classifier(hidden_states=hidden_states,
input_lengths=input_lengths,
remove_input_padding=remove_input_padding)
logits.mark_output('logits', self.config.logits_dtype)
return logits
RobertaForSequenceClassification = BertForSequenceClassification