fix: handle batches with and without grammars (#1676)
This PR correctly handles batches with a mixture of constrained and non constrained generations. Currently if batch contains mixed generations the generation will throw an error because it will incorrectly attempt to constrain a request with an empty grammar. We now handled `None` grammars and only apply the mask if needed Fixes: https://github.com/huggingface/text-generation-inference/issues/1643
This commit is contained in:
parent
818aee37e5
commit
762dbf3f19
|
@ -555,6 +555,9 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
|||
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
|
||||
self.fsms = []
|
||||
for grammar, grammar_type in zip(grammars, grammar_types):
|
||||
if len(grammar) == 0:
|
||||
self.fsms.append(None)
|
||||
continue
|
||||
fsm = GrammarLogitProcessor._cached_compile_fsm(
|
||||
grammar_type, grammar, self.tokenizer
|
||||
)
|
||||
|
@ -572,7 +575,7 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
|||
continue
|
||||
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
|
||||
mask[i, allowed_tokens] = 0
|
||||
logits += mask
|
||||
logits[i] += mask[i]
|
||||
return logits
|
||||
|
||||
def advance_batch(self, next_token_ids, fsm_grammar_states):
|
||||
|
@ -584,6 +587,8 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
|||
]
|
||||
|
||||
def advance_at_index(self, next_token_id, fsm_grammar_state, index):
|
||||
if self.fsms[index] is None:
|
||||
return fsm_grammar_state
|
||||
return GrammarLogitProcessor._advance(
|
||||
next_token_id, fsm_grammar_state, self.fsms[index]
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue