@@ -313,14 +313,15 @@ def _collate(x): |
|
|
313 |
313 |
num_beams=gen_kwargs["num_beams"], |
314 |
314 |
max_new_tokens=gen_kwargs["max_new_tokens"], |
315 |
315 |
use_cache=self.use_cache, |
|
316 |
+pad_token_id=self.tokenizer.eos_token_id, |
316 |
317 |
) |
317 |
318 |
except Exception as e: |
318 |
319 |
eval_logger.error(f"Error {e} in generating") |
319 |
320 |
cont = "" |
320 |
321 |
text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)[0] |
321 |
322 |
if "1.5" in self.pretrained: |
322 |
323 |
text_outputs = text_outputs.split("ASSISTANT:")[-1].strip() |
323 |
|
-elif "1.6" in self.pretrained: |
|
324 |
+elif "mistral" in self.pretrained: |
324 |
325 |
text_outputs = text_outputs.split("[/INST]")[-1].strip() |
325 |
326 |
else: |
326 |
327 |
text_outputs = text_outputs.split("ASSISTANT:")[-1].strip() |