Merge pull request #130 from lscpku/vitatecs · EvolvingLMMs-Lab/lmms-eval@11fd7e3 (original) (raw)

``

1

`+

from decord import VideoReader, cpu

`

``

2

`+

import numpy as np

`

``

3

`+

import os

`

``

4

`+

import sys

`

``

5

`+

import datetime

`

``

6

`+

import lmms_eval.tasks._task_utils.file_utils as file_utils

`

``

7

`+

import json

`

``

8

`+

import logging

`

``

9

`+

import yaml

`

``

10

`+

from pathlib import Path

`

``

11

+

``

12

`+

import requests

`

``

13

`+

import openai

`

``

14

`+

from openai import OpenAI

`

``

15

`+

import time

`

``

16

`+

import ast

`

``

17

`+

from tqdm import tqdm

`

``

18

`+

import random

`

``

19

+

``

20

`+

import re

`

``

21

+

``

22

`+

with open(Path(file).parent / "_default_template_yaml", "r") as f:

`

``

23

`+

raw_data = f.readlines()

`

``

24

`+

safe_data = []

`

``

25

`+

for i, line in enumerate(raw_data):

`

``

26

`+

remove function definition since yaml load cannot handle it

`

``

27

`+

if "!function" not in line:

`

``

28

`+

safe_data.append(line)

`

``

29

+

``

30

`+

config = yaml.safe_load("".join(safe_data))

`

``

31

+

``

32

+

``

33

`+

API_TYPE = os.getenv("API_TYPE", "openai")

`

``

34

+

``

35

`+

if API_TYPE == "openai":

`

``

36

`+

API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions")

`

``

37

`+

API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY")

`

``

38

`+

headers = {

`

``

39

`+

"Authorization": f"Bearer {API_KEY}",

`

``

40

`+

"Content-Type": "application/json",

`

``

41

`+

}

`

``

42

+

``

43

`+

We will unzip all the zip files

`

``

44

`+

To HF HOME cache dir

`

``

45

`+

And load it here

`

``

46

`+

HF_HOME = os.environ["HF_HOME"]

`

``

47

`+

cache_dir = config["dataset_kwargs"]["cache_dir"]

`

``

48

`+

cache_dir = os.path.join(HF_HOME, cache_dir)

`

``

49

+

``

50

`+

eval_logger = logging.getLogger("lmms-eval")

`

``

51

+

``

52

+

``

53

`+

Pass in video path here

`

``

54

`+

Can only work correctly with video llm

`

``

55

`+

def vitatecs_doc_to_visual(doc):

`

``

56

`+

video_path = os.path.join(cache_dir, doc["src_dataset"], doc["video_name"])

`

``

57

`+

if os.path.exists(video_path):

`

``

58

`+

video_path = video_path

`

``

59

`+

else:

`

``

60

`+

sys.exit(f"video path:{video_path} does not exist, please check")

`

``

61

`+

return [video_path]

`

``

62

+

``

63

+

``

64

`+

This is the place where you format your question

`

``

65

`+

def vitatecs_doc_to_text(doc, model_specific_prompt_kwargs=None):

`

``

66

`+

if model_specific_prompt_kwargs is None:

`

``

67

`+

model_specific_prompt_kwargs = {}

`

``

68

`+

pre_prompt = ""

`

``

69

`+

post_prompt = ""

`

``

70

`+

if "pre_prompt" in model_specific_prompt_kwargs:

`

``

71

`+

pre_prompt = model_specific_prompt_kwargs["pre_prompt"]

`

``

72

`+

if "post_prompt" in model_specific_prompt_kwargs:

`

``

73

`+

post_prompt = model_specific_prompt_kwargs["post_prompt"]

`

``

74

+

``

75

`+

question, _, _ = format_question_and_answer(doc)

`

``

76

`+

return f"{pre_prompt}{question}{post_prompt}"

`

``

77

+

``

78

+

``

79

`+

def process_option_for_question(sent):

`

``

80

`+

if not sent.endswith("."):

`

``

81

`+

sent += "."

`

``

82

`+

return sent.capitalize()

`

``

83

+

``

84

+

``

85

`+

def process_option_for_matching(sent):

`

``

86

`+

if sent.endswith("."):

`

``

87

`+

sent = sent[:-1]

`

``

88

`+

return sent.lower()

`

``

89

+

``

90

+

``

91

`+

def format_question_and_answer(doc):

`

``

92

`+

seed = sum(ord(c) for c in doc['caption'] + doc['counterfactual']) % 100

`

``

93

`+

random.seed(seed)

`

``

94

`+

if random.random() > 0.5:

`

``

95

`+

option_a = process_option_for_question(doc['caption'])

`

``

96

`+

option_b = process_option_for_question(doc['counterfactual'])

`

``

97

`+

answer = "(A) " + option_a

`

``

98

`+

else:

`

``

99

`+

option_a = process_option_for_question(doc['counterfactual'])

`

``

100

`+

option_b = process_option_for_question(doc['caption'])

`

``

101

`+

answer = "(B) " + option_b

`

``

102

`+

options = [process_option_for_matching(doc['caption']), process_option_for_matching(doc['counterfactual'])]

`

``

103

+

``

104

`+

question = f"Which of the following best describes the content of the video: \n(A) {option_a} \n(B) {option_b}"

`

``

105

`+

return question, answer, options

`

``

106

+

``

107

+

``

108

`+

def vitatecs_doc_to_answer(doc):

`

``

109

`+

_, answer, _ = format_question_and_answer(doc)

`

``

110

`+

return answer

`

``

111

+

``

112

+

``

113

`+

Process result

`

``

114

`+

def vitatecs_process_results(doc, result):

`

``

115

`+

pred = result[0]

`

``

116

`+

rating = 0

`

``

117

`+

match_success = True

`

``

118

`+

chatgpt_response = None

`

``

119

`+

question, answer, options = format_question_and_answer(doc)

`

``

120

+

``

121

`+

Some hand-crafted matching rules

`

``

122

`+

if options[0] in pred.lower() and options[1] not in pred.lower():

`

``

123

`+

rating = 1

`

``

124

`+

elif options[1] in pred.lower() and options[0] not in pred.lower():

`

``

125

`+

rating = 0

`

``

126

`+

elif pred in ["A", "B"]:

`

``

127

`+

rating = 1 if pred == answer[1] else 0

`

``

128

`+

elif any(pred.startswith(prefix) for prefix in ["A.", "B."]):

`

``

129

`+

rating = 1 if pred.split(".")[0] == answer[1] else 0

`

``

130

`+

elif any(pred.startswith(prefix) for prefix in ["A)", "B)"]):

`

``

131

`+

rating = 1 if pred.split(")")[0] == answer[1] else 0

`

``

132

`+

elif any(pred.startswith(prefix) for prefix in ["(A)", "(B)"]):

`

``

133

`+

rating = 1 if pred.split(")")[1] == answer[1] else 0

`

``

134

`+

else:

`

``

135

`+

Fail to match answer in the video-llm response. Use ChatGPT to evaluate.

`

``

136

`+

match_success = False

`

``

137

+

``

138

`+

base_prompt = """You will receive a caption matching question, the ground-truth answer and the prediction from a question answering (QA) model. Your task is to determine whether QA model prediction is correct, based on the question and ground-truth answer. If the prediction is correct, respond "Correct". If the prediction is incorrect, respond "Incorrect". """

`

``

139

`+

prompt = f"""{base_prompt}\n\nCaption Matching Question: {question}\n\nGround-Truth Answer: {answer}\n\nModel Prediction: {pred}"""

`

``

140

`+

chatgpt_response, rating = get_eval_result(prompt)

`

``

141

+

``

142

`+

if not match_success:

`

``

143

`+

return {

`

``

144

`+

"accuracy": {

`

``

145

`+

"src_dataset": doc["src_dataset"],

`

``

146

`+

"video_id": doc["video_name"],

`

``

147

`+

"question": question,

`

``

148

`+

"gt-answer": answer,

`

``

149

`+

"video-llm-prediction": pred,

`

``

150

`+

"match_success": match_success,

`

``

151

`+

"rating": rating,

`

``

152

`+

"chatgpt_prompt": prompt,

`

``

153

`+

"chatgpt_response": chatgpt_response,

`

``

154

`+

"aspect": doc["aspect"],

`

``

155

`+

},

`

``

156

`+

}

`

``

157

`+

else:

`

``

158

`+

return {

`

``

159

`+

"accuracy": {

`

``

160

`+

"src_dataset": doc["src_dataset"],

`

``

161

`+

"video_id": doc["video_name"],

`

``

162

`+

"question": question,

`

``

163

`+

"gt-answer": answer,

`

``

164

`+

"video-llm-prediction": pred,

`

``

165

`+

"match_success": match_success,

`

``

166

`+

"rating": rating,

`

``

167

`+

"aspect": doc["aspect"],

`

``

168

`+

},

`

``

169

`+

}

`

``

170

+

``

171

+

``

172

`+

utils function for gpt_evaluation when rule-based matching is unsuccessful

`

``

173

`+

def get_eval_result(prompt, maxtry=10, sys_prompt=None):

`

``

174

`+

llm_output = None

`

``

175

`+

while True:

`

``

176

`+

try:

`

``

177

`+

llm_output = get_llm_output(prompt, sys_prompt)

`

``

178

`+

rating = llm_output_to_rating(llm_output)

`

``

179

`+

return llm_output, rating

`

``

180

`+

except:

`

``

181

`+

if maxtry <= 0:

`

``

182

`+

return llm_output, 0

`

``

183

`+

maxtry -= 1

`

``

184

`+

print(f"Not success! {maxtry} retries remaining...")

`

``

185

`+

time.sleep(random.uniform(1, 2))

`

``

186

+

``

187

+

``

188

`+

utils function for gpt evaluation

`

``

189

`+

def get_llm_output(prompt, sys_prompt, max_tokens=128):

`

``

190

`+

if sys_prompt is None:

`

``

191

`+

sys_prompt = "You are an AI assistant for question answering."

`

``

192

`+

data = {"max_tokens": max_tokens, "model": "gpt-3.5-turbo-1106", "temperature": 1.0, "top_p": 1, "presence_penalty": 1, "messages": [{"role": "system", "content": sys_prompt}, {"role": "user", "content": prompt}]}

`

``

193

`+

response = requests.post(API_URL, headers=headers, data=json.dumps(data).encode("utf-8"))

`

``

194

`+

result = response.content.decode("utf-8")

`

``

195

`+

dict_result = json.loads(result)

`

``

196

`+

llm_output = dict_result["choices"][0]["message"]["content"].strip()

`

``

197

`+

return llm_output

`

``

198

+

``

199

+

``

200

`+

utils function that converts gpt evaluation into rating

`

``

201

`+

def llm_output_to_rating(llm_output):

`

``

202

`+

assert "Correct" in llm_output or "Incorrect" in llm_output

`

``

203

`+

if llm_output.startswith("Correct"):

`

``

204

`+

rating = 1

`

``

205

`+

elif llm_output.startswith("Incorrect"):

`

``

206

`+

rating = 0

`

``

207

`+

elif ("Correct" in llm_output) and ("Incorrect" not in llm_output):

`

``

208

`+

rating = 1

`

``

209

`+

elif "Incorrect" in llm_output:

`

``

210

`+

rating = 0

`

``

211

`+

return rating

`

``

212

+

``

213

+

``

214

`+

Factory into different aggregate

`

``

215

`+

def vitatecs_aggregate_rating(results, args):

`

``

216

`+

yes_count = 0

`

``

217

+

``

218

`+

results is a list of dict

`

``

219

`+

for answer_dict in results:

`

``

220

`+

if answer_dict["rating"] == 1:

`

``

221

`+

yes_count += 1

`

``

222

+

``

223

`+

accuracy = yes_count / len(results)

`

``

224

+

``

225

`+

return accuracy * 100

`