Check if allowed tokens is None (#2694)

* Upgrade outlines to 0.1.1

* Update for new API

* Check if allowed tokens is None

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
Alex Weston 2024-10-28 00:10:55 -04:00 committed by GitHub
parent b49cff3f07
commit 7bc2c97bd9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 4 deletions

2
server/poetry.lock generated
View File

@ -3976,4 +3976,4 @@ torch = ["torch"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.9,<3.13" python-versions = ">=3.9,<3.13"
content-hash = "5b2536579dd1d4013da43c5260666cab1c6575c8c1dd6ebb1ccfd5dc0f2874fd" content-hash = "40be820ced080c2457b0794ed61fdd5340615f0fe75420985eaaca7b2b6c3968"

View File

@ -498,9 +498,10 @@ class GrammarLogitProcessor(LogitsProcessor):
): ):
if fsm_grammar_state == -1 or self.fsm is None: if fsm_grammar_state == -1 or self.fsm is None:
return logits return logits
allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state) allowed_tokens = self.fsm.get_next_instruction(fsm_grammar_state).tokens
mask = torch.full_like(logits, -math.inf) mask = torch.full_like(logits, -math.inf)
mask[:, allowed_tokens] = 0 if allowed_tokens is not None:
mask[:, allowed_tokens] = 0
biased_scores = logits + mask biased_scores = logits + mask
return biased_scores return biased_scores
@ -589,7 +590,8 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
if fsm_grammar_states[i] == -1 or fsm is None: if fsm_grammar_states[i] == -1 or fsm is None:
continue continue
allowed_tokens = fsm.get_next_instruction(fsm_grammar_states[i]).tokens allowed_tokens = fsm.get_next_instruction(fsm_grammar_states[i]).tokens
mask[i, allowed_tokens] = 0 if allowed_tokens is not None:
mask[i, allowed_tokens] = 0
logits[i] += mask[i] logits[i] += mask[i]
return logits return logits