From 762dbf3f198a9fbe7edffd60edc909e60e66878a Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 28 Mar 2024 12:02:01 -0400 Subject: [PATCH] fix: handle batches with and without grammars (#1676) This PR correctly handles batches with a mixture of constrained and non constrained generations. Currently if batch contains mixed generations the generation will throw an error because it will incorrectly attempt to constrain a request with an empty grammar. We now handled `None` grammars and only apply the mask if needed Fixes: https://github.com/huggingface/text-generation-inference/issues/1643 --- server/text_generation_server/utils/logits_process.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index b4ffb863..6d8cb71a 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -555,6 +555,9 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor): self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer) self.fsms = [] for grammar, grammar_type in zip(grammars, grammar_types): + if len(grammar) == 0: + self.fsms.append(None) + continue fsm = GrammarLogitProcessor._cached_compile_fsm( grammar_type, grammar, self.tokenizer ) @@ -572,7 +575,7 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor): continue allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i]) mask[i, allowed_tokens] = 0 - logits += mask + logits[i] += mask[i] return logits def advance_batch(self, next_token_ids, fsm_grammar_states): @@ -584,6 +587,8 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor): ] def advance_at_index(self, next_token_id, fsm_grammar_state, index): + if self.fsms[index] is None: + return fsm_grammar_state return GrammarLogitProcessor._advance( next_token_id, fsm_grammar_state, self.fsms[index] )