Add model_name parameter to Llava constructor · EvolvingLMMs-Lab/lmms-eval@8aaa828 (original) (raw)

`@@ -57,6 +57,7 @@ def init(

`

57

57

`batch_size: Optional[Union[int, str]] = 1,

`

58

58

`trust_remote_code: Optional[bool] = False,

`

59

59

`revision=None,

`

``

60

`+

model_name=None,

`

60

61

`attn_implementation=best_fit_attn_implementation,

`

61

62

`use_flash_attention_2=True,

`

62

63

`device_map="auto",

`

`@@ -83,8 +84,20 @@ def init(

`

83

84

`llava_model_args["attn_implementation"] = attn_implementation

`

84

85

`if customized_config:

`

85

86

`llava_model_args["customized_config"] = customized_config

`

86

``

`-

llava_model_args["use_flash_attention_2"] = False

`

87

``

`-

self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model(pretrained, None, get_model_name_from_path(pretrained), device_map=self.device_map, **llava_model_args)

`

``

87

`+

if attn_implementation is not None:

`

``

88

`+

llava_model_args["attn_implementation"] = attn_implementation

`

``

89

`+

if "use_flash_attention_2" in kwargs:

`

``

90

`+

llava_model_args["use_flash_attention_2"] = kwargs["use_flash_attention_2"]

`

``

91

+

``

92

`+

model_name = model_name if model_name is not None else get_model_name_from_path(pretrained)

`

``

93

`+

try:

`

``

94

`+

Try to load the model with the multimodal argument

`

``

95

`+

self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model(pretrained, None, model_name, device_map=self.device_map, **llava_model_args)

`

``

96

`+

except TypeError:

`

``

97

`+

for older versions of LLaVA that don't have multimodal argument

`

``

98

`+

llava_model_args.pop("multimodal", None)

`

``

99

`+

self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model(pretrained, None, model_name, device_map=self.device_map, **llava_model_args)

`

``

100

+

88

101

`self._config = self._model.config

`

89

102

`self.model.eval()

`

90

103

`self.model.tie_weights()

`