fix: reset grammar state when generation stops

This commit is contained in:
drbh 2024-04-18 17:05:52 +00:00
parent 2d0a7173d4
commit e6259d9fc0
3 changed files with 21 additions and 8 deletions

View File

@ -1221,7 +1221,9 @@ class FlashCausalLM(Model):
# have more than one new token per request with speculative decoding
for next_token_id in _next_token_ids:
batch.next_token_chooser = (
batch.next_token_chooser.advance_grammar_single(i, next_token_id)
batch.next_token_chooser.advance_grammar_single(
i, next_token_id, stopped
)
)
# Update values

View File

@ -593,6 +593,10 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
next_token_id, fsm_grammar_state, self.fsms[index]
)
def reset_at_index(self, index):
self.fsms[index] = None
return -1
def filter(self, indices):
new_fsms = []
for i in indices:

View File

@ -406,15 +406,22 @@ class HeterogeneousNextTokenChooser:
self.fsm_grammar_states = other_new_states
return self
def advance_grammar_single(self, grammar_state_index: int, next_id: int):
def advance_grammar_single(
self, grammar_state_index: int, next_id: int, stopped: bool
):
if self.grammar_processor is not None:
self.fsm_grammar_states[grammar_state_index] = (
self.grammar_processor.advance_at_index(
next_id,
self.fsm_grammar_states[grammar_state_index],
grammar_state_index,
if stopped:
self.fsm_grammar_states[grammar_state_index] = (
self.grammar_processor.reset_at_index(grammar_state_index)
)
else:
self.fsm_grammar_states[grammar_state_index] = (
self.grammar_processor.advance_at_index(
next_id,
self.fsm_grammar_states[grammar_state_index],
grammar_state_index,
)
)
)
return self
def filter(self, indices):