Merge pull request #54 from EvolvingLMMs-Lab/add_llava_sglang · MichalCiesiolka/lmms-eval-llmzszl@95df9fe (original) (raw)

``

1

`+

import torch

`

``

2

+

``

3

`+

torch.backends.cuda.matmul.allow_tf32 = True

`

``

4

+

``

5

`+

import logging

`

``

6

`+

from tqdm import tqdm

`

``

7

`+

from datetime import timedelta

`

``

8

+

``

9

`+

from lmms_eval import utils

`

``

10

`+

from lmms_eval.api.instance import Instance

`

``

11

`+

from lmms_eval.api.model import lmms

`

``

12

`+

from lmms_eval.api.registry import register_model

`

``

13

+

``

14

`+

from accelerate import Accelerator, InitProcessGroupKwargs

`

``

15

`+

from typing import List, Optional, Union, Tuple

`

``

16

`+

import warnings

`

``

17

+

``

18

`+

warnings.filterwarnings("ignore")

`

``

19

`+

from concurrent.futures import ThreadPoolExecutor, as_completed

`

``

20

`+

import tempfile

`

``

21

+

``

22

`+

eval_logger = logging.getLogger("lmms-eval")

`

``

23

+

``

24

`+

try:

`

``

25

`+

import sglang as sgl

`

``

26

`+

from sglang.lang.chat_template import get_chat_template

`

``

27

`+

except ImportError:

`

``

28

`+

eval_logger.error("SGLang is not installed. If you want to use llava_sglang, please install it using pip install 'sglang[all]' ")

`

``

29

+

``

30

`+

if torch.version > "2.1.2":

`

``

31

`+

best_fit_attn_implementation = "sdpa"

`

``

32

`+

else:

`

``

33

`+

best_fit_attn_implementation = "eager"

`

``

34

+

``

35

+

``

36

`+

@register_model("llava_sglang")

`

``

37

`+

class LlavaSglang(lmms):

`

``

38

`+

"""

`

``

39

`+

Llava Sglang Model

`

``

40

`+

"""

`

``

41

+

``

42

`+

def init(

`

``

43

`+

self,

`

``

44

`+

pretrained: str = "liuhaotian/llava-v1.5-7b",

`

``

45

`+

tokenizer: str = "llava-hf/llava-1.5-7b-hf",

`

``

46

`+

tp_size: int = 1,

`

``

47

`+

parallel: Optional[Union[int, str]] = 64,

`

``

48

`+

conv_template="vicuna_v1.1",

`

``

49

`+

**kwargs,

`

``

50

`+

) -> None:

`

``

51

`+

super().init()

`

``

52

`+

self.pretrained = pretrained

`

``

53

`+

self.tokenizer = tokenizer

`

``

54

`+

self.tp_size = tp_size

`

``

55

`+

self.conv_template = conv_template

`

``

56

`+

torch.multiprocessing.set_start_method("spawn")

`

``

57

+

``

58

`+

accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))

`

``

59

`+

accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])

`

``

60

`+

assert accelerator.num_processes == 1, "Llava-sglang does not support multi-processes yet (it does support tensor parallelism)."

`

``

61

`+

self._rank = 0

`

``

62

`+

self._world_size = 1

`

``

63

`+

self.parallel = parallel

`

``

64

+

``

65

`+

def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:

`

``

66

`+

raise NotImplementedError("Llava-sglang does not support loglikelihood evaluation yet")

`

``

67

+

``

68

`+

def generate_until(self, requests: List[Instance]) -> List[str]:

`

``

69

+

``

70

`+

runtime = sgl.Runtime(model_path=self.pretrained, tokenizer_path=self.tokenizer, tp_size=self.tp_size)

`

``

71

`+

runtime.endpoint.chat_template = get_chat_template(self.conv_template)

`

``

72

`+

sgl.set_default_backend(runtime)

`

``

73

+

``

74

`+

@sgl.function

`

``

75

`+

def image_qa(s, image_file, question):

`

``

76

`+

s += sgl.user(sgl.image(image_file) + question)

`

``

77

`+

s += sgl.assistant(sgl.gen("answer"))

`

``

78

+

``

79

`+

res = []

`

``

80

+

``

81

`+

def _collate(x):

`

``

82

`+

the negative sign on len(toks) sorts descending - this has a few advantages:

`

``

83

`+

- time estimates will always be over not underestimates, which is more useful for planning

`

``

84

`+

- to know the size of a batch when going through the list, you know the first one is always the batch

`

``

85

`+

padded context length. this is useful to simplify the batching logic and more importantly to make

`

``

86

`+

automatic adaptive batches much much easier to implement

`

``

87

`+

- any OOMs will happen right away rather than near the end

`

``

88

`+

toks = x[0].split(" ")

`

``

89

`+

return -len(toks), x[0]

`

``

90

+

``

91

`+

we group requests by their generation_kwargs,

`

``

92

`+

so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling

`

``

93

`+

in the same batch.

`

``

94

`+

re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True)

`

``

95

`+

chunks = re_ords.get_batched(n=self.parallel, batch_fn=None)

`

``

96

`+

num_iters = len(requests) // self.parallel if len(requests) % self.parallel == 0 else len(requests) // self.parallel + 1

`

``

97

`+

pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding")

`

``

98

`+

for chunk in chunks:

`

``

99

`+

contexts, all_gen_kwargs, doc_to_visuals, doc_id, tasks, splits = zip(*chunk)

`

``

100

`+

batched_visuals = [doc_to_visual(self.task_dict[task][split][ids]) for ids, task, split, doc_to_visual in zip(doc_id, tasks, splits, doc_to_visuals)] # [B, N]

`

``

101

`+

we assume all gen kwargs in the batch are the same

`

``

102

`` +

this is safe to assume because the grouper object ensures it.

``

``

103

`+

gen_kwargs = all_gen_kwargs[0]

`

``

104

`+

if "max_new_tokens" not in gen_kwargs:

`

``

105

`+

gen_kwargs["max_new_tokens"] = 1024

`

``

106

`+

if "temperature" not in gen_kwargs:

`

``

107

`+

gen_kwargs["temperature"] = 0

`

``

108

`+

if "top_p" not in gen_kwargs:

`

``

109

`+

gen_kwargs["top_p"] = 1.0

`

``

110

`+

if "num_beams" not in gen_kwargs:

`

``

111

`+

gen_kwargs["num_beams"] = 1

`

``

112

`+

if gen_kwargs["top_p"] == 0.0:

`

``

113

`+

gen_kwargs["top_p"] = 1.0

`

``

114

`+

gen_kwargs["temperature"] = 0.0

`

``

115

`+

assert gen_kwargs["num_beams"] == 1

`

``

116

+

``

117

`+

def save_image_to_temp_file(image):

`

``

118

`+

temp_file = tempfile.NamedTemporaryFile(suffix=".jpeg", delete=True)

`

``

119

`+

image.save(temp_file.name)

`

``

120

`+

return temp_file

`

``

121

+

``

122

`+

def prepare_arguments_parallel(contexts, batched_visuals, max_workers=64):

`

``

123

`+

arguments = [None] * len(contexts) # Initialize with placeholders

`

``

124

`+

tmp_files = [None] * len(contexts) # Initialize with placeholders

`

``

125

+

``

126

`+

with ThreadPoolExecutor(max_workers=max_workers) as executor:

`

``

127

`+

Associate each future with its index and content

`

``

128

`+

future_to_info = {executor.submit(save_image_to_temp_file, pil_list[0]): (index, context, pil_list) for index, (context, pil_list) in enumerate(zip(contexts, batched_visuals))}

`

``

129

+

``

130

`+

for future in as_completed(future_to_info):

`

``

131

`+

index, context, pil_list = future_to_info[future]

`

``

132

`+

if len(pil_list) > 1:

`

``

133

`+

eval_logger.warning("Llava-sglang only supports one visual input per question. Using the first visual input.")

`

``

134

`+

try:

`

``

135

`+

temp_file = future.result()

`

``

136

`+

arguments[index] = {

`

``

137

`+

"image_file": temp_file.name,

`

``

138

`+

"question": context,

`

``

139

`+

}

`

``

140

`+

tmp_files[index] = temp_file

`

``

141

`+

except Exception as exc:

`

``

142

`+

print(f"Generated an exception: {exc}")

`

``

143

+

``

144

`+

Filter out any None values in case of exceptions

`

``

145

`+

arguments = [arg for arg in arguments if arg is not None]

`

``

146

`+

tmp_files = [tmp_file for tmp_file in tmp_files if tmp_file is not None]

`

``

147

+

``

148

`+

return arguments, tmp_files

`

``

149

+

``

150

`+

arguments, tmp_files = prepare_arguments_parallel(contexts, batched_visuals, self.parallel)

`

``

151

`+

states = image_qa.run_batch(arguments, temperature=gen_kwargs["temperature"], max_new_tokens=gen_kwargs["max_new_tokens"], top_p=gen_kwargs["top_p"], num_threads=self.parallel, progress_bar=False)

`

``

152

+

``

153

`+

text_outputs = [state["answer"].strip() for state in states]

`

``

154

`+

clean up the temporary files

`

``

155

`+

for tmp_file in tmp_files:

`

``

156

`+

tmp_file.close()

`

``

157

`+

res.extend(text_outputs)

`

``

158

`+

pbar.update(1)

`

``

159

`+

reorder this group of results back to original unsorted form

`

``

160

`+

res = re_ords.get_original(res)

`

``

161

+

``

162

`+

pbar.close()

`

``

163

`+

runtime.shutdown()

`

``

164

`+

return res

`