diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index b4ffb863..6d8cb71a 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -555,6 +555,9 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor): self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer) self.fsms = [] for grammar, grammar_type in zip(grammars, grammar_types): + if len(grammar) == 0: + self.fsms.append(None) + continue fsm = GrammarLogitProcessor._cached_compile_fsm( grammar_type, grammar, self.tokenizer ) @@ -572,7 +575,7 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor): continue allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i]) mask[i, allowed_tokens] = 0 - logits += mask + logits[i] += mask[i] return logits def advance_batch(self, next_token_ids, fsm_grammar_states): @@ -584,6 +587,8 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor): ] def advance_at_index(self, next_token_id, fsm_grammar_state, index): + if self.fsms[index] is None: + return fsm_grammar_state return GrammarLogitProcessor._advance( next_token_id, fsm_grammar_state, self.fsms[index] )