fix: reset grammar state when generation stops
This commit is contained in:
parent
2d0a7173d4
commit
e6259d9fc0
|
@ -1221,7 +1221,9 @@ class FlashCausalLM(Model):
|
||||||
# have more than one new token per request with speculative decoding
|
# have more than one new token per request with speculative decoding
|
||||||
for next_token_id in _next_token_ids:
|
for next_token_id in _next_token_ids:
|
||||||
batch.next_token_chooser = (
|
batch.next_token_chooser = (
|
||||||
batch.next_token_chooser.advance_grammar_single(i, next_token_id)
|
batch.next_token_chooser.advance_grammar_single(
|
||||||
|
i, next_token_id, stopped
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
|
|
|
@ -593,6 +593,10 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
||||||
next_token_id, fsm_grammar_state, self.fsms[index]
|
next_token_id, fsm_grammar_state, self.fsms[index]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def reset_at_index(self, index):
|
||||||
|
self.fsms[index] = None
|
||||||
|
return -1
|
||||||
|
|
||||||
def filter(self, indices):
|
def filter(self, indices):
|
||||||
new_fsms = []
|
new_fsms = []
|
||||||
for i in indices:
|
for i in indices:
|
||||||
|
|
|
@ -406,15 +406,22 @@ class HeterogeneousNextTokenChooser:
|
||||||
self.fsm_grammar_states = other_new_states
|
self.fsm_grammar_states = other_new_states
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def advance_grammar_single(self, grammar_state_index: int, next_id: int):
|
def advance_grammar_single(
|
||||||
|
self, grammar_state_index: int, next_id: int, stopped: bool
|
||||||
|
):
|
||||||
if self.grammar_processor is not None:
|
if self.grammar_processor is not None:
|
||||||
self.fsm_grammar_states[grammar_state_index] = (
|
if stopped:
|
||||||
self.grammar_processor.advance_at_index(
|
self.fsm_grammar_states[grammar_state_index] = (
|
||||||
next_id,
|
self.grammar_processor.reset_at_index(grammar_state_index)
|
||||||
self.fsm_grammar_states[grammar_state_index],
|
)
|
||||||
grammar_state_index,
|
else:
|
||||||
|
self.fsm_grammar_states[grammar_state_index] = (
|
||||||
|
self.grammar_processor.advance_at_index(
|
||||||
|
next_id,
|
||||||
|
self.fsm_grammar_states[grammar_state_index],
|
||||||
|
grammar_state_index,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def filter(self, indices):
|
def filter(self, indices):
|
||||||
|
|
Loading…
Reference in New Issue