diff --git a/server/poetry.lock b/server/poetry.lock index 4060842d..89231d21 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -3976,4 +3976,4 @@ torch = ["torch"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "5b2536579dd1d4013da43c5260666cab1c6575c8c1dd6ebb1ccfd5dc0f2874fd" +content-hash = "40be820ced080c2457b0794ed61fdd5340615f0fe75420985eaaca7b2b6c3968" \ No newline at end of file diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index e79a4bb0..ec2813a1 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -498,9 +498,10 @@ class GrammarLogitProcessor(LogitsProcessor): ): if fsm_grammar_state == -1 or self.fsm is None: return logits - allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state) + allowed_tokens = self.fsm.get_next_instruction(fsm_grammar_state).tokens mask = torch.full_like(logits, -math.inf) - mask[:, allowed_tokens] = 0 + if allowed_tokens is not None: + mask[:, allowed_tokens] = 0 biased_scores = logits + mask return biased_scores @@ -589,7 +590,8 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor): if fsm_grammar_states[i] == -1 or fsm is None: continue allowed_tokens = fsm.get_next_instruction(fsm_grammar_states[i]).tokens - mask[i, allowed_tokens] = 0 + if allowed_tokens is not None: + mask[i, allowed_tokens] = 0 logits[i] += mask[i] return logits