Adding a New Model in PyTorch Backend — TensorRT-LLM (original) (raw)

Table of Contents#

  1. Introduction
  2. Prerequisites
  3. Step-by-Step Guide
    1. Model Configuration
    2. Model Definition
    3. Weight Loading
    4. Model Registration
      1. Core Models
      2. Out-of-Tree Models

Introduction#

This guide provides a step-by-step process for adding a new model in PyTorch Backend.

Prerequisites#

Before you begin, ensure you have the following:

Step-by-Step Guide#

Model Configuration#

Suppose you want to support a new model named MyModel. If the model is already supported in HuggingFace’s transformers, you should bring the PyTorch modeling code and reuse HuggingFace’s configuration class. For example, our tensorrt_llm/_torch/models/modeling_llama.py was adapted from HuggingFace’s modeling_llama.py; in the modeling code, we reuse the configuration class:

from transformers import LlamaConfig

If the model is not registered in HuggingFace’s transformers, you need to define the configuration class in your configuration_mymodel.py following HuggingFace’s configuration_llama.py:

from transformers.configuration_utils import PretrainedConfig

class MyConfig(PretrainedConfig): def init(self, ...): ...

Model Definition#

Remove any unnecessary code (e.g., training-specific code), and then rewrite some PyTorch modules. For a typical Transformer decoder model, you need to implement your modeling_mymodel.py like this:

from typing import Optional

import torch from torch import nn from tensorrt_llm._torch.attention_backend import AttentionMetadata from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.models.modeling_utils import DecoderModel, DecoderModelForCausalLM from tensorrt_llm._torch.modules.attention import Attention from tensorrt_llm._torch.modules.decoder_layer import DecoderLayer

from configuration_mymodel import MyConfig

class MyAttention(Attention): def init(self, model_config: ModelConfig[MyConfig], layer_idx: Optional[int] = None): # Use model_config to initialize the Attention module super().init(...)

class MyDecoderLayer(DecoderLayer): def init(self, model_config: ModelConfig[MyConfig], layer_idx: int): super().init() # Use model_config to initialize the submodules self.input_layernorm = ... self.self_attn = MyAttention(model_config, layer_idx) self.post_attention_layernorm = ... self.mlp = ...

def forward(self, hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, **kwargs):
    # Define the forward computation of a single decoder layer
    ...

class MyModel(DecoderModel): def init(self, model_config: ModelConfig[MyConfig]): super().init(model_config) # Use model_config to initialize the submodules self.embed_tokens = ... self.layers = nn.ModuleList([ MyDecoderLayer(model_config, layer_idx) for layer_idx in range(model_config.pretrained_config.num_hidden_layers) ])

def forward(self,
            attn_metadata: AttentionMetadata,
            input_ids: Optional[torch.IntTensor] = None,
            position_ids: Optional[torch.IntTensor] = None,
            inputs_embeds: Optional[torch.FloatTensor] = None):
    # Define the forward computation of the model
    ...

class MyModelForCausalLM(DecoderModelForCausalLM[MyModel, MyConfig]): def init(self, model_config: ModelConfig[MyConfig]): super().init(MyModel(model_config), config=model_config, hidden_size=model_config.pretrained_config.hidden_size, vocab_size=model_config.pretrained_config.vocab_size)

Note that MyAttention inherits from our Attention module (in tensorrt_llm/_torch/modules/attention.py), so that the attention computation is compatible with our PyTorch runtime. Related to this, module inputs should also be adapted:

Additionally, MyDecoderLayer, MyModel, and MyModelForCausalLM are subclasses of DecoderLayer, DecoderModel, and DecoderModelForCausalLM respectively. The base classes define interfaces and provide a generic scaffolding to define model layers, load weights, etc.

Optionally, you may replace the native PyTorch modules with our implementations to enable features or achieve higher performance:

For a concrete reference, check out tensorrt_llm/_torch/models/modeling_llama.py.

Weight Loading#

The base class DecoderModelForCausalLM provides a load_weights method that loads the weights from the checkpoint file and assigns them to the corresponding layers in the model. However, if the default method does not work for MyModelForCausalLM, you need to implement your own load_weights:

class MyModelForCausalLM(DecoderModelForCausalLM[MyModel, MyConfig]):

def load_weights(self, weights: dict):
    # Define the weight loading logic
    ...

For example, Huggingface’s LLaMA model uses three linear layers for Q/K/V projections, resulting in three weight tensors in the checkpoint:

weights { ..., "model.layers.0.self_attn.q_proj.weight": torch.Tensor([hidden_size, hidden_size]), "model.layers.0.self_attn.k_proj.weight": torch.Tensor([hidden_size, hidden_size]), "model.layers.0.self_attn.v_proj.weight": torch.Tensor([hidden_size, hidden_size]), ..., }

However, our LLaMA model fuses the three layers into one linear layer:

llama.model.layers[0].self_attn.qkv_proj.weight.data torch.Tensor([hidden_size * 3, hidden_size])

Hence, load_weights needs to collect the three weight tensors from the original checkpoint, concatenate them, and assign them to the fused linear layer. Considering tensor parallelism and quantization, the process would be more complicated. We recommend calling the predefined module-level load_weights (e.g., Linear and Embedding) when implementing your model-level load_weights method.

Overall, load_weights should handle any discrepancy between MyModelForCausalLM and the weights loaded from the checkpoint, so that MyModelForCausalLM can perform forward computation equivalent to the original model.

Model Registration#

The new model needs to be registered so that it can be recognized by the PyTorch runtime. The registration can be done simply by adding the register_auto_model decorator for MyModelForCausalLM:

from tensorrt_llm._torch.models.modeling_utils import register_auto_model

@register_auto_model("MyModelForCausalLM") class MyModelForCausalLM(DecoderModelForCausalLM[MyModel, MyConfig]): def init(self, model_config: ModelConfig[MyConfig]): ...

Core Models#

To add the new model to core models, modeling_mymodel.py (and potentially configuration_mymodel.py) should be placed in tensorrt_llm/_torch/models. Then, you need to import the modeling code in tensorrt_llm/_torch/models/__init__.py:

from .modeling_mymodel import MyModelForCausalLM

all = [ ..., "MyModelForCausalLM", ]

Out-of-Tree Models#

Alternatively, you can register the new model as an out-of-tree model, so that you can use the new model without touching the TensorRT-LLM codebase. To do so, place modeling_mymodel.py (and potentially configuration_mymodel.py) in your working directory, and import the modeling code in your script:

from tensorrt_llm._torch import LLM import modeling_mymodel

def main(): llm = LLM(...)

if name == 'main': main()

We provide an out-of-tree modeling example in examples/pytorch/out_of_tree_example. The model is implemented in modeling_opt.py and you can run the example by:

python examples/pytorch/out_of_tree_example/main.py