Fix endless warning for llava_hf generation · dadwadw233/lmms-eval@7fbdaf7 (original) (raw)

Original file line number Diff line number Diff line change
@@ -313,14 +313,15 @@ def _collate(x):
313 313 num_beams=gen_kwargs["num_beams"],
314 314 max_new_tokens=gen_kwargs["max_new_tokens"],
315 315 use_cache=self.use_cache,
316 +pad_token_id=self.tokenizer.eos_token_id,
316 317 )
317 318 except Exception as e:
318 319 eval_logger.error(f"Error {e} in generating")
319 320 cont = ""
320 321 text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)[0]
321 322 if "1.5" in self.pretrained:
322 323 text_outputs = text_outputs.split("ASSISTANT:")[-1].strip()
323 -elif "1.6" in self.pretrained:
324 +elif "mistral" in self.pretrained:
324 325 text_outputs = text_outputs.split("[/INST]")[-1].strip()
325 326 else:
326 327 text_outputs = text_outputs.split("ASSISTANT:")[-1].strip()