tensorrt_llm.models.multimodal_encoders.config — TensorRT-LLM (original) (raw)
SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
from typing import Optional, Union
import torch
from ..._utils import torch_dtype_to_str from ...logger import logger from ...mapping import Mapping from ..modeling_utils import PretrainedConfig, QuantConfig
[docs] class LlavaNextVisionConfig(PretrainedConfig):
def __init__(self,
*,
image_size: int,
patch_size: int,
text_hidden_size: int,
projector_hidden_act: str = 'gelu',
num_channels: int = 3,
vision_model_type: str = 'clip_vision_model',
**kwargs):
self.image_size = image_size
self.patch_size = patch_size
self.text_hidden_size = text_hidden_size
self.num_channels = num_channels
self.projector_hidden_act = projector_hidden_act
self.vision_model_type = vision_model_type
super().__init__(**kwargs)
[docs] @classmethod def from_hugging_face( cls, hf_config_or_dir: Union[str, 'transformers.PretrainedConfig'], dtype: str = 'auto', mapping: Optional[Mapping] = None, quant_config: Optional[QuantConfig] = None, **kwargs): import transformers
if isinstance(hf_config_or_dir, transformers.PretrainedConfig):
hf_config = hf_config_or_dir
else:
hf_config_dir = str(hf_config_or_dir)
hf_config = transformers.AutoConfig.from_pretrained(
hf_config_dir, trust_remote_code=True)
if hf_config.model_type == "llava_next":
from transformers import LlavaNextConfig
hf_config = LlavaNextConfig.from_pretrained(hf_config_dir)
else:
logger.error("Provided model type is not llava_next.")
text_hidden_size = hf_config.text_config.hidden_size
# Extract only the vision config
llava_next_vision_config = hf_config.vision_config
# llava-next uses the second last layer as vision output
num_feature_layers = llava_next_vision_config.num_hidden_layers + hf_config.vision_feature_layer + 1
vision_model_type = getattr(llava_next_vision_config,
"vision_model_type", "clip_vision_model")
num_key_value_heads = getattr(
llava_next_vision_config, "num_key_value_heads",
llava_next_vision_config.num_attention_heads)
# Default configs from HF
hidden_act = 'quick_gelu'
norm_epsilon = 1e-5
head_size = llava_next_vision_config.hidden_size // llava_next_vision_config.num_attention_heads
if dtype == 'auto':
dtype = getattr(hf_config, 'torch_dtype', None)
if dtype is None:
dtype = 'float16'
if isinstance(dtype, torch.dtype):
dtype = torch_dtype_to_str(dtype)
if dtype == 'float32':
dtype = 'float16'
return cls(
image_size=llava_next_vision_config.image_size,
patch_size=llava_next_vision_config.patch_size,
text_hidden_size=text_hidden_size,
projector_hidden_act=hf_config.projector_hidden_act,
vision_model_type=vision_model_type,
architecture=hf_config.architectures[0],
dtype=dtype,
num_hidden_layers=num_feature_layers,
num_attention_heads=llava_next_vision_config.num_attention_heads,
hidden_size=llava_next_vision_config.hidden_size,
intermediate_size=llava_next_vision_config.intermediate_size,
num_key_value_heads=num_key_value_heads,
head_size=head_size,
vocab_size=llava_next_vision_config.vocab_size,
hidden_act=hidden_act,
norm_epsilon=norm_epsilon,
mapping=mapping,
quantization=quant_config,
**kwargs)