Check if allowed tokens is None (#2694)
* Upgrade outlines to 0.1.1 * Update for new API * Check if allowed tokens is None --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
parent
b49cff3f07
commit
7bc2c97bd9
|
@ -3976,4 +3976,4 @@ torch = ["torch"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.9,<3.13"
|
||||
content-hash = "5b2536579dd1d4013da43c5260666cab1c6575c8c1dd6ebb1ccfd5dc0f2874fd"
|
||||
content-hash = "40be820ced080c2457b0794ed61fdd5340615f0fe75420985eaaca7b2b6c3968"
|
|
@ -498,9 +498,10 @@ class GrammarLogitProcessor(LogitsProcessor):
|
|||
):
|
||||
if fsm_grammar_state == -1 or self.fsm is None:
|
||||
return logits
|
||||
allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state)
|
||||
allowed_tokens = self.fsm.get_next_instruction(fsm_grammar_state).tokens
|
||||
mask = torch.full_like(logits, -math.inf)
|
||||
mask[:, allowed_tokens] = 0
|
||||
if allowed_tokens is not None:
|
||||
mask[:, allowed_tokens] = 0
|
||||
biased_scores = logits + mask
|
||||
return biased_scores
|
||||
|
||||
|
@ -589,7 +590,8 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
|||
if fsm_grammar_states[i] == -1 or fsm is None:
|
||||
continue
|
||||
allowed_tokens = fsm.get_next_instruction(fsm_grammar_states[i]).tokens
|
||||
mask[i, allowed_tokens] = 0
|
||||
if allowed_tokens is not None:
|
||||
mask[i, allowed_tokens] = 0
|
||||
logits[i] += mask[i]
|
||||
return logits
|
||||
|
||||
|
|
Loading…
Reference in New Issue