Merge pull request #54 from EvolvingLMMs-Lab/add_llava_sglang · MichalCiesiolka/lmms-eval-llmzszl@95df9fe (original) (raw)
``
1
`+
import torch
`
``
2
+
``
3
`+
torch.backends.cuda.matmul.allow_tf32 = True
`
``
4
+
``
5
`+
import logging
`
``
6
`+
from tqdm import tqdm
`
``
7
`+
from datetime import timedelta
`
``
8
+
``
9
`+
from lmms_eval import utils
`
``
10
`+
from lmms_eval.api.instance import Instance
`
``
11
`+
from lmms_eval.api.model import lmms
`
``
12
`+
from lmms_eval.api.registry import register_model
`
``
13
+
``
14
`+
from accelerate import Accelerator, InitProcessGroupKwargs
`
``
15
`+
from typing import List, Optional, Union, Tuple
`
``
16
`+
import warnings
`
``
17
+
``
18
`+
warnings.filterwarnings("ignore")
`
``
19
`+
from concurrent.futures import ThreadPoolExecutor, as_completed
`
``
20
`+
import tempfile
`
``
21
+
``
22
`+
eval_logger = logging.getLogger("lmms-eval")
`
``
23
+
``
24
`+
try:
`
``
25
`+
import sglang as sgl
`
``
26
`+
from sglang.lang.chat_template import get_chat_template
`
``
27
`+
except ImportError:
`
``
28
`+
eval_logger.error("SGLang is not installed. If you want to use llava_sglang, please install it using pip install 'sglang[all]' ")
`
``
29
+
``
30
`+
if torch.version > "2.1.2":
`
``
31
`+
best_fit_attn_implementation = "sdpa"
`
``
32
`+
else:
`
``
33
`+
best_fit_attn_implementation = "eager"
`
``
34
+
``
35
+
``
36
`+
@register_model("llava_sglang")
`
``
37
`+
class LlavaSglang(lmms):
`
``
38
`+
"""
`
``
39
`+
Llava Sglang Model
`
``
40
`+
"""
`
``
41
+
``
42
`+
def init(
`
``
43
`+
self,
`
``
44
`+
pretrained: str = "liuhaotian/llava-v1.5-7b",
`
``
45
`+
tokenizer: str = "llava-hf/llava-1.5-7b-hf",
`
``
46
`+
tp_size: int = 1,
`
``
47
`+
parallel: Optional[Union[int, str]] = 64,
`
``
48
`+
conv_template="vicuna_v1.1",
`
``
49
`+
**kwargs,
`
``
50
`+
) -> None:
`
``
51
`+
super().init()
`
``
52
`+
self.pretrained = pretrained
`
``
53
`+
self.tokenizer = tokenizer
`
``
54
`+
self.tp_size = tp_size
`
``
55
`+
self.conv_template = conv_template
`
``
56
`+
torch.multiprocessing.set_start_method("spawn")
`
``
57
+
``
58
`+
accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
`
``
59
`+
accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
`
``
60
`+
assert accelerator.num_processes == 1, "Llava-sglang does not support multi-processes yet (it does support tensor parallelism)."
`
``
61
`+
self._rank = 0
`
``
62
`+
self._world_size = 1
`
``
63
`+
self.parallel = parallel
`
``
64
+
``
65
`+
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
`
``
66
`+
raise NotImplementedError("Llava-sglang does not support loglikelihood evaluation yet")
`
``
67
+
``
68
`+
def generate_until(self, requests: List[Instance]) -> List[str]:
`
``
69
+
``
70
`+
runtime = sgl.Runtime(model_path=self.pretrained, tokenizer_path=self.tokenizer, tp_size=self.tp_size)
`
``
71
`+
runtime.endpoint.chat_template = get_chat_template(self.conv_template)
`
``
72
`+
sgl.set_default_backend(runtime)
`
``
73
+
``
74
`+
@sgl.function
`
``
75
`+
def image_qa(s, image_file, question):
`
``
76
`+
s += sgl.user(sgl.image(image_file) + question)
`
``
77
`+
s += sgl.assistant(sgl.gen("answer"))
`
``
78
+
``
79
`+
res = []
`
``
80
+
``
81
`+
def _collate(x):
`
``
82
`+
the negative sign on len(toks) sorts descending - this has a few advantages:
`
``
83
`+
- time estimates will always be over not underestimates, which is more useful for planning
`
``
84
`+
- to know the size of a batch when going through the list, you know the first one is always the batch
`
``
85
`+
padded context length. this is useful to simplify the batching logic and more importantly to make
`
``
86
`+
automatic adaptive batches much much easier to implement
`
``
87
`+
- any OOMs will happen right away rather than near the end
`
``
88
`+
toks = x[0].split(" ")
`
``
89
`+
return -len(toks), x[0]
`
``
90
+
``
91
`+
we group requests by their generation_kwargs,
`
``
92
`+
so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
`
``
93
`+
in the same batch.
`
``
94
`+
re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True)
`
``
95
`+
chunks = re_ords.get_batched(n=self.parallel, batch_fn=None)
`
``
96
`+
num_iters = len(requests) // self.parallel if len(requests) % self.parallel == 0 else len(requests) // self.parallel + 1
`
``
97
`+
pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding")
`
``
98
`+
for chunk in chunks:
`
``
99
`+
contexts, all_gen_kwargs, doc_to_visuals, doc_id, tasks, splits = zip(*chunk)
`
``
100
`+
batched_visuals = [doc_to_visual(self.task_dict[task][split][ids]) for ids, task, split, doc_to_visual in zip(doc_id, tasks, splits, doc_to_visuals)] # [B, N]
`
``
101
`+
we assume all gen kwargs in the batch are the same
`
``
102
`` +
this is safe to assume because the grouper
object ensures it.
``
``
103
`+
gen_kwargs = all_gen_kwargs[0]
`
``
104
`+
if "max_new_tokens" not in gen_kwargs:
`
``
105
`+
gen_kwargs["max_new_tokens"] = 1024
`
``
106
`+
if "temperature" not in gen_kwargs:
`
``
107
`+
gen_kwargs["temperature"] = 0
`
``
108
`+
if "top_p" not in gen_kwargs:
`
``
109
`+
gen_kwargs["top_p"] = 1.0
`
``
110
`+
if "num_beams" not in gen_kwargs:
`
``
111
`+
gen_kwargs["num_beams"] = 1
`
``
112
`+
if gen_kwargs["top_p"] == 0.0:
`
``
113
`+
gen_kwargs["top_p"] = 1.0
`
``
114
`+
gen_kwargs["temperature"] = 0.0
`
``
115
`+
assert gen_kwargs["num_beams"] == 1
`
``
116
+
``
117
`+
def save_image_to_temp_file(image):
`
``
118
`+
temp_file = tempfile.NamedTemporaryFile(suffix=".jpeg", delete=True)
`
``
119
`+
image.save(temp_file.name)
`
``
120
`+
return temp_file
`
``
121
+
``
122
`+
def prepare_arguments_parallel(contexts, batched_visuals, max_workers=64):
`
``
123
`+
arguments = [None] * len(contexts) # Initialize with placeholders
`
``
124
`+
tmp_files = [None] * len(contexts) # Initialize with placeholders
`
``
125
+
``
126
`+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
`
``
127
`+
Associate each future with its index and content
`
``
128
`+
future_to_info = {executor.submit(save_image_to_temp_file, pil_list[0]): (index, context, pil_list) for index, (context, pil_list) in enumerate(zip(contexts, batched_visuals))}
`
``
129
+
``
130
`+
for future in as_completed(future_to_info):
`
``
131
`+
index, context, pil_list = future_to_info[future]
`
``
132
`+
if len(pil_list) > 1:
`
``
133
`+
eval_logger.warning("Llava-sglang only supports one visual input per question. Using the first visual input.")
`
``
134
`+
try:
`
``
135
`+
temp_file = future.result()
`
``
136
`+
arguments[index] = {
`
``
137
`+
"image_file": temp_file.name,
`
``
138
`+
"question": context,
`
``
139
`+
}
`
``
140
`+
tmp_files[index] = temp_file
`
``
141
`+
except Exception as exc:
`
``
142
`+
print(f"Generated an exception: {exc}")
`
``
143
+
``
144
`+
Filter out any None values in case of exceptions
`
``
145
`+
arguments = [arg for arg in arguments if arg is not None]
`
``
146
`+
tmp_files = [tmp_file for tmp_file in tmp_files if tmp_file is not None]
`
``
147
+
``
148
`+
return arguments, tmp_files
`
``
149
+
``
150
`+
arguments, tmp_files = prepare_arguments_parallel(contexts, batched_visuals, self.parallel)
`
``
151
`+
states = image_qa.run_batch(arguments, temperature=gen_kwargs["temperature"], max_new_tokens=gen_kwargs["max_new_tokens"], top_p=gen_kwargs["top_p"], num_threads=self.parallel, progress_bar=False)
`
``
152
+
``
153
`+
text_outputs = [state["answer"].strip() for state in states]
`
``
154
`+
clean up the temporary files
`
``
155
`+
for tmp_file in tmp_files:
`
``
156
`+
tmp_file.close()
`
``
157
`+
res.extend(text_outputs)
`
``
158
`+
pbar.update(1)
`
``
159
`+
reorder this group of results back to original unsorted form
`
``
160
`+
res = re_ords.get_original(res)
`
``
161
+
``
162
`+
pbar.close()
`
``
163
`+
runtime.shutdown()
`
``
164
`+
return res
`