diff --git a/benchmark/src/lib.rs b/benchmark/src/lib.rs index 6deae48d..638c6514 100644 --- a/benchmark/src/lib.rs +++ b/benchmark/src/lib.rs @@ -8,7 +8,7 @@ use crate::app::App; use crate::event::Event; use crossterm::ExecutableCommand; use std::io; -use text_generation_client::{NextTokenChooserParameters, ShardedClient}; +use text_generation_client::{GrammarType, NextTokenChooserParameters, ShardedClient}; use tokenizers::Tokenizer; use tokio::sync::{broadcast, mpsc}; use tui::backend::CrosstermBackend; @@ -45,6 +45,8 @@ pub async fn run( repetition_penalty: repetition_penalty.unwrap_or(1.0), frequency_penalty: frequency_penalty.unwrap_or(0.0), watermark, + grammar: String::new(), + grammar_type: GrammarType::None as i32, }; // Initialize terminal properties diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index 0bf80f8c..bbccbf1d 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -10,6 +10,7 @@ from text_generation.types import ( Response, Request, Parameters, + Grammar, ) from text_generation.errors import parse_error @@ -76,6 +77,7 @@ class Client: watermark: bool = False, decoder_input_details: bool = False, top_n_tokens: Optional[int] = None, + grammar: Optional[Grammar] = None, ) -> Response: """ Given a prompt, generate the following text @@ -138,6 +140,7 @@ class Client: watermark=watermark, decoder_input_details=decoder_input_details, top_n_tokens=top_n_tokens, + grammar=grammar, ) request = Request(inputs=prompt, stream=False, parameters=parameters) @@ -169,6 +172,7 @@ class Client: typical_p: Optional[float] = None, watermark: bool = False, top_n_tokens: Optional[int] = None, + grammar: Optional[Grammar] = None, ) -> Iterator[StreamResponse]: """ Given a prompt, generate the following stream of tokens @@ -227,6 +231,7 @@ class Client: typical_p=typical_p, watermark=watermark, top_n_tokens=top_n_tokens, + grammar=grammar, ) request = Request(inputs=prompt, stream=True, parameters=parameters) @@ -326,6 +331,7 @@ class AsyncClient: watermark: bool = False, decoder_input_details: bool = False, top_n_tokens: Optional[int] = None, + grammar: Optional[Grammar] = None, ) -> Response: """ Given a prompt, generate the following text asynchronously @@ -370,6 +376,7 @@ class AsyncClient: Returns: Response: generated response """ + # Validate parameters parameters = Parameters( best_of=best_of, @@ -388,6 +395,7 @@ class AsyncClient: typical_p=typical_p, watermark=watermark, top_n_tokens=top_n_tokens, + grammar=grammar, ) request = Request(inputs=prompt, stream=False, parameters=parameters) @@ -417,6 +425,7 @@ class AsyncClient: typical_p: Optional[float] = None, watermark: bool = False, top_n_tokens: Optional[int] = None, + grammar: Optional[Grammar] = None, ) -> AsyncIterator[StreamResponse]: """ Given a prompt, generate the following stream of tokens asynchronously @@ -475,6 +484,7 @@ class AsyncClient: typical_p=typical_p, watermark=watermark, top_n_tokens=top_n_tokens, + grammar=grammar, ) request = Request(inputs=prompt, stream=True, parameters=parameters) diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index aa02d8d8..3426411b 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -1,10 +1,24 @@ from enum import Enum from pydantic import BaseModel, validator -from typing import Optional, List +from typing import Optional, List, Union from text_generation.errors import ValidationError +# enum for grammar type +class GrammarType(str, Enum): + Json = "json" + Regex = "regex" + + +# Grammar type and value +class Grammar(BaseModel): + # Grammar type + type: GrammarType + # Grammar value + value: Union[str, dict] + + class Parameters(BaseModel): # Activate logits sampling do_sample: bool = False @@ -41,6 +55,8 @@ class Parameters(BaseModel): decoder_input_details: bool = False # Return the N most likely tokens at each step top_n_tokens: Optional[int] = None + # grammar to use for generation + grammar: Optional[Grammar] = None @validator("best_of") def valid_best_of(cls, field_value, values): @@ -109,6 +125,14 @@ class Parameters(BaseModel): raise ValidationError("`top_n_tokens` must be strictly positive") return v + @validator("grammar") + def valid_grammar(cls, v): + if v is not None: + if v.type == GrammarType.Regex and not v.value: + raise ValidationError("`value` cannot be empty for `regex` grammar") + if v.type == GrammarType.Json and not v.value: + raise ValidationError("`value` cannot be empty for `json` grammar") + return v class Request(BaseModel): # Prompt @@ -157,7 +181,7 @@ class Token(BaseModel): # Token text text: str # Logprob - logprob: float + logprob: Optional[float] = None # Is the token a special token # Can be used to ignore tokens when concatenating special: bool diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index be31a7a4..36fa1241 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -378,6 +378,14 @@ Options: [env: TOKENIZER_CONFIG_PATH=] +``` +## DISABLE_GRAMMAR_SUPPORT +```shell + --disable-grammar-support + Disable outlines grammar constrained generation. This is a feature that allows you to generate text that follows a specific grammar + + [env: DISABLE_GRAMMAR_SUPPORT=] + ``` ## ENV ```shell diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index efeda08d..e0228894 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -16,7 +16,14 @@ from syrupy.extensions.json import JSONSnapshotExtension from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError from text_generation import AsyncClient -from text_generation.types import Response, Details, InputToken, Token, BestOfSequence +from text_generation.types import ( + Response, + Details, + InputToken, + Token, + BestOfSequence, + Grammar, +) DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None) HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None) @@ -224,6 +231,7 @@ def launcher(event_loop): quantize: Optional[str] = None, trust_remote_code: bool = False, use_flash_attention: bool = True, + disable_grammar_support: bool = False, dtype: Optional[str] = None, ): port = random.randint(8000, 10_000) @@ -247,6 +255,8 @@ def launcher(event_loop): env = os.environ + if disable_grammar_support: + args.append("--disable-grammar-support") if num_shard is not None: args.extend(["--num-shard", str(num_shard)]) if quantize is not None: @@ -287,12 +297,15 @@ def launcher(event_loop): quantize: Optional[str] = None, trust_remote_code: bool = False, use_flash_attention: bool = True, + disable_grammar_support: bool = False, dtype: Optional[str] = None, ): port = random.randint(8000, 10_000) args = ["--model-id", model_id, "--env"] + if disable_grammar_support: + args.append("--disable-grammar-support") if num_shard is not None: args.extend(["--num-shard", str(num_shard)]) if quantize is not None: @@ -370,11 +383,22 @@ def launcher(event_loop): @pytest.fixture(scope="module") def generate_load(): async def generate_load_inner( - client: AsyncClient, prompt: str, max_new_tokens: int, n: int + client: AsyncClient, + prompt: str, + max_new_tokens: int, + n: int, + seed: Optional[int] = None, + grammar: Optional[Grammar] = None, + stop_sequences: Optional[List[str]] = None, ) -> List[Response]: futures = [ client.generate( - prompt, max_new_tokens=max_new_tokens, decoder_input_details=True + prompt, + max_new_tokens=max_new_tokens, + decoder_input_details=True, + seed=seed, + grammar=grammar, + stop_sequences=stop_sequences, ) for _ in range(n) ] diff --git a/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar.json b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar.json new file mode 100644 index 00000000..0e87f59e --- /dev/null +++ b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -13.90625, + "text": "Test" + }, + { + "id": 2009, + "logprob": -12.328125, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -2.0566406, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -1.5253906, + "special": false, + "text": "\n" + }, + { + "id": 29902, + "logprob": -2.7578125, + "special": false, + "text": "I" + }, + { + "id": 4966, + "logprob": -1.9033203, + "special": false, + "text": " hope" + }, + { + "id": 445, + "logprob": -0.5019531, + "special": false, + "text": " this" + }, + { + "id": 6911, + "logprob": -0.21264648, + "special": false, + "text": " helps" + }, + { + "id": 29991, + "logprob": -0.5991211, + "special": false, + "text": "!" + }, + { + "id": 2803, + "logprob": -0.37475586, + "special": false, + "text": " Let" + }, + { + "id": 592, + "logprob": -0.018463135, + "special": false, + "text": " me" + }, + { + "id": 1073, + "logprob": -0.0008597374, + "special": false, + "text": " know" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nI hope this helps! Let me know" +} diff --git a/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_json.json b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_json.json new file mode 100644 index 00000000..7b12b158 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_json.json @@ -0,0 +1,274 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "eos_token", + "generated_tokens": 30, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 5235, + "logprob": -10.0625, + "text": "info" + }, + { + "id": 29901, + "logprob": -3.2324219, + "text": ":" + }, + { + "id": 13260, + "logprob": -10.625, + "text": "dav" + }, + { + "id": 333, + "logprob": -0.08276367, + "text": "id" + }, + { + "id": 8753, + "logprob": -7.5273438, + "text": "hol" + }, + { + "id": 17559, + "logprob": -3.8476562, + "text": "tz" + }, + { + "id": 763, + "logprob": -10.140625, + "text": "like" + }, + { + "id": 10697, + "logprob": -10.1953125, + "text": "trees" + }, + { + "id": 322, + "logprob": -2.5742188, + "text": "and" + }, + { + "id": 756, + "logprob": -7.4882812, + "text": "has" + }, + { + "id": 1023, + "logprob": -5.0507812, + "text": "two" + }, + { + "id": 274, + "logprob": -5.3164062, + "text": "c" + }, + { + "id": 1446, + "logprob": -0.6694336, + "text": "ats" + }, + { + "id": 29889, + "logprob": -0.9995117, + "text": "." + }, + { + "id": 29871, + "logprob": -4.2421875, + "text": "" + } + ], + "seed": null, + "tokens": [ + { + "id": 6377, + "logprob": -0.14916992, + "special": false, + "text": "{\"" + }, + { + "id": 29888, + "logprob": -0.13598633, + "special": false, + "text": "f" + }, + { + "id": 12935, + "logprob": -0.017669678, + "special": false, + "text": "irs" + }, + { + "id": 29873, + "logprob": -0.00085639954, + "special": false, + "text": "t" + }, + { + "id": 1170, + "logprob": -0.0054016113, + "special": false, + "text": "Name" + }, + { + "id": 4710, + "logprob": -0.13549805, + "special": false, + "text": "\":\"" + }, + { + "id": 19504, + "logprob": -0.8852539, + "special": false, + "text": "David" + }, + { + "id": 3284, + "logprob": -0.16394043, + "special": false, + "text": "\",\"" + }, + { + "id": 4230, + "logprob": -0.020492554, + "special": false, + "text": "last" + }, + { + "id": 1170, + "logprob": -0.0013818741, + "special": false, + "text": "Name" + }, + { + "id": 4710, + "logprob": -0.0067749023, + "special": false, + "text": "\":\"" + }, + { + "id": 29950, + "logprob": -0.11578369, + "special": false, + "text": "H" + }, + { + "id": 14339, + "logprob": -0.004131317, + "special": false, + "text": "olt" + }, + { + "id": 29920, + "logprob": -0.0033359528, + "special": false, + "text": "z" + }, + { + "id": 3284, + "logprob": -0.20471191, + "special": false, + "text": "\",\"" + }, + { + "id": 29882, + "logprob": -0.0069274902, + "special": false, + "text": "h" + }, + { + "id": 20838, + "logprob": -0.19580078, + "special": false, + "text": "obb" + }, + { + "id": 29891, + "logprob": -2.2649765e-06, + "special": false, + "text": "y" + }, + { + "id": 4710, + "logprob": -0.32080078, + "special": false, + "text": "\":\"" + }, + { + "id": 29911, + "logprob": -2.1035156, + "special": false, + "text": "T" + }, + { + "id": 11003, + "logprob": -0.020767212, + "special": false, + "text": "rees" + }, + { + "id": 3284, + "logprob": -0.6010742, + "special": false, + "text": "\",\"" + }, + { + "id": 29876, + "logprob": -0.57666016, + "special": false, + "text": "n" + }, + { + "id": 398, + "logprob": -0.0061073303, + "special": false, + "text": "um" + }, + { + "id": 29907, + "logprob": -0.45703125, + "special": false, + "text": "C" + }, + { + "id": 1446, + "logprob": -0.0002872944, + "special": false, + "text": "ats" + }, + { + "id": 1115, + "logprob": -0.0021018982, + "special": false, + "text": "\":" + }, + { + "id": 29906, + "logprob": -0.08996582, + "special": false, + "text": "2" + }, + { + "id": 29913, + "logprob": -0.021697998, + "special": false, + "text": "}" + }, + { + "id": 2, + "logprob": 0.0, + "special": true, + "text": "" + } + ], + "top_tokens": null + }, + "generated_text": "{\"firstName\":\"David\",\"lastName\":\"Holtz\",\"hobby\":\"Trees\",\"numCats\":2}" +} diff --git a/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_load.json b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_load.json new file mode 100644 index 00000000..b7b26a2c --- /dev/null +++ b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_load.json @@ -0,0 +1,478 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1024, + "logprob": -10.578125, + "text": "name" + }, + { + "id": 29901, + "logprob": -3.03125, + "text": ":" + }, + { + "id": 13260, + "logprob": -9.171875, + "text": "dav" + }, + { + "id": 333, + "logprob": -0.04244995, + "text": "id" + }, + { + "id": 29889, + "logprob": -2.4863281, + "text": "." + }, + { + "id": 4876, + "logprob": -10.7890625, + "text": "email" + }, + { + "id": 29901, + "logprob": -0.32714844, + "text": ":" + }, + { + "id": 259, + "logprob": -9.4921875, + "text": " " + } + ], + "seed": null, + "tokens": [ + { + "id": 29896, + "logprob": -0.7685547, + "special": false, + "text": "1" + }, + { + "id": 29906, + "logprob": -0.2376709, + "special": false, + "text": "2" + }, + { + "id": 29941, + "logprob": -0.01008606, + "special": false, + "text": "3" + }, + { + "id": 29946, + "logprob": -0.64160156, + "special": false, + "text": "4" + }, + { + "id": 29945, + "logprob": -0.5, + "special": false, + "text": "5" + }, + { + "id": 29953, + "logprob": -0.46557617, + "special": false, + "text": "6" + }, + { + "id": 29992, + "logprob": -0.5341797, + "special": false, + "text": "@" + }, + { + "id": 21980, + "logprob": -0.5361328, + "special": false, + "text": "gmail" + }, + { + "id": 29889, + "logprob": -0.00088739395, + "special": false, + "text": "." + }, + { + "id": 510, + "logprob": -0.0022907257, + "special": false, + "text": "com" + } + ], + "top_tokens": null + }, + "generated_text": "123456@gmail.com" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1024, + "logprob": -10.578125, + "text": "name" + }, + { + "id": 29901, + "logprob": -3.0332031, + "text": ":" + }, + { + "id": 13260, + "logprob": -9.171875, + "text": "dav" + }, + { + "id": 333, + "logprob": -0.04257202, + "text": "id" + }, + { + "id": 29889, + "logprob": -2.4785156, + "text": "." + }, + { + "id": 4876, + "logprob": -10.7890625, + "text": "email" + }, + { + "id": 29901, + "logprob": -0.32495117, + "text": ":" + }, + { + "id": 259, + "logprob": -9.4921875, + "text": " " + } + ], + "seed": null, + "tokens": [ + { + "id": 29896, + "logprob": -0.7709961, + "special": false, + "text": "1" + }, + { + "id": 29906, + "logprob": -0.23840332, + "special": false, + "text": "2" + }, + { + "id": 29941, + "logprob": -0.00995636, + "special": false, + "text": "3" + }, + { + "id": 29946, + "logprob": -0.64208984, + "special": false, + "text": "4" + }, + { + "id": 29945, + "logprob": -0.4970703, + "special": false, + "text": "5" + }, + { + "id": 29953, + "logprob": -0.46533203, + "special": false, + "text": "6" + }, + { + "id": 29992, + "logprob": -0.5336914, + "special": false, + "text": "@" + }, + { + "id": 21980, + "logprob": -0.5361328, + "special": false, + "text": "gmail" + }, + { + "id": 29889, + "logprob": -0.00088739395, + "special": false, + "text": "." + }, + { + "id": 510, + "logprob": -0.0022735596, + "special": false, + "text": "com" + } + ], + "top_tokens": null + }, + "generated_text": "123456@gmail.com" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1024, + "logprob": -10.578125, + "text": "name" + }, + { + "id": 29901, + "logprob": -3.0332031, + "text": ":" + }, + { + "id": 13260, + "logprob": -9.171875, + "text": "dav" + }, + { + "id": 333, + "logprob": -0.04257202, + "text": "id" + }, + { + "id": 29889, + "logprob": -2.4785156, + "text": "." + }, + { + "id": 4876, + "logprob": -10.7890625, + "text": "email" + }, + { + "id": 29901, + "logprob": -0.32495117, + "text": ":" + }, + { + "id": 259, + "logprob": -9.4921875, + "text": " " + } + ], + "seed": null, + "tokens": [ + { + "id": 29896, + "logprob": -0.7709961, + "special": false, + "text": "1" + }, + { + "id": 29906, + "logprob": -0.23840332, + "special": false, + "text": "2" + }, + { + "id": 29941, + "logprob": -0.00995636, + "special": false, + "text": "3" + }, + { + "id": 29946, + "logprob": -0.64208984, + "special": false, + "text": "4" + }, + { + "id": 29945, + "logprob": -0.4970703, + "special": false, + "text": "5" + }, + { + "id": 29953, + "logprob": -0.46533203, + "special": false, + "text": "6" + }, + { + "id": 29992, + "logprob": -0.5336914, + "special": false, + "text": "@" + }, + { + "id": 21980, + "logprob": -0.5361328, + "special": false, + "text": "gmail" + }, + { + "id": 29889, + "logprob": -0.00088739395, + "special": false, + "text": "." + }, + { + "id": 510, + "logprob": -0.0022735596, + "special": false, + "text": "com" + } + ], + "top_tokens": null + }, + "generated_text": "123456@gmail.com" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1024, + "logprob": -10.578125, + "text": "name" + }, + { + "id": 29901, + "logprob": -3.0332031, + "text": ":" + }, + { + "id": 13260, + "logprob": -9.171875, + "text": "dav" + }, + { + "id": 333, + "logprob": -0.04257202, + "text": "id" + }, + { + "id": 29889, + "logprob": -2.4785156, + "text": "." + }, + { + "id": 4876, + "logprob": -10.7890625, + "text": "email" + }, + { + "id": 29901, + "logprob": -0.32495117, + "text": ":" + }, + { + "id": 259, + "logprob": -9.4921875, + "text": " " + } + ], + "seed": null, + "tokens": [ + { + "id": 29896, + "logprob": -0.7709961, + "special": false, + "text": "1" + }, + { + "id": 29906, + "logprob": -0.23840332, + "special": false, + "text": "2" + }, + { + "id": 29941, + "logprob": -0.00995636, + "special": false, + "text": "3" + }, + { + "id": 29946, + "logprob": -0.64208984, + "special": false, + "text": "4" + }, + { + "id": 29945, + "logprob": -0.4970703, + "special": false, + "text": "5" + }, + { + "id": 29953, + "logprob": -0.46533203, + "special": false, + "text": "6" + }, + { + "id": 29992, + "logprob": -0.5336914, + "special": false, + "text": "@" + }, + { + "id": 21980, + "logprob": -0.5361328, + "special": false, + "text": "gmail" + }, + { + "id": 29889, + "logprob": -0.00088739395, + "special": false, + "text": "." + }, + { + "id": 510, + "logprob": -0.0022735596, + "special": false, + "text": "com" + } + ], + "top_tokens": null + }, + "generated_text": "123456@gmail.com" + } +] diff --git a/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_regex.json b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_regex.json new file mode 100644 index 00000000..1ba9ae1e --- /dev/null +++ b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_regex.json @@ -0,0 +1,109 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 806, + "logprob": -11.890625, + "text": "Wh" + }, + { + "id": 1446, + "logprob": -3.6699219, + "text": "ats" + }, + { + "id": 2921, + "logprob": -7.8203125, + "text": "Go" + }, + { + "id": 468, + "logprob": -8.0703125, + "text": "og" + }, + { + "id": 793, + "logprob": -2.1875, + "text": "les" + }, + { + "id": 16332, + "logprob": -9.7109375, + "text": "DNS" + } + ], + "seed": null, + "tokens": [ + { + "id": 29946, + "logprob": -1.4765625, + "special": false, + "text": "4" + }, + { + "id": 29906, + "logprob": -0.9199219, + "special": false, + "text": "2" + }, + { + "id": 29889, + "logprob": 0.0, + "special": false, + "text": "." + }, + { + "id": 29896, + "logprob": -1.1367188, + "special": false, + "text": "1" + }, + { + "id": 29889, + "logprob": -1.4648438, + "special": false, + "text": "." + }, + { + "id": 29896, + "logprob": -0.40722656, + "special": false, + "text": "1" + }, + { + "id": 29889, + "logprob": -0.17419434, + "special": false, + "text": "." + }, + { + "id": 29896, + "logprob": -0.20251465, + "special": false, + "text": "1" + }, + { + "id": 29900, + "logprob": -1.5527344, + "special": false, + "text": "0" + }, + { + "id": 29896, + "logprob": -1.3710938, + "special": false, + "text": "1" + } + ], + "top_tokens": null + }, + "generated_text": "42.1.1.101" +} diff --git a/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_single_load_instance.json b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_single_load_instance.json new file mode 100644 index 00000000..7ffb17cb --- /dev/null +++ b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_single_load_instance.json @@ -0,0 +1,73 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 29896, + "logprob": -0.7685547, + "special": false, + "text": "1" + }, + { + "id": 29906, + "logprob": -0.33666992, + "special": false, + "text": "2" + }, + { + "id": 29941, + "logprob": -0.009979248, + "special": false, + "text": "3" + }, + { + "id": 29946, + "logprob": -0.64208984, + "special": false, + "text": "4" + }, + { + "id": 29945, + "logprob": -0.4970703, + "special": false, + "text": "5" + }, + { + "id": 29953, + "logprob": -0.46533203, + "special": false, + "text": "6" + }, + { + "id": 29992, + "logprob": -0.5336914, + "special": false, + "text": "@" + }, + { + "id": 21980, + "logprob": -0.53759766, + "special": false, + "text": "gmail" + }, + { + "id": 29889, + "logprob": -0.0008878708, + "special": false, + "text": "." + }, + { + "id": 510, + "logprob": -0.002275467, + "special": false, + "text": "com" + } + ], + "top_tokens": null + }, + "generated_text": "123456@gmail.com" +} diff --git a/integration-tests/models/test_grammar_llama.py b/integration-tests/models/test_grammar_llama.py new file mode 100644 index 00000000..f068496c --- /dev/null +++ b/integration-tests/models/test_grammar_llama.py @@ -0,0 +1,151 @@ +import pytest +import json + +from text_generation.types import GrammarType + + +@pytest.fixture(scope="module") +def flash_llama_grammar_handle(launcher): + with launcher( + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llama_grammar(flash_llama_grammar_handle): + await flash_llama_grammar_handle.health(300) + return flash_llama_grammar_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot): + response = await flash_llama_grammar.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot): + response = await flash_llama_grammar.generate( + "Whats Googles DNS", + max_new_tokens=10, + decoder_input_details=True, + seed=0, + grammar={ + "type": GrammarType.Regex, # "regex" + "value": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)", + }, + ) + + assert response.details.generated_tokens == 10 + assert response.generated_text == "42.1.1.101" + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot): + response = await flash_llama_grammar.generate( + "info: david holtz like trees and has two cats. ", + max_new_tokens=100, + decoder_input_details=True, + seed=0, + grammar={ + "type": GrammarType.Json, # "json" + "value": json.dumps( + { + "type": "object", + "$id": "https://example.com/person.schema.json", + "$schema": "https://json-schema.org/draft/2020-12/schema", + "title": "Person", + "properties": { + "firstName": { + "type": "string", + "description": "The person'''s first name.", + }, + "lastName": { + "type": "string", + "description": "The person'''s last name.", + }, + "hobby": { + "description": "The person'''s hobby.", + "type": "string", + }, + "numCats": { + "description": "The number of cats the person has.", + "type": "integer", + "minimum": 0, + }, + }, + "required": ["firstName", "lastName", "hobby", "numCats"], + } + ), + }, + ) + + assert response.details.generated_tokens == 30 + assert ( + response.generated_text + == '{"firstName":"David","lastName":"Holtz","hobby":"Trees","numCats":2}' + ) + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_load( + flash_llama_grammar, generate_load, response_snapshot +): + responses = await generate_load( + flash_llama_grammar, + "name: david. email: ", + max_new_tokens=10, + n=4, + stop_sequences=[".com"], + seed=0, + grammar={ + "type": GrammarType.Regex, # "regex" + "value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex + }, + ) + + assert len(responses) == 4 + + expected = "123456@gmail.com" + + for response in responses: + assert response.generated_text == expected + + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot + + +# this is the same as the above test, but only fires off a single request +# this is only to ensure that the parallel and single inference produce the same result +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_single_load_instance( + flash_llama_grammar, generate_load, response_snapshot +): + response = await flash_llama_grammar.generate( + "name: david. email: ", + max_new_tokens=10, + stop_sequences=[".com"], + seed=0, + grammar={ + "type": GrammarType.Regex, # "regex" + "value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex + }, + ) + + # assert response.details.generated_tokens == 30 + assert response.generated_text == "123456@gmail.com" + + assert response == response_snapshot diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 8367ef81..d52e2669 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -382,6 +382,11 @@ struct Args { #[clap(long, env)] tokenizer_config_path: Option, + /// Disable outlines grammar constrained generation. + /// This is a feature that allows you to generate text that follows a specific grammar. + #[clap(long, env)] + disable_grammar_support: bool, + /// Display a lot of information about your runtime environment #[clap(long, short, action)] env: bool, @@ -1051,6 +1056,11 @@ fn spawn_webserver( args.model_id, ]; + // Grammar support + if args.disable_grammar_support { + router_args.push("--disable-grammar-support".to_string()); + } + // Tokenizer config path if let Some(ref tokenizer_config_path) = args.tokenizer_config_path { router_args.push("--tokenizer-config-path".to_string()); diff --git a/proto/generate.proto b/proto/generate.proto index 5140fdaa..0490029f 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -51,6 +51,12 @@ message ClearCacheRequest { /// Empty response message ClearCacheResponse {} +enum GrammarType { + GRAMMAR_TYPE_NONE = 0; + GRAMMAR_TYPE_JSON = 1; + GRAMMAR_TYPE_REGEX = 2; +} + message NextTokenChooserParameters { /// exponential scaling output probability distribution float temperature = 1; @@ -70,6 +76,10 @@ message NextTokenChooserParameters { float frequency_penalty = 9; /// token watermarking using "A Watermark for Large Language Models" bool watermark = 8; + /// grammar (applied if not empty) + string grammar = 10; + /// grammar type + GrammarType grammar_type = 11; } message StoppingCriteriaParameters { diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 7b9f90fb..f8658318 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -128,6 +128,8 @@ impl Client { repetition_penalty: 1.2, frequency_penalty: 0.1, watermark: true, + grammar: String::new(), + grammar_type: GrammarType::None as i32, }), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: max_total_tokens - truncate, diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index c38b931b..6782d9ff 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -9,8 +9,8 @@ pub use client::Client; pub use pb::generate::v2::HealthResponse; pub use pb::generate::v2::InfoResponse as ShardInfo; pub use pb::generate::v2::{ - Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, - Request, StoppingCriteriaParameters, Tokens, + Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, + NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens, }; pub use sharded_client::ShardedClient; use thiserror::Error; diff --git a/router/src/health.rs b/router/src/health.rs index e830a3c3..b05b3094 100644 --- a/router/src/health.rs +++ b/router/src/health.rs @@ -1,5 +1,6 @@ use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; +use text_generation_client::GrammarType as ProtoGrammarType; use text_generation_client::{ Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters, }; @@ -45,6 +46,8 @@ impl Health { repetition_penalty: 1.0, frequency_penalty: 0.0, watermark: false, + grammar: String::new(), + grammar_type: ProtoGrammarType::None as i32, }), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: 1, diff --git a/router/src/lib.rs b/router/src/lib.rs index a9d783bb..87873821 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -45,6 +45,43 @@ impl HubTokenizerConfig { } } +mod json_object_or_string_to_string { + use serde::{Deserialize, Deserializer}; + use serde_json::Value; + + // A custom deserializer that treats both strings and objects as strings. + // This provides flexibility with input formats for the 'grammar' field. + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let value = Value::deserialize(deserializer)?; + + match value { + Value::String(s) => Ok(s), + // Safely handle serialization and return an error if it fails + Value::Object(o) => { + serde_json::to_string(&o).map_err(|e| serde::de::Error::custom(e.to_string())) + } + _ => Err(serde::de::Error::custom( + "expected string or object for grammar", + )), + } + } +} + +#[derive(Clone, Debug, Deserialize)] +#[serde(tag = "type", content = "value")] +pub(crate) enum GrammarType { + #[serde( + rename = "json", + deserialize_with = "json_object_or_string_to_string::deserialize" + )] + Json(String), + #[serde(rename = "regex")] + Regex(String), +} + mod token_serde { use super::*; use serde::de; @@ -201,6 +238,8 @@ pub(crate) struct GenerateParameters { #[serde(default)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)] pub top_n_tokens: Option, + #[serde(default)] + pub grammar: Option, } fn default_max_new_tokens() -> Option { @@ -226,6 +265,7 @@ fn default_parameters() -> GenerateParameters { decoder_input_details: false, seed: None, top_n_tokens: None, + grammar: None, } } diff --git a/router/src/main.rs b/router/src/main.rs index a1f8d97b..457bca8e 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -75,6 +75,8 @@ struct Args { ngrok_edge: Option, #[clap(long, env, default_value_t = false)] messages_api_enabled: bool, + #[clap(long, env, default_value_t = false)] + disable_grammar_support: bool, } #[tokio::main] @@ -108,6 +110,7 @@ async fn main() -> Result<(), RouterError> { ngrok_authtoken, ngrok_edge, messages_api_enabled, + disable_grammar_support, } = args; // Launch Tokio runtime @@ -359,6 +362,7 @@ async fn main() -> Result<(), RouterError> { ngrok_edge, tokenizer_config, messages_api_enabled, + disable_grammar_support, ) .await?; Ok(()) diff --git a/router/src/queue.rs b/router/src/queue.rs index 3675e0f5..52ea16ca 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -343,7 +343,9 @@ enum QueueCommand { #[cfg(test)] mod tests { use super::*; - use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; + use text_generation_client::{ + GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters, + }; use tracing::info_span; fn default_entry() -> ( @@ -354,7 +356,7 @@ mod tests { let entry = Entry { request: ValidGenerateRequest { - inputs: "".to_string(), + inputs: String::new(), input_length: 0, truncate: 0, decoder_input_details: false, @@ -368,6 +370,8 @@ mod tests { repetition_penalty: 0.0, frequency_penalty: 0.0, watermark: false, + grammar: String::new(), + grammar_type: ProtoGrammarType::None as i32, }, stopping_parameters: StoppingCriteriaParameters { ignore_eos_token: false, diff --git a/router/src/server.rs b/router/src/server.rs index 00b793e3..0fc76916 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -614,6 +614,7 @@ async fn chat_completions( decoder_input_details: !stream, seed, top_n_tokens: None, + grammar: None, }, }; @@ -779,6 +780,7 @@ pub async fn run( ngrok_edge: Option, tokenizer_config: HubTokenizerConfig, messages_api_enabled: bool, + grammar_support: bool, ) -> Result<(), axum::BoxError> { // OpenAPI documentation #[derive(OpenApi)] @@ -840,6 +842,7 @@ pub async fn run( max_top_n_tokens, max_input_length, max_total_tokens, + grammar_support, ); let generation_health = Arc::new(AtomicBool::new(false)); let health_ext = Health::new(client.clone(), generation_health.clone()); diff --git a/router/src/validation.rs b/router/src/validation.rs index e6874b11..7801f4e3 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -1,8 +1,10 @@ /// Payload validation logic use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; -use crate::{GenerateParameters, GenerateRequest}; +use crate::{GenerateParameters, GenerateRequest, GrammarType}; use rand::{thread_rng, Rng}; -use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; +use text_generation_client::{ + GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters, +}; use thiserror::Error; use tokenizers::tokenizer::Tokenizer; use tokenizers::TruncationDirection; @@ -19,6 +21,7 @@ pub struct Validation { max_top_n_tokens: u32, max_input_length: usize, max_total_tokens: usize, + disable_grammar_support: bool, /// Channel to communicate with the background tokenization task sender: Option>, } @@ -32,6 +35,7 @@ impl Validation { max_top_n_tokens: u32, max_input_length: usize, max_total_tokens: usize, + disable_grammar_support: bool, ) -> Self { // If we have a fast tokenizer let sender = if let Some(tokenizer) = tokenizer { @@ -66,6 +70,7 @@ impl Validation { max_top_n_tokens, max_input_length, max_total_tokens, + disable_grammar_support, } } @@ -182,6 +187,7 @@ impl Validation { watermark, decoder_input_details, top_n_tokens, + grammar, .. } = request.parameters; @@ -292,6 +298,28 @@ impl Validation { .validate_input(request.inputs, truncate, max_new_tokens) .await?; + // TODO: we should build the FSM here and pass the compiled FSM instead of the grammar + // NOTE: this is currently difficult because we need the tokenizer in Python to build + // the FSM and we'd have to load a copy of the tokenizer into our Pyo3 instance which + // may be slow and memory intensive. Best case is to have a Rust implementation of the FSM + // compiler and use that to build the FSM here. + + // Validate grammar and unpack the grammar and type for the proto message + let (grammar, grammar_type) = match grammar { + Some(grammar) => { + // Ensure that grammar is not set if it's not supported + if self.disable_grammar_support { + return Err(ValidationError::Grammar); + } + match grammar { + // currently both are handled the same way since compilation is done in Python + GrammarType::Json(json) => (json, ProtoGrammarType::Json.into()), + GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()), + } + } + None => (String::new(), ProtoGrammarType::None.into()), + }; + let parameters = NextTokenChooserParameters { temperature, repetition_penalty, @@ -302,6 +330,8 @@ impl Validation { do_sample, seed, watermark, + grammar, + grammar_type, }; let stopping_parameters = StoppingCriteriaParameters { max_new_tokens, @@ -453,6 +483,8 @@ pub enum ValidationError { StopSequence(usize, usize), #[error("tokenizer error {0}")] Tokenizer(String), + #[error("grammar is not supported")] + Grammar, } #[cfg(test)] @@ -470,6 +502,7 @@ mod tests { let max_input_length = 5; let max_total_tokens = 6; let workers = 1; + let disable_grammar_support = true; let validation = Validation::new( workers, tokenizer, @@ -478,6 +511,7 @@ mod tests { max_top_n_tokens, max_input_length, max_total_tokens, + disable_grammar_support, ); let max_new_tokens = 10; @@ -498,6 +532,7 @@ mod tests { let max_top_n_tokens = 4; let max_input_length = 5; let max_total_tokens = 6; + let disable_grammar_support = true; let workers = 1; let validation = Validation::new( workers, @@ -507,6 +542,7 @@ mod tests { max_top_n_tokens, max_input_length, max_total_tokens, + disable_grammar_support, ); let max_new_tokens = 10; @@ -528,6 +564,7 @@ mod tests { let max_input_length = 5; let max_total_tokens = 6; let workers = 1; + let disable_grammar_support = true; let validation = Validation::new( workers, tokenizer, @@ -536,6 +573,7 @@ mod tests { max_top_n_tokens, max_input_length, max_total_tokens, + disable_grammar_support, ); match validation .validate(GenerateRequest { @@ -562,6 +600,7 @@ mod tests { let max_input_length = 5; let max_total_tokens = 106; let workers = 1; + let disable_grammar_support = true; let validation = Validation::new( workers, tokenizer, @@ -570,6 +609,7 @@ mod tests { max_top_n_tokens, max_input_length, max_total_tokens, + disable_grammar_support, ); match validation .validate(GenerateRequest { @@ -625,6 +665,7 @@ mod tests { let max_input_length = 5; let max_total_tokens = 106; let workers = 1; + let disable_grammar_support = true; let validation = Validation::new( workers, tokenizer, @@ -633,6 +674,7 @@ mod tests { max_top_n_tokens, max_input_length, max_total_tokens, + disable_grammar_support, ); match validation .validate(GenerateRequest { diff --git a/server/pyproject.toml b/server/pyproject.toml index b8ebf2e3..566eda7a 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -34,6 +34,7 @@ peft = { version = "^0.8.2", optional = true } torch = { version = "^2.1.1", optional = true } scipy = "^1.11.1" pillow = "^10.0.0" +outlines="^0.0.27" [tool.poetry.extras] torch = ["torch"] diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index a7a16212..a0f0c9e8 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -87,7 +87,9 @@ class CausalLMBatch(Batch): for i, r in enumerate(pb.requests): requests_idx_mapping[r.id] = i inputs.append(r.inputs) - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + next_token_choosers.append( + NextTokenChooser.from_pb(r.parameters, device, tokenizer) + ) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) @@ -413,14 +415,14 @@ class CausalLMBatch(Batch): # We slice the keys to remove the padding from previous batches past_seq_len = batch.max_input_length - 1 if batch.keys_head_dim_last: - padded_past_keys[ - start_index:end_index, :, -past_seq_len:, : - ] = past_keys[:, :, -past_seq_len:, :] + padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = ( + past_keys[:, :, -past_seq_len:, :] + ) else: # BLOOM case - padded_past_keys[ - start_index:end_index, :, :, -past_seq_len: - ] = past_keys[:, :, :, -past_seq_len:] + padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = ( + past_keys[:, :, :, -past_seq_len:] + ) del past_keys start_index = end_index @@ -438,9 +440,9 @@ class CausalLMBatch(Batch): end_index = start_index + len(batch) # We slice the past values to remove the padding from previous batches past_seq_len = batch.max_input_length - 1 - padded_past_values[ - start_index:end_index, :, -past_seq_len:, : - ] = past_values[:, :, -past_seq_len:, :] + padded_past_values[start_index:end_index, :, -past_seq_len:, :] = ( + past_values[:, :, -past_seq_len:, :] + ) del past_values # Update values @@ -504,9 +506,11 @@ class CausalLM(Model): model_id, revision=revision, torch_dtype=dtype, - device_map="auto" - if torch.cuda.is_available() and torch.cuda.device_count() > 1 - else None, + device_map=( + "auto" + if torch.cuda.is_available() and torch.cuda.device_count() > 1 + else None + ), load_in_8bit=quantize == "bitsandbytes", trust_remote_code=trust_remote_code, ) @@ -696,7 +700,7 @@ class CausalLM(Model): if top_n_tokens > 0: all_top_tokens = [] - for (top_token_ids, top_token_logprobs) in zip( + for top_token_ids, top_token_logprobs in zip( top_token_ids, top_token_logprobs ): toptoken_texts = self.tokenizer.batch_decode( @@ -735,6 +739,9 @@ class CausalLM(Model): generations.append(generation) # Update values + batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar( + next_token_id_squeezed.item() + ) batch.input_ids[i, 0] = next_token_id batch.all_input_ids[i] = all_input_ids batch.input_lengths[i] = new_input_length diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 886fe486..7ec8c2fc 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -237,7 +237,7 @@ class FlashCausalLMBatch(Batch): ) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - next_token_chooser_parameters, dtype, device + next_token_chooser_parameters, dtype, device, tokenizer ) start_slots = torch.tensor(start_slots, dtype=torch.int64) @@ -593,6 +593,7 @@ class FlashCausalLMBatch(Batch): next_token_chooser_parameters, dtype=batches[0].next_token_chooser.dtype, device=batches[0].next_token_chooser.device, + tokenizer=batches[0].next_token_chooser.tokenizer, ) speculative_ids = ( @@ -869,7 +870,11 @@ class FlashCausalLM(Model): # Try to find an associated cuda graph cuda_graph = self.cuda_graphs.get(padded_bs, None) - if cu_seqlen_prefill is not None or cuda_graph is None or batch.speculative_ids is not None: + if ( + cu_seqlen_prefill is not None + or cuda_graph is None + or batch.speculative_ids is not None + ): return self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -1013,9 +1018,9 @@ class FlashCausalLM(Model): # Copy batch.input_ids to prefill_token_indices if prefill_logprobs: if len(batch) > 1: - prefill_tokens_indices[ - out_start_index : out_end_index - 1 - ] = batch.input_ids[start_index + 1 : start_index + out_length] + prefill_tokens_indices[out_start_index : out_end_index - 1] = ( + batch.input_ids[start_index + 1 : start_index + out_length] + ) else: # Set prefill_tokens_indices to the correct slice prefill_tokens_indices = batch.input_ids[ @@ -1028,6 +1033,7 @@ class FlashCausalLM(Model): cumulative_length += input_length + # Update values batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] batch.speculative_ids = speculative_ids batch.position_ids = next_position_ids + accepted_ids @@ -1166,7 +1172,7 @@ class FlashCausalLM(Model): if top_n_tokens > 0: all_top_tokens = [] - for (top_token_ids, top_token_logprobs) in zip( + for top_token_ids, top_token_logprobs in zip( top_token_ids, top_token_logprobs ): toptoken_texts = self.tokenizer.batch_decode( @@ -1204,6 +1210,12 @@ class FlashCausalLM(Model): generations.append(generation) + # accept each new token for this specific request since we may + # have more than one new token per request with speculative decoding + for next_token_id in _next_token_ids: + batch.next_token_chooser = batch.next_token_chooser.advance_grammar_single(i, next_token_id) + + # Update values batch.input_lengths[i] = input_length + n_accepted_ids if batch.input_lengths[i] > batch.max_seqlen: diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 34a50194..70669c8d 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -192,7 +192,7 @@ class FlashMistralBatch(FlashCausalLMBatch): ) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - next_token_chooser_parameters, dtype, device + next_token_chooser_parameters, dtype, device, tokenizer ) start_slots = torch.tensor(start_slots, dtype=torch.int64) diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 42ff1c80..a2c30255 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -92,7 +92,7 @@ class GalacticaCausalLMBatch(CausalLMBatch): requests_idx_mapping[r.id] = i # Add escape_custom_split_sequence to the CausalLMBatch logic inputs.append(escape_custom_split_sequence(r.inputs)) - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index 2f28688d..5ea2db87 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -114,7 +114,7 @@ class IdeficsCausalLMBatch(Batch): for i, r in enumerate(pb.requests): requests_idx_mapping[r.id] = i inputs.append(r.inputs) - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) @@ -815,6 +815,9 @@ class IdeficsCausalLM(Model): generations.append(generation) # Update values + batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar( + next_token_id_squeezed.item() + ) batch.input_ids[i, 0] = next_token_id batch.all_input_ids[i] = all_input_ids batch.input_lengths[i] = new_input_length diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 868db6aa..4585f4b9 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -124,7 +124,7 @@ class MambaBatch(Batch): for i, r in enumerate(pb.requests): requests_idx_mapping[r.id] = i inputs.append(r.inputs) - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) @@ -694,6 +694,9 @@ class Mamba(Model): generations.append(generation) # Update values + batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar( + next_token_id_squeezed.item() + ) batch.input_ids[i, 0] = next_token_id batch.all_input_ids[i] = all_input_ids batch.input_lengths[i] = new_input_length diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 25042a32..459f4256 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -96,7 +96,7 @@ class Seq2SeqLMBatch(Batch): inputs.append(r.inputs) requests_idx_mapping[r.id] = i decoder_input_lengths.append(1) - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) @@ -789,6 +789,9 @@ class Seq2SeqLM(Model): generations.append(generation) # Update values + batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar( + next_token_id_squeezed.item() + ) batch.decoder_input_ids[i] = next_token_id batch.all_decoder_input_ids[i] = all_decoder_input_ids batch.input_lengths[i] = input_length diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 291c522f..73fcf53f 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -1,8 +1,17 @@ import math import torch +import json +from loguru import logger from functools import lru_cache from typing import Optional, List, Dict, Union +from text_generation_server.pb.generate_pb2 import GrammarType + +from outlines.fsm.fsm import RegexFSM +from outlines.fsm.json_schema import build_regex_from_object +from functools import lru_cache +from typing import List, Optional, DefaultDict +import time from transformers import ( LogitsWarper, @@ -135,9 +144,7 @@ class FrequencyPenaltyLogitsProcessor(LogitsProcessor): ) -> torch.FloatTensor: score = torch.gather(scores, 1, input_ids) # if score < 0 then penalty has to be multiplied to reduce the previous token probability - score = -torch.where( - score < 0, score * self.penalty, score / self.penalty - ) + score = -torch.where(score < 0, score * self.penalty, score / self.penalty) return scores.scatter_add_(1, input_ids, score) @@ -464,3 +471,132 @@ class HeterogeneousProcessorWrapper(LogitsProcessor): self.processors = new_processors return self return None + + +class GrammarLogitProcessor(LogitsProcessor): + fsm_state: DefaultDict[int, int] + fsm: RegexFSM + + def __init__(self, tokenizer, device, grammar, grammar_type): + self.device = device + self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer) + self.fsm = GrammarLogitProcessor._cached_compile_fsm( + grammar_type, grammar, self.tokenizer + ) + + def __call__( + self, + logits: torch.Tensor, + fsm_grammar_state: int, + ): + if fsm_grammar_state == -1 or self.fsm is None: + return logits + allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state) + mask = torch.full((logits.shape[-1],), -math.inf, device=self.device) + mask[allowed_tokens] = 0 + biased_scores = logits + mask + return biased_scores + + def advance(self, next_token_id, fsm_grammar_state): + return GrammarLogitProcessor._advance( + next_token_id, fsm_grammar_state, self.fsm + ) + + @staticmethod + def _advance(next_token_id, fsm_grammar_state, fsm): + if fsm_grammar_state == -1: + return fsm_grammar_state + return fsm.next_state(fsm_grammar_state, next_token_id) + + # TODO: move grammar compilation into the router + @staticmethod + @lru_cache(maxsize=32, typed=True) + def _cached_compile_fsm(grammar_type, schema, tokenizer): + start_time = time.time() + if grammar_type == GrammarType.GRAMMAR_TYPE_JSON: + schema = build_regex_from_object(schema) + elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX: + pass # schema is already a regex just here for clarity + fsm = RegexFSM(schema, tokenizer) + logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s") + return fsm + + @staticmethod + @lru_cache(maxsize=32, typed=True) + def _cached_adapt_tokenizer(tokenizer): + """Adapt tokenizer to work with the FSM. + + The API of Outlines tokenizers is slightly different to that of + `transformers`. In addition we need to handle the missing spaces to + Llama's tokenizer to be able to compile FSMs for this model. + + """ + start_time = time.time() + tokenizer.vocabulary = tokenizer.get_vocab() + tokenizer.special_tokens = set(tokenizer.all_special_tokens) + + def convert_token_to_string(token: str) -> str: + from transformers.file_utils import SPIECE_UNDERLINE + + string = tokenizer.convert_tokens_to_string([token]) + + # A hack to handle missing spaces to HF's Llama tokenizers + if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": + return " " + string + + return string + + tokenizer.convert_token_to_string = convert_token_to_string + logger.debug(f"Adapted tokenizer in {time.time() - start_time:.2f}s") + return tokenizer + + def filter(self, indices): + new_fsms = [] + for i in indices: + new_fsms.append(self.fsms[i]) + self.fsms = new_fsms + return self + + +class HeterogeneousGrammarLogitProcessor(LogitsProcessor): + def __init__(self, tokenizer, device, grammars, grammar_type): + self.device = device + self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer) + self.fsms = [] + for i in range(len(grammars)): + fsm = GrammarLogitProcessor._cached_compile_fsm( + grammar_type[i], grammars[i], self.tokenizer + ) + self.fsms.append(fsm) + + def __call__( + self, + logits: torch.Tensor, + fsm_grammar_states: List[int], + mask: torch.Tensor, + ): + mask = torch.full_like(logits, -math.inf) + for i in range(logits.shape[0]): + fsm = self.fsms[i] + if fsm_grammar_states[i] == -1 or fsm is None: + continue + allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i]) + mask[i, allowed_tokens] = 0 + logits += mask + return logits + + def advance_batch(self, next_token_ids, fsm_grammar_states, grammars): + return [ + GrammarLogitProcessor._advance( + next_token_ids[i], fsm_grammar_states[i], self.fsms[i] + ) + for i in range(len(next_token_ids)) + ] + + def advance_at_index(self, next_token_id, fsm_grammar_state, index): + return GrammarLogitProcessor._advance( + next_token_id, fsm_grammar_state, self.fsms[index] + ) + + def filter(self, indices): + return GrammarLogitProcessor.filter(self, indices) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index d6ca10c7..2784585e 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -1,11 +1,13 @@ import re from typing import List, Optional, Tuple +import math import torch from text_generation_server.pb import generate_pb2 -from text_generation_server.pb.generate_pb2 import FinishReason +from text_generation_server.pb.generate_pb2 import FinishReason, GrammarType from text_generation_server.utils.logits_process import ( FrequencyPenaltyLogitsProcessor, + GrammarLogitProcessor, HeterogeneousProcessorWrapper, HeterogeneousRepetitionPenaltyLogitsProcessor, HeterogeneousFrequencyPenaltyLogitsProcessor, @@ -13,6 +15,7 @@ from text_generation_server.utils.logits_process import ( HeterogeneousTopKLogitsWarper, HeterogeneousTopPLogitsWarper, HeterogeneousTypicalLogitsWarper, + HeterogeneousGrammarLogitProcessor, static_warper, ) from text_generation_server.utils.watermark import WatermarkLogitsProcessor @@ -22,16 +25,20 @@ from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcess class NextTokenChooser: def __init__( self, - watermark=False, - temperature=1.0, - repetition_penalty=1.0, - frequency_penalty=0.0, - top_k=None, - top_p=None, - typical_p=None, - do_sample=False, - seed=0, - device="cpu", + watermark: bool = False, + temperature: float = 1.0, + repetition_penalty: float = 1.0, + frequency_penalty: float = 0.0, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + typical_p: Optional[float] = None, + do_sample: bool = False, + seed: int = 0, + device: str = "cpu", + tokenizer: Optional[PreTrainedTokenizerBase] = None, + grammar: str = "", + grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE, + fsm_grammar_state: int = 0, ): self.watermark_processor = ( WatermarkLogitsProcessor(device=device) if watermark else None @@ -46,6 +53,12 @@ class NextTokenChooser: if frequency_penalty and frequency_penalty != 0.0 else None ) + self.grammar_processor = ( + GrammarLogitProcessor(tokenizer, device, grammar, grammar_type) + if grammar != "" + else None + ) + self.tokenizer = tokenizer has_warpers = ( (temperature is not None and temperature != 1.0) @@ -61,7 +74,10 @@ class NextTokenChooser: self.static_warper = None sampling = do_sample or has_warpers + self.choice = Sampling(seed, device) if sampling else Greedy() + self.fsm_grammar_state = fsm_grammar_state + self.grammar = grammar def __call__(self, input_ids, scores): if self.watermark_processor is not None: @@ -70,6 +86,8 @@ class NextTokenChooser: scores = self.repetition_processor(input_ids, scores) if self.frequency_processor is not None: scores = self.frequency_processor(input_ids, scores) + if self.grammar_processor is not None: + scores = self.grammar_processor(scores, self.fsm_grammar_state) if self.static_warper is None: next_logprob = torch.log_softmax(scores, -1) @@ -80,11 +98,19 @@ class NextTokenChooser: return next_id, next_logprob + def advance_grammar(self, next_id: int): + if self.grammar_processor is not None: + self.fsm_grammar_state = self.grammar_processor.advance( + next_id, self.fsm_grammar_state + ) + return self + @classmethod def from_pb( cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device, + tokenizer: PreTrainedTokenizerBase, ) -> "NextTokenChooser": return NextTokenChooser( watermark=pb.watermark, @@ -97,6 +123,9 @@ class NextTokenChooser: do_sample=pb.do_sample, seed=pb.seed, device=device, + tokenizer=tokenizer, + grammar=pb.grammar, + grammar_type=pb.grammar_type, ) @@ -201,6 +230,10 @@ class HeterogeneousNextTokenChooser: typical_p: List[float], do_sample: List[bool], seeds: List[int], + tokenizer: PreTrainedTokenizerBase, + grammars: List[str], + grammar_types: List[int], + fsm_grammar_states=List[int], ): warpers = [] @@ -232,6 +265,14 @@ class HeterogeneousNextTokenChooser: else None ) + self.grammar_processor = ( + HeterogeneousGrammarLogitProcessor( + tokenizer, device, grammars, grammar_types + ) + if any([grammar != "" for grammar in grammars]) + else None + ) + if any([x != 1.0 for x in temperature]): do_sample = [ sample or x != 1.0 for x, sample in zip(temperature, do_sample) @@ -263,6 +304,10 @@ class HeterogeneousNextTokenChooser: self.do_sample = do_sample self.dtype = dtype self.device = device + self.tokenizer = tokenizer + self.fsm_grammar_states = fsm_grammar_states + self.grammars = grammars + self.grammar_types = grammar_types def __call__( self, @@ -283,6 +328,8 @@ class HeterogeneousNextTokenChooser: scores = scores.view(B, S, -1) next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long) + mask = torch.full((scores.shape[-1],), -math.inf, device=self.device) + for j in range(S): _scores = scores[:, j] if self.watermark_processor is not None: @@ -291,10 +338,10 @@ class HeterogeneousNextTokenChooser: _scores = self.repetition_processor(input_ids, _scores) if self.frequency_processor is not None: _scores = self.frequency_processor(input_ids, _scores) - for warper in self.warpers: _scores = warper(input_ids, _scores) - + if self.grammar_processor is not None: + _scores = self.grammar_processor(_scores, self.fsm_grammar_states, mask) _next_ids = self.choice(_scores) scores[:, j] = _scores next_ids[:, j] = _next_ids @@ -352,6 +399,21 @@ class HeterogeneousNextTokenChooser: return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids + def advance_grammar(self, next_ids: List[int]): + if self.grammar_processor is not None: + other_new_states = self.grammar_processor.advance_batch( + next_ids, self.fsm_grammar_states, self.grammars + ) + self.fsm_grammar_states = other_new_states + return self + + def advance_grammar_single(self, grammar_state_index: int, next_id: int): + if self.grammar_processor is not None: + self.fsm_grammar_states[grammar_state_index] = self.grammar_processor.advance_at_index( + next_id, self.fsm_grammar_states[grammar_state_index], grammar_state_index + ) + return self + def filter(self, indices): if self.watermark_processor is not None: self.watermark_processor = self.watermark_processor.filter(indices) @@ -362,6 +424,9 @@ class HeterogeneousNextTokenChooser: if self.frequency_processor is not None: self.frequency_processor = self.frequency_processor.filter(indices) + if self.grammar_processor is not None: + self.grammar_processor = self.grammar_processor.filter(indices) + filtered_warpers = [] for warper in self.warpers: filtered_warper = warper.filter(indices) @@ -372,6 +437,18 @@ class HeterogeneousNextTokenChooser: self.seeds = [self.seeds[i] for i in indices] self.do_sample = [self.do_sample[i] for i in indices] + new_grammars = [] + new_fsm_grammar_states = [] + new_grammar_types = [] + for i in indices: + new_grammars.append(self.grammars[i]) + new_fsm_grammar_states.append(self.fsm_grammar_states[i]) + new_grammar_types.append(self.grammar_types[i]) + + self.grammars = new_grammars + self.fsm_grammar_states = new_fsm_grammar_states + self.grammar_types = new_grammar_types + if any(self.do_sample): self.choice.filter(indices) else: @@ -385,6 +462,7 @@ class HeterogeneousNextTokenChooser: pb: List[generate_pb2.NextTokenChooserParameters], dtype: torch.dtype, device: torch.device, + tokenizer: PreTrainedTokenizerBase, ) -> "HeterogeneousNextTokenChooser": return HeterogeneousNextTokenChooser( watermark=[pb_.watermark for pb_ in pb], @@ -398,6 +476,10 @@ class HeterogeneousNextTokenChooser: seeds=[pb_.seed for pb_ in pb], device=device, dtype=dtype, + tokenizer=tokenizer, + grammars=[pb_.grammar for pb_ in pb], + grammar_types=[pb_.grammar_type for pb_ in pb], + fsm_grammar_states=[0] * len(pb), )