fix: reset grammar state when generation stops
This commit is contained in:
parent
2d0a7173d4
commit
e6259d9fc0
|
@ -1221,7 +1221,9 @@ class FlashCausalLM(Model):
|
|||
# have more than one new token per request with speculative decoding
|
||||
for next_token_id in _next_token_ids:
|
||||
batch.next_token_chooser = (
|
||||
batch.next_token_chooser.advance_grammar_single(i, next_token_id)
|
||||
batch.next_token_chooser.advance_grammar_single(
|
||||
i, next_token_id, stopped
|
||||
)
|
||||
)
|
||||
|
||||
# Update values
|
||||
|
|
|
@ -593,6 +593,10 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
|||
next_token_id, fsm_grammar_state, self.fsms[index]
|
||||
)
|
||||
|
||||
def reset_at_index(self, index):
|
||||
self.fsms[index] = None
|
||||
return -1
|
||||
|
||||
def filter(self, indices):
|
||||
new_fsms = []
|
||||
for i in indices:
|
||||
|
|
|
@ -406,15 +406,22 @@ class HeterogeneousNextTokenChooser:
|
|||
self.fsm_grammar_states = other_new_states
|
||||
return self
|
||||
|
||||
def advance_grammar_single(self, grammar_state_index: int, next_id: int):
|
||||
def advance_grammar_single(
|
||||
self, grammar_state_index: int, next_id: int, stopped: bool
|
||||
):
|
||||
if self.grammar_processor is not None:
|
||||
self.fsm_grammar_states[grammar_state_index] = (
|
||||
self.grammar_processor.advance_at_index(
|
||||
next_id,
|
||||
self.fsm_grammar_states[grammar_state_index],
|
||||
grammar_state_index,
|
||||
if stopped:
|
||||
self.fsm_grammar_states[grammar_state_index] = (
|
||||
self.grammar_processor.reset_at_index(grammar_state_index)
|
||||
)
|
||||
else:
|
||||
self.fsm_grammar_states[grammar_state_index] = (
|
||||
self.grammar_processor.advance_at_index(
|
||||
next_id,
|
||||
self.fsm_grammar_states[grammar_state_index],
|
||||
grammar_state_index,
|
||||
)
|
||||
)
|
||||
)
|
||||
return self
|
||||
|
||||
def filter(self, indices):
|
||||
|
|
Loading…
Reference in New Issue