Fix llava conv template for llama3 · EvolvingLMMs-Lab/lmms-eval@fa3ff92 (original) (raw)

Original file line number Diff line number Diff line change
@@ -223,7 +223,11 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
223 223 image_tokens = " ".join(image_tokens)
224 224 prompts_input = image_tokens + "\n" + (contexts[0] if isinstance(contexts, list) else contexts)
225 225
226 -conv = conv_templates[self.conv_template].copy()
226 +# This is much safer for llama3, as we now have some object type in it
227 +if "llama_3" in self.conv_template:
228 +conv = copy.deepcopy(conv_templates[self.conv_template])
229 +else:
230 +conv = conv_templates[self.conv_template].copy()
227 231 conv.append_message(conv.roles[0], prompts_input)
228 232 conv.append_message(conv.roles[1], None)
229 233 prompt = conv.get_prompt()
@@ -331,7 +335,11 @@ def _collate(x):
331 335 else:
332 336 question = context
333 337
334 -conv = conv_templates[self.conv_template].copy()
338 +# This is much safer for llama3, as we now have some object type in it
339 +if "llama_3" in self.conv_template:
340 +conv = copy.deepcopy(conv_templates[self.conv_template])
341 +else:
342 +conv = conv_templates[self.conv_template].copy()
335 343 conv.append_message(conv.roles[0], question)
336 344 conv.append_message(conv.roles[1], None)
337 345 prompt_question = conv.get_prompt()