fix: reset grammar state when generation stops

This commit is contained in:
drbh 2024-04-18 17:05:52 +00:00
parent 2d0a7173d4
commit e6259d9fc0
3 changed files with 21 additions and 8 deletions

View File

@ -1221,7 +1221,9 @@ class FlashCausalLM(Model):
# have more than one new token per request with speculative decoding # have more than one new token per request with speculative decoding
for next_token_id in _next_token_ids: for next_token_id in _next_token_ids:
batch.next_token_chooser = ( batch.next_token_chooser = (
batch.next_token_chooser.advance_grammar_single(i, next_token_id) batch.next_token_chooser.advance_grammar_single(
i, next_token_id, stopped
)
) )
# Update values # Update values

View File

@ -593,6 +593,10 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
next_token_id, fsm_grammar_state, self.fsms[index] next_token_id, fsm_grammar_state, self.fsms[index]
) )
def reset_at_index(self, index):
self.fsms[index] = None
return -1
def filter(self, indices): def filter(self, indices):
new_fsms = [] new_fsms = []
for i in indices: for i in indices:

View File

@ -406,15 +406,22 @@ class HeterogeneousNextTokenChooser:
self.fsm_grammar_states = other_new_states self.fsm_grammar_states = other_new_states
return self return self
def advance_grammar_single(self, grammar_state_index: int, next_id: int): def advance_grammar_single(
self, grammar_state_index: int, next_id: int, stopped: bool
):
if self.grammar_processor is not None: if self.grammar_processor is not None:
self.fsm_grammar_states[grammar_state_index] = ( if stopped:
self.grammar_processor.advance_at_index( self.fsm_grammar_states[grammar_state_index] = (
next_id, self.grammar_processor.reset_at_index(grammar_state_index)
self.fsm_grammar_states[grammar_state_index], )
grammar_state_index, else:
self.fsm_grammar_states[grammar_state_index] = (
self.grammar_processor.advance_at_index(
next_id,
self.fsm_grammar_states[grammar_state_index],
grammar_state_index,
)
) )
)
return self return self
def filter(self, indices): def filter(self, indices):