fix: only keep stop sequence buffer if we have some

This commit is contained in:
OlivierDehaene 2023-12-14 17:04:58 +01:00
parent 80a69204c1
commit 9b78a6eee3
1 changed files with 10 additions and 9 deletions

View File

@ -112,7 +112,7 @@ class StoppingCriteria:
self.stop_sequence_criterias = stop_sequence_criterias self.stop_sequence_criterias = stop_sequence_criterias
self.max_new_tokens = max_new_tokens self.max_new_tokens = max_new_tokens
self.current_tokens = 0 self.current_tokens = 0
self.current_output = "test" self.current_output = ""
self.ignore_eos_token = ignore_eos_token self.ignore_eos_token = ignore_eos_token
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]: def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
@ -123,14 +123,15 @@ class StoppingCriteria:
if not self.ignore_eos_token and last_token == self.eos_token_id: if not self.ignore_eos_token and last_token == self.eos_token_id:
return True, FinishReason.FINISH_REASON_EOS_TOKEN return True, FinishReason.FINISH_REASON_EOS_TOKEN
self.current_output += last_output if self.stop_sequence_criterias:
# There is no need to keep an output that is too long self.current_output += last_output
if len(self.current_output) > 300: # There is no need to keep an output that is too long
# Slice to -200 to avoid doing it all the time if len(self.current_output) > 300:
self.current_output = self.current_output[-200:] # Slice to -200 to avoid doing it all the time
for stop_sequence_criteria in self.stop_sequence_criterias: self.current_output = self.current_output[-200:]
if stop_sequence_criteria(self.current_output): for stop_sequence_criteria in self.stop_sequence_criterias:
return True, FinishReason.FINISH_REASON_STOP_SEQUENCE if stop_sequence_criteria(self.current_output):
return True, FinishReason.FINISH_REASON_STOP_SEQUENCE
return False, None return False, None