From 343aa7a1971840df020524dc2b4943ca5b83714a Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 29 Feb 2024 05:17:42 -0500 Subject: [PATCH] fix: Handle concurrent grammar requests (#1610) This PR fixes parallel grammar requests, currently grammar states are not concatenated correctly when a new request is added to the batch and this results in incorrect generation. This PR updates the `concatenate` function to correctly include the previous states. fixes: #1601 --- .../test_grammar_llama/test_flash_llama_grammar_load.json | 8 ++++---- server/text_generation_server/models/flash_causal_lm.py | 3 +++ server/text_generation_server/utils/tokens.py | 5 ++++- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_load.json b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_load.json index b7b26a2c..f6bc6e56 100644 --- a/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_load.json +++ b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_load.json @@ -61,7 +61,7 @@ }, { "id": 29906, - "logprob": -0.2376709, + "logprob": -0.33666992, "special": false, "text": "2" }, @@ -180,7 +180,7 @@ }, { "id": 29906, - "logprob": -0.23840332, + "logprob": -0.33740234, "special": false, "text": "2" }, @@ -299,7 +299,7 @@ }, { "id": 29906, - "logprob": -0.23840332, + "logprob": -0.33740234, "special": false, "text": "2" }, @@ -418,7 +418,7 @@ }, { "id": 29906, - "logprob": -0.23840332, + "logprob": -0.33740234, "special": false, "text": "2" }, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 988637d4..acd97f45 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -530,6 +530,7 @@ class FlashCausalLMBatch(Batch): read_offsets = [] next_token_chooser_parameters = [] + fsm_grammar_states = [] stopping_criterias = [] top_n_tokens = [] @@ -578,6 +579,7 @@ class FlashCausalLMBatch(Batch): read_offsets.extend(batch.read_offsets) next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) + fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states) stopping_criterias.extend(batch.stopping_criterias) top_n_tokens.extend(batch.top_n_tokens) @@ -593,6 +595,7 @@ class FlashCausalLMBatch(Batch): dtype=batches[0].next_token_chooser.dtype, device=batches[0].next_token_chooser.device, tokenizer=batches[0].next_token_chooser.tokenizer, + fsm_grammar_states=fsm_grammar_states, ) speculative_ids = ( diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 32789850..7c8a18f0 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -466,6 +466,7 @@ class HeterogeneousNextTokenChooser: dtype: torch.dtype, device: torch.device, tokenizer: PreTrainedTokenizerBase, + fsm_grammar_states: Optional[List[int]] = None, ) -> "HeterogeneousNextTokenChooser": return HeterogeneousNextTokenChooser( watermark=[pb_.watermark for pb_ in pb], @@ -482,7 +483,9 @@ class HeterogeneousNextTokenChooser: tokenizer=tokenizer, grammars=[pb_.grammar for pb_ in pb], grammar_types=[pb_.grammar_type for pb_ in pb], - fsm_grammar_states=[0] * len(pb), + fsm_grammar_states=( + fsm_grammar_states if fsm_grammar_states else [0] * len(pb) + ), )