Check if allowed tokens is None (#2694)
* Upgrade outlines to 0.1.1 * Update for new API * Check if allowed tokens is None --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
parent
b49cff3f07
commit
7bc2c97bd9
|
@ -3976,4 +3976,4 @@ torch = ["torch"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.9,<3.13"
|
python-versions = ">=3.9,<3.13"
|
||||||
content-hash = "5b2536579dd1d4013da43c5260666cab1c6575c8c1dd6ebb1ccfd5dc0f2874fd"
|
content-hash = "40be820ced080c2457b0794ed61fdd5340615f0fe75420985eaaca7b2b6c3968"
|
|
@ -498,8 +498,9 @@ class GrammarLogitProcessor(LogitsProcessor):
|
||||||
):
|
):
|
||||||
if fsm_grammar_state == -1 or self.fsm is None:
|
if fsm_grammar_state == -1 or self.fsm is None:
|
||||||
return logits
|
return logits
|
||||||
allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state)
|
allowed_tokens = self.fsm.get_next_instruction(fsm_grammar_state).tokens
|
||||||
mask = torch.full_like(logits, -math.inf)
|
mask = torch.full_like(logits, -math.inf)
|
||||||
|
if allowed_tokens is not None:
|
||||||
mask[:, allowed_tokens] = 0
|
mask[:, allowed_tokens] = 0
|
||||||
biased_scores = logits + mask
|
biased_scores = logits + mask
|
||||||
return biased_scores
|
return biased_scores
|
||||||
|
@ -589,6 +590,7 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
||||||
if fsm_grammar_states[i] == -1 or fsm is None:
|
if fsm_grammar_states[i] == -1 or fsm is None:
|
||||||
continue
|
continue
|
||||||
allowed_tokens = fsm.get_next_instruction(fsm_grammar_states[i]).tokens
|
allowed_tokens = fsm.get_next_instruction(fsm_grammar_states[i]).tokens
|
||||||
|
if allowed_tokens is not None:
|
||||||
mask[i, allowed_tokens] = 0
|
mask[i, allowed_tokens] = 0
|
||||||
logits[i] += mask[i]
|
logits[i] += mask[i]
|
||||||
return logits
|
return logits
|
||||||
|
|
Loading…
Reference in New Issue