diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 2c440083..24f3e243 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1221,7 +1221,9 @@ class FlashCausalLM(Model): # have more than one new token per request with speculative decoding for next_token_id in _next_token_ids: batch.next_token_chooser = ( - batch.next_token_chooser.advance_grammar_single(i, next_token_id) + batch.next_token_chooser.advance_grammar_single( + i, next_token_id, stopped + ) ) # Update values diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 6d8cb71a..79f8e6d4 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -593,6 +593,10 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor): next_token_id, fsm_grammar_state, self.fsms[index] ) + def reset_at_index(self, index): + self.fsms[index] = None + return -1 + def filter(self, indices): new_fsms = [] for i in indices: diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 7c8a18f0..e3f6e1d4 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -406,15 +406,22 @@ class HeterogeneousNextTokenChooser: self.fsm_grammar_states = other_new_states return self - def advance_grammar_single(self, grammar_state_index: int, next_id: int): + def advance_grammar_single( + self, grammar_state_index: int, next_id: int, stopped: bool + ): if self.grammar_processor is not None: - self.fsm_grammar_states[grammar_state_index] = ( - self.grammar_processor.advance_at_index( - next_id, - self.fsm_grammar_states[grammar_state_index], - grammar_state_index, + if stopped: + self.fsm_grammar_states[grammar_state_index] = ( + self.grammar_processor.reset_at_index(grammar_state_index) + ) + else: + self.fsm_grammar_states[grammar_state_index] = ( + self.grammar_processor.advance_at_index( + next_id, + self.fsm_grammar_states[grammar_state_index], + grammar_state_index, + ) ) - ) return self def filter(self, indices):