tensorrt_llm.models.mmdit_sd3.model — TensorRT-LLM (original) (raw)
SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
from collections import OrderedDict from typing import Any, Dict, List, Optional
from ..._utils import str_dtype_to_torch from ...functional import (Tensor, allgather, chunk, concat, einsum, pad, shape, unsqueeze) from ...layers import LayerNorm, Linear from ...layers.attention import DiffusersAttention from ...layers.embedding import (CombinedTimestepTextProjEmbeddings, SD3PatchEmbed) from ...layers.mlp import (LinearActivation, LinearApproximateGELU, LinearGEGLU, LinearGELU, LinearSwiGLU) from ...layers.normalization import (AdaLayerNormContinuous, AdaLayerNormZero, SD35AdaLayerNormZeroX) from ...logger import logger from ...mapping import Mapping from ...module import Module, ModuleList from ..model_weights_loader import ModelWeightsLoader from ..modeling_utils import PretrainedModel from .config import SD3Transformer2DModelConfig
class FeedForward(Module):
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
mult: int = 4,
activation_fn: str = "geglu",
inner_dim=None,
bias: bool = True,
mapping=Mapping(),
dtype=None,
):
super().__init__()
self.mapping = mapping
self.dtype = dtype
if inner_dim is None:
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
if activation_fn == "gelu":
raise NotImplementedError('GELU only support tanh now.')
if activation_fn == "gelu-approximate":
act_fn = LinearGELU(dim,
inner_dim,
approximate="tanh",
bias=bias,
mapping=mapping,
dtype=dtype)
elif activation_fn == "geglu":
act_fn = LinearGEGLU(dim,
inner_dim,
approximate="tanh",
bias=bias,
mapping=mapping,
dtype=dtype)
elif activation_fn == "geglu-approximate":
act_fn = LinearApproximateGELU(dim,
inner_dim,
bias=bias,
mapping=mapping,
dtype=dtype)
elif activation_fn == "swiglu":
act_fn = LinearSwiGLU(dim,
inner_dim,
bias=bias,
mapping=mapping,
dtype=dtype)
elif activation_fn == "linear-silu":
act_fn = LinearActivation(dim,
inner_dim,
bias=bias,
activation="silu",
mapping=mapping,
dtype=dtype)
self.net = ModuleList([
act_fn,
Linear(inner_dim,
dim_out,
bias=bias,
tp_group=self.mapping.tp_group,
tp_size=self.mapping.tp_size,
dtype=self.dtype)
])
def forward(self, hidden_states: Tensor):
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
class JointTransformerBlock(Module):
def __init__(self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
context_pre_only: bool = False,
qk_norm: Optional[str] = None,
use_dual_attention: bool = False,
mapping=Mapping(),
dtype=None):
super().__init__()
self.use_dual_attention = use_dual_attention
self.context_pre_only = context_pre_only
context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
if use_dual_attention:
self.norm1 = SD35AdaLayerNormZeroX(dim,
mapping=mapping,
dtype=dtype)
else:
self.norm1 = AdaLayerNormZero(dim, mapping=mapping, dtype=dtype)
if context_norm_type == "ada_norm_continous":
self.norm1_context = AdaLayerNormContinuous(
dim,
dim,
elementwise_affine=False,
eps=1e-6,
bias=True,
norm_type="layer_norm",
dtype=dtype)
elif context_norm_type == "ada_norm_zero":
self.norm1_context = AdaLayerNormZero(dim, dtype=dtype)
else:
raise ValueError(
f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
)
self.attn = DiffusersAttention(
query_dim=dim,
cross_attention_dim=None,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=context_pre_only,
bias=True,
qk_norm=qk_norm,
eps=1e-6,
mapping=mapping,
dtype=dtype,
)
if use_dual_attention:
self.attn2 = DiffusersAttention(
query_dim=dim,
cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
qk_norm=qk_norm,
eps=1e-6,
mapping=mapping,
dtype=dtype,
)
else:
self.attn2 = None
self.norm2 = LayerNorm(dim,
elementwise_affine=False,
eps=1e-6,
dtype=dtype)
self.ff = FeedForward(dim=dim,
dim_out=dim,
activation_fn="gelu-approximate",
mapping=mapping,
dtype=dtype)
if not context_pre_only:
self.norm2_context = LayerNorm(dim,
elementwise_affine=False,
eps=1e-6,
dtype=dtype)
self.ff_context = FeedForward(dim=dim,
dim_out=dim,
activation_fn="gelu-approximate",
mapping=mapping,
dtype=dtype)
else:
self.norm2_context = None
self.ff_context = None
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def set_chunk_feed_forward(self,
chunk_size: Optional[int] = None,
dim: int = 0):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
@staticmethod
def _chunked_feed_forward(ff: Module, hidden_states: Tensor, chunk_dim: int,
chunk_size: int):
# "feed_forward_chunk_size" can be used to save memory
if hidden_states.shape[chunk_dim] % chunk_size != 0:
raise ValueError(
f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
)
num_chunks = hidden_states.shape[chunk_dim] // chunk_size
ff_output = concat(
[
ff(hid_slice)
for hid_slice in chunk(hidden_states, num_chunks, dim=chunk_dim)
],
dim=chunk_dim,
)
return ff_output
def forward(self,
hidden_states: Tensor,
encoder_hidden_states: Tensor,
temb: Tensor,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
*args,
**kwargs):
joint_attention_kwargs = joint_attention_kwargs or {}
if self.use_dual_attention:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
hidden_states, emb=temb)
else:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, emb=temb)
if self.context_pre_only:
norm_encoder_hidden_states = self.norm1_context(
encoder_hidden_states, temb)
else:
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb)
# Attention.
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
**joint_attention_kwargs,
)
# Process attention outputs for the `hidden_states`.
attn_output = unsqueeze(gate_msa, 1) * attn_output
hidden_states = hidden_states + attn_output
if self.use_dual_attention:
attn_output2 = self.attn2(hidden_states=norm_hidden_states2,
**joint_attention_kwargs)
attn_output2 = unsqueeze(gate_msa2, 1) * attn_output2
hidden_states = hidden_states + attn_output2
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (
1 + unsqueeze(scale_mlp, 1)) + unsqueeze(shift_mlp, 1)
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
ff_output = self._chunked_feed_forward(self.ff, norm_hidden_states,
self._chunk_dim,
self._chunk_size)
else:
ff_output = self.ff(norm_hidden_states)
ff_output = unsqueeze(gate_mlp, 1) * ff_output
hidden_states = hidden_states + ff_output
# Process attention outputs for the `encoder_hidden_states`.
if self.context_pre_only:
encoder_hidden_states = None
else:
context_attn_output = unsqueeze(c_gate_msa, 1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output
norm_encoder_hidden_states = self.norm2_context(
encoder_hidden_states)
norm_encoder_hidden_states = norm_encoder_hidden_states * (
1 + unsqueeze(c_scale_mlp, 1)) + unsqueeze(c_shift_mlp, 1)
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
context_ff_output = self._chunked_feed_forward(
self.ff_context, norm_encoder_hidden_states,
self._chunk_dim, self._chunk_size)
else:
context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + unsqueeze(
c_gate_mlp, 1) * context_ff_output
return encoder_hidden_states, hidden_states
[docs] class SD3Transformer2DModel(PretrainedModel): config_class = SD3Transformer2DModelConfig
def __init__(self, config: SD3Transformer2DModelConfig):
super().__init__(config)
self.quant_mode = config.quant_mode
self.mapping = config.mapping
self.dtype = config.dtype
self.in_channels = config.in_channels
default_out_channels = config.in_channels
self.out_channels = config.out_channels if config.out_channels is not None else default_out_channels
self.inner_dim = config.num_attention_heads * config.attention_head_dim
self.pos_embed = SD3PatchEmbed(
height=config.sample_size,
width=config.sample_size,
patch_size=config.patch_size,
in_channels=self.in_channels,
embed_dim=self.inner_dim,
pos_embed_max_size=config.
pos_embed_max_size, # hard-code as HF implementation
dtype=self.dtype)
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
embedding_dim=self.inner_dim,
pooled_projection_dim=config.pooled_projection_dim,
mapping=self.mapping,
dtype=self.dtype)
self.context_embedder = Linear(config.joint_attention_dim,
config.caption_projection_dim,
tp_group=self.mapping.tp_group,
tp_size=self.mapping.tp_size,
dtype=self.dtype)
self.transformer_blocks = ModuleList([
JointTransformerBlock(
dim=self.inner_dim,
num_attention_heads=config.num_attention_heads,
attention_head_dim=config.attention_head_dim,
context_pre_only=(i == config.num_layers - 1),
qk_norm=config.qk_norm,
use_dual_attention=True
if i in config.dual_attention_layers else False,
mapping=self.mapping,
dtype=self.dtype) for i in range(config.num_layers)
])
self.norm_out = AdaLayerNormContinuous(self.inner_dim,
self.inner_dim,
elementwise_affine=False,
eps=1e-6,
dtype=self.dtype)
self.proj_out = Linear(self.inner_dim,
config.patch_size * config.patch_size *
self.out_channels,
bias=True,
tp_group=self.mapping.tp_group,
tp_size=self.mapping.tp_size,
dtype=self.dtype)
self.skip_layers = config.skip_layers
self.use_pretrained_pos_emb = config.use_pretrained_pos_emb
self.config = config
[docs] def forward(self, hidden_states: Tensor, encoder_hidden_states: Optional[Tensor] = None, pooled_projections: Optional[Tensor] = None, timestep: Optional[Tensor] = None, block_controlnet_hidden_states: List[Tensor] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None): height, width = hidden_states.shape[-2:] hidden_states = self.pos_embed( hidden_states) # takes care of adding positional embeddings too. temb = self.time_text_embed(timestep, pooled_projections) encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if self.mapping.cp_size > 1:
hidden_states = chunk(hidden_states,
chunks=self.mapping.cp_size,
dim=1)[self.mapping.cp_rank]
encoder_redundant = encoder_hidden_states.shape[
1] % self.mapping.cp_size
encoder_padding_index = tuple(
[0, 0] * (encoder_hidden_states.ndim() - 2) +
[0, self.mapping.cp_size - encoder_redundant])
if encoder_redundant != 0:
encoder_hidden_states = pad(encoder_hidden_states,
pad=encoder_padding_index)
encoder_hidden_states = chunk(encoder_hidden_states,
chunks=self.mapping.cp_size,
dim=1)[self.mapping.cp_rank]
for index_block, block in enumerate(self.transformer_blocks):
# Skip specified layers
is_skip = True if self.skip_layers is not None and index_block in self.skip_layers else False
if not is_skip:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
if block_controlnet_hidden_states is not None and block.context_pre_only is False:
interval_control = len(self.transformer_blocks) / len(
block_controlnet_hidden_states)
hidden_states = hidden_states + block_controlnet_hidden_states[
int(index_block / interval_control)]
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
if self.mapping.cp_size > 1:
hidden_states = allgather(hidden_states,
group=self.mapping.cp_group,
gather_dim=1)
# unpatchify
patch_size = self.config.patch_size
height = height // patch_size
width = width // patch_size
hidden_states = hidden_states.view(
concat([
shape(hidden_states, 0), height, width, patch_size, patch_size,
self.out_channels
]))
hidden_states = einsum("nhwpqc->nchpwq", [hidden_states])
output = hidden_states.view(
concat([
shape(hidden_states, 0), self.out_channels, height * patch_size,
width * patch_size
]))
output.mark_output("output")
return output
[docs] def prepare_inputs(self, max_batch_size, **kwargs):
def sd3_default_range(max_batch_size):
return [1, max(1, (max_batch_size + 1) // 2), max_batch_size]
default_range = sd3_default_range
prompt_embeds_len = 256 + 77 # [NOTE] tokenizer_max_length = 77; max_sequence_length = 256
hidden_states = Tensor(name='hidden_states',
dtype=self.dtype,
shape=[
-1, self.in_channels,
self.config.sample_size,
self.config.sample_size
],
dim_range=OrderedDict([
('batch_size',
[default_range(max_batch_size)]),
('in_channels', [[self.in_channels] * 3]),
('height', [[self.config.sample_size] * 3]),
('width', [[self.config.sample_size] * 3]),
]))
encoder_hidden_states = Tensor(
name='encoder_hidden_states',
dtype=self.dtype,
shape=[-1, prompt_embeds_len, self.config.joint_attention_dim],
dim_range=OrderedDict([
('batch_size', [default_range(max_batch_size)]),
('txt_len', [[prompt_embeds_len] * 3]),
('joint_attention_dim', [[self.config.joint_attention_dim] * 3
]),
]))
pooled_projections = Tensor(
name='pooled_projections',
dtype=self.dtype,
shape=[-1, self.config.pooled_projection_dim],
dim_range=OrderedDict([
('batch_size', [default_range(max_batch_size)]),
('pooled_projection_dim',
[[self.config.pooled_projection_dim] * 3]),
]))
timestep = Tensor(name='timestep',
dtype=self.dtype,
shape=[-1],
dim_range=OrderedDict([
('batch_size', [default_range(max_batch_size)]),
]))
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"pooled_projections": pooled_projections,
"timestep": timestep,
}
[docs] @classmethod def from_pretrained(cls, pretrained_model_name_or_path: str, dtype='float16', mapping=Mapping(), **kwargs): quant_ckpt_path = kwargs.pop('quant_ckpt_path', None)
from diffusers import StableDiffusion3Pipeline
transformer = StableDiffusion3Pipeline.from_pretrained(
pretrained_model_name_or_path,
torch_dtype=str_dtype_to_torch(dtype)).transformer
config = SD3Transformer2DModelConfig.from_hugging_face_config(
transformer.config, dtype=dtype, mapping=mapping, **kwargs)
hf_model_dir = transformer.config._name_or_path
custom_dict = {}
if quant_ckpt_path is not None:
hf_model_dir = quant_ckpt_path
loader = SD3ModelWeightsLoader(hf_model_dir, custom_dict)
model = cls(config)
loader.generate_tllm_weights(model)
return model
[docs] def load(self, weights, from_pruned=False): required_names = set() for name, param in self.named_parameters(): if self.use_pretrained_pos_emb and 'pos_embed' in name: required_names.add(name) continue if param.is_inited(): continue if name not in weights: # Exemption for embedding sharing if name.endswith('lm_head.weight') and any( k.endswith('vocab_embedding.weight') for k in weights.keys()): continue if name.endswith('lm_head.per_channel_scale') and any( k.endswith('vocab_embedding.per_channel_scale') for k in weights.keys()): continue required_names.add(name)
provided_names = set(weights.keys())
if not required_names.issubset(provided_names):
raise RuntimeError(
f"Required but not provided tensors:{required_names.difference(provided_names)}"
)
if not provided_names.issubset(required_names):
logger.warning(
f"Provided but not required tensors: {provided_names.difference(required_names)}"
)
for name, param in self.named_parameters():
if name in provided_names:
if not from_pruned:
try:
param.value = weights[name]
except Exception as e:
raise RuntimeError(
f"Encounter error '{e}' for parameter '{name}'")
else:
param.set_value_or_dummy(weights[name])
[docs] def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0): raise NotImplementedError()
[docs] def disable_forward_chunking(self): raise NotImplementedError()
@property
def attn_processors(self):
return None
[docs] def set_attn_processor(self, processor): raise NotImplementedError()
[docs] def fuse_qkv_projections(self): raise NotImplementedError()
[docs] def unfuse_qkv_projections(self): raise NotImplementedError()
def _set_gradient_checkpointing(self, module, value=False):
raise NotImplementedError()
class SD3ModelWeightsLoader(ModelWeightsLoader):
def translate_to_external_key(self, tllm_key: str,
tllm_to_externel_key_dict: dict):
"""Convert and load external checkpoint into a TensorRT-LLM model.
"""
trtllm_to_hf_name = {
r"transformer_blocks.(\d+).ff(\w*).net.1.weight":
"transformer_blocks.*.ff*.net.2.weight",
r"transformer_blocks.(\d+).ff(\w*).net.1.bias":
"transformer_blocks.*.ff*.net.2.bias",
}
import re
for k, v in trtllm_to_hf_name.items():
m = re.match(k, tllm_key)
if m is not None:
matched_pos = m.groups()
placeholders = v.count('*')
assert len(matched_pos) == placeholders
for i in range(len(matched_pos)):
v = v.replace('*', matched_pos[i], 1)
return v
return tllm_key