fix: handle batches with and without grammars (#1676)
This PR correctly handles batches with a mixture of constrained and non constrained generations. Currently if batch contains mixed generations the generation will throw an error because it will incorrectly attempt to constrain a request with an empty grammar. We now handled `None` grammars and only apply the mask if needed Fixes: https://github.com/huggingface/text-generation-inference/issues/1643
This commit is contained in:
parent
818aee37e5
commit
762dbf3f19
|
@ -555,6 +555,9 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
||||||
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
|
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
|
||||||
self.fsms = []
|
self.fsms = []
|
||||||
for grammar, grammar_type in zip(grammars, grammar_types):
|
for grammar, grammar_type in zip(grammars, grammar_types):
|
||||||
|
if len(grammar) == 0:
|
||||||
|
self.fsms.append(None)
|
||||||
|
continue
|
||||||
fsm = GrammarLogitProcessor._cached_compile_fsm(
|
fsm = GrammarLogitProcessor._cached_compile_fsm(
|
||||||
grammar_type, grammar, self.tokenizer
|
grammar_type, grammar, self.tokenizer
|
||||||
)
|
)
|
||||||
|
@ -572,7 +575,7 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
||||||
continue
|
continue
|
||||||
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
|
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
|
||||||
mask[i, allowed_tokens] = 0
|
mask[i, allowed_tokens] = 0
|
||||||
logits += mask
|
logits[i] += mask[i]
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def advance_batch(self, next_token_ids, fsm_grammar_states):
|
def advance_batch(self, next_token_ids, fsm_grammar_states):
|
||||||
|
@ -584,6 +587,8 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
||||||
]
|
]
|
||||||
|
|
||||||
def advance_at_index(self, next_token_id, fsm_grammar_state, index):
|
def advance_at_index(self, next_token_id, fsm_grammar_state, index):
|
||||||
|
if self.fsms[index] is None:
|
||||||
|
return fsm_grammar_state
|
||||||
return GrammarLogitProcessor._advance(
|
return GrammarLogitProcessor._advance(
|
||||||
next_token_id, fsm_grammar_state, self.fsms[index]
|
next_token_id, fsm_grammar_state, self.fsms[index]
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue