fix: handle batches with and without grammars (#1676)

This PR correctly handles batches with a mixture of constrained and non
constrained generations.

Currently if batch contains mixed generations the generation will throw
an error because it will incorrectly attempt to constrain a request with
an empty grammar.

We now handled `None` grammars and only apply the mask if needed

Fixes:
https://github.com/huggingface/text-generation-inference/issues/1643
This commit is contained in:
drbh 2024-03-28 12:02:01 -04:00 committed by GitHub
parent 818aee37e5
commit 762dbf3f19
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 6 additions and 1 deletions

View File

@ -555,6 +555,9 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer) self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
self.fsms = [] self.fsms = []
for grammar, grammar_type in zip(grammars, grammar_types): for grammar, grammar_type in zip(grammars, grammar_types):
if len(grammar) == 0:
self.fsms.append(None)
continue
fsm = GrammarLogitProcessor._cached_compile_fsm( fsm = GrammarLogitProcessor._cached_compile_fsm(
grammar_type, grammar, self.tokenizer grammar_type, grammar, self.tokenizer
) )
@ -572,7 +575,7 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
continue continue
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i]) allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
mask[i, allowed_tokens] = 0 mask[i, allowed_tokens] = 0
logits += mask logits[i] += mask[i]
return logits return logits
def advance_batch(self, next_token_ids, fsm_grammar_states): def advance_batch(self, next_token_ids, fsm_grammar_states):
@ -584,6 +587,8 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
] ]
def advance_at_index(self, next_token_id, fsm_grammar_state, index): def advance_at_index(self, next_token_id, fsm_grammar_state, index):
if self.fsms[index] is None:
return fsm_grammar_state
return GrammarLogitProcessor._advance( return GrammarLogitProcessor._advance(
next_token_id, fsm_grammar_state, self.fsms[index] next_token_id, fsm_grammar_state, self.fsms[index]
) )