From 289aa4855449fe9162c8da7e13aaaaef81755a3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 25 Nov 2024 18:47:34 +0100 Subject: [PATCH] Move JSON grammar -> regex grammar conversion to the router (#2772) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Move JSON grammar -> regex grammar conversion to the router This change moves the JSON grammar -> regex grammar conversion to the router by adding a dependency on the `outlines-core` Rust crate. In contrast to the Python implementation, the conversions are not LRU-cached since they seem to be fast enough: simple schema time: [5.8293 µs 5.8307 µs 5.8320 µs] change: [-13.166% -12.884% -12.641%] (p = 0.00 < 0.05) Performance has improved. complex schema time: [14.875 µs 14.881 µs 14.887 µs] change: [-2.1637% -1.9914% -1.7852%] (p = 0.00 < 0.05) Performance has improved. Using the schemas from: https://github.com/dottxt-ai/outlines-core/blob/main/benchmarks/bench_json_schema.py --- .pre-commit-config.yaml | 1 + Cargo.lock | 24 ++++++++++++ crate-hashes.json | 3 ++ ...st_grammar_response_format_llama_json.json | 14 +++---- .../test_flash_llama_grammar_tools.json | 20 +++++----- .../test_flash_llama_grammar_tools_auto.json | 20 +++++----- ..._sea_creatures_stream_function_object.json | 4 +- ...r_tools_sea_creatures_stream_required.json | 4 +- ...test_flash_llama_grammar_tools_stream.json | 13 ++++--- .../test_grammar_response_format_llama.py | 2 +- integration-tests/models/test_tools_llama.py | 10 ++--- router/Cargo.toml | 2 + router/src/infer/chat_template.rs | 2 +- router/src/validation.rs | 15 +++++--- .../utils/logits_process.py | 38 +++++++++++-------- 15 files changed, 108 insertions(+), 64 deletions(-) create mode 100644 crate-hashes.json diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0c8b6885..b5d939cd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,6 +4,7 @@ repos: hooks: - id: check-yaml - id: end-of-file-fixer + exclude: crate-hashes.json - id: trailing-whitespace exclude: docs/source/reference/launcher.md - repo: https://github.com/psf/black diff --git a/Cargo.lock b/Cargo.lock index 77c49408..72f70fdc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3005,6 +3005,17 @@ dependencies = [ "num-traits", ] +[[package]] +name = "outlines-core" +version = "0.1.0" +source = "git+https://github.com/dottxt-ai/outlines-core.git?rev=ba10c619fc9bf3c487e43f49bdecb95a24bb465c#ba10c619fc9bf3c487e43f49bdecb95a24bb465c" +dependencies = [ + "anyhow", + "regex", + "serde-pyobject", + "serde_json", +] + [[package]] name = "overload" version = "0.1.1" @@ -3952,6 +3963,16 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "serde-pyobject" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca4b0aad8b225845739a0030a0d5cc2ae949c56a86a7daf9226c7df7c2016d16" +dependencies = [ + "pyo3", + "serde", +] + [[package]] name = "serde_cbor" version = "0.11.2" @@ -3979,6 +4000,7 @@ version = "1.0.133" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377" dependencies = [ + "indexmap 2.6.0", "itoa", "memchr", "ryu", @@ -4430,6 +4452,7 @@ dependencies = [ name = "text-generation-router" version = "2.4.2-dev0" dependencies = [ + "anyhow", "async-stream", "async-trait", "axum 0.7.9", @@ -4453,6 +4476,7 @@ dependencies = [ "once_cell", "opentelemetry 0.20.0", "opentelemetry-otlp", + "outlines-core", "pyo3", "rand", "regex", diff --git a/crate-hashes.json b/crate-hashes.json new file mode 100644 index 00000000..2694759c --- /dev/null +++ b/crate-hashes.json @@ -0,0 +1,3 @@ +{ + "git+https://github.com/dottxt-ai/outlines-core.git?rev=ba10c619fc9bf3c487e43f49bdecb95a24bb465c#outlines-core@0.1.0": "1j9dcd831b0bmmjk2n4aag3x47qnqmkpg4gqpvwwyic7744llbfm" +} \ No newline at end of file diff --git a/integration-tests/models/__snapshots__/test_grammar_response_format_llama/test_grammar_response_format_llama_json.json b/integration-tests/models/__snapshots__/test_grammar_response_format_llama/test_grammar_response_format_llama_json.json index 2bd79b1d..38229e0a 100644 --- a/integration-tests/models/__snapshots__/test_grammar_response_format_llama/test_grammar_response_format_llama_json.json +++ b/integration-tests/models/__snapshots__/test_grammar_response_format_llama/test_grammar_response_format_llama_json.json @@ -1,23 +1,23 @@ { "choices": [ { - "finish_reason": "eos_token", + "finish_reason": "stop", "index": 0, "logprobs": null, "message": { - "content": "{ \"temperature\": [ 26, 30, 33, 29 ] ,\"unit\": \"Fahrenheit\" }", + "content": "{ \"unit\": \"fahrenheit\", \"temperature\": [ 72, 79, 88 ] }", "role": "assistant" } } ], - "created": 1718044128, + "created": 1732525803, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", - "object": "text_completion", - "system_fingerprint": "2.0.5-dev0-native", + "object": "chat.completion", + "system_fingerprint": "2.4.1-dev0-native", "usage": { - "completion_tokens": 39, + "completion_tokens": 29, "prompt_tokens": 136, - "total_tokens": 175 + "total_tokens": 165 } } diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json index a4c34a10..33e223ba 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json @@ -1,7 +1,7 @@ { "choices": [ { - "finish_reason": "eos_token", + "finish_reason": "stop", "index": 0, "logprobs": null, "message": { @@ -13,12 +13,12 @@ "function": { "arguments": { "format": "celsius", - "location": "Brooklyn" + "location": "Brooklyn, New York" }, "description": null, "name": "get_current_weather" }, - "id": 0, + "id": "0", "type": "function" } ] @@ -26,14 +26,14 @@ "usage": null } ], - "created": 1712782670, + "created": 1732293383, "id": "", - "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", - "object": "text_completion", - "system_fingerprint": "2.0.1-native", + "model": "meta-llama/Llama-3.1-8B-Instruct", + "object": "chat.completion", + "system_fingerprint": "2.4.1-dev0-native", "usage": { - "completion_tokens": 37, - "prompt_tokens": 524, - "total_tokens": 561 + "completion_tokens": 30, + "prompt_tokens": 615, + "total_tokens": 645 } } diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json index 04bcdc4e..92ffbbc1 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json @@ -1,7 +1,7 @@ { "choices": [ { - "finish_reason": "eos_token", + "finish_reason": "stop", "index": 0, "logprobs": null, "message": { @@ -13,12 +13,12 @@ "function": { "arguments": { "format": "celsius", - "location": "Brooklyn" + "location": "Brooklyn, New York" }, "description": null, "name": "get_current_weather" }, - "id": 0, + "id": "0", "type": "function" } ] @@ -26,14 +26,14 @@ "usage": null } ], - "created": 1712787937, + "created": 1732293384, "id": "", - "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", - "object": "text_completion", - "system_fingerprint": "2.0.1-native", + "model": "meta-llama/Llama-3.1-8B-Instruct", + "object": "chat.completion", + "system_fingerprint": "2.4.1-dev0-native", "usage": { - "completion_tokens": 37, - "prompt_tokens": 524, - "total_tokens": 561 + "completion_tokens": 30, + "prompt_tokens": 615, + "total_tokens": 645 } } diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_function_object.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_function_object.json index e64dd49d..bb8d61c8 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_function_object.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_function_object.json @@ -18,10 +18,10 @@ "logprobs": null } ], - "created": 1729084854, + "created": 1732293254, "id": "", "model": "meta-llama/Llama-3.1-8B-Instruct", "object": "chat.completion.chunk", - "system_fingerprint": "2.3.2-dev0-native", + "system_fingerprint": "2.4.1-dev0-native", "usage": null } diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_required.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_required.json index d8d538d6..dbced5b8 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_required.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_required.json @@ -19,10 +19,10 @@ "logprobs": null } ], - "created": 1729084850, + "created": 1732293246, "id": "", "model": "meta-llama/Llama-3.1-8B-Instruct", "object": "chat.completion.chunk", - "system_fingerprint": "2.3.2-dev0-native", + "system_fingerprint": "2.4.1-dev0-native", "usage": null } diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json index f72a5d38..27d2f9ca 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json @@ -6,7 +6,7 @@ "role": "assistant", "tool_calls": { "function": { - "arguments": "", + "arguments": "<|eot_id|>", "name": null }, "id": "", @@ -14,14 +14,15 @@ "type": "function" } }, - "finish_reason": "eos_token", + "finish_reason": "stop", "index": 0, "logprobs": null } ], - "created": 1712788218, + "created": 1732293235, "id": "", - "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", - "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "model": "meta-llama/Llama-3.1-8B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "2.4.1-dev0-native", + "usage": null } diff --git a/integration-tests/models/test_grammar_response_format_llama.py b/integration-tests/models/test_grammar_response_format_llama.py index 3c46cefe..f2a8a96d 100644 --- a/integration-tests/models/test_grammar_response_format_llama.py +++ b/integration-tests/models/test_grammar_response_format_llama.py @@ -55,7 +55,7 @@ async def test_grammar_response_format_llama_json(llama_grammar, response_snapsh called = chat_completion["choices"][0]["message"]["content"] assert response.status_code == 200 - assert called == '{ "temperature": [ 26, 30, 33, 29 ] ,"unit": "Fahrenheit" }' + assert called == '{ "unit": "fahrenheit", "temperature": [ 72, 79, 88 ] }' assert chat_completion == response_snapshot diff --git a/integration-tests/models/test_tools_llama.py b/integration-tests/models/test_tools_llama.py index b5821945..70c3aff0 100644 --- a/integration-tests/models/test_tools_llama.py +++ b/integration-tests/models/test_tools_llama.py @@ -101,7 +101,7 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna "function": { "description": None, "name": "get_current_weather", - "arguments": {"format": "celsius", "location": "Brooklyn, NY"}, + "arguments": {"format": "celsius", "location": "Brooklyn, New York"}, }, } ] @@ -138,7 +138,7 @@ async def test_flash_llama_grammar_tools_auto( "function": { "description": None, "name": "get_current_weather", - "arguments": {"format": "celsius", "location": "Brooklyn, NY"}, + "arguments": {"format": "celsius", "location": "Brooklyn, New York"}, }, } ] @@ -219,7 +219,7 @@ async def test_flash_llama_grammar_tools_stream( assert ( tool_calls_generated - == '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "Paris, France"}}<|eot_id|>' + == '{"function": {"_name": "get_current_weather", "location": "Paris, France", "format": "celsius"}}<|eot_id|>' ) assert count == 28 assert last_response == response_snapshot @@ -366,7 +366,7 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_required( assert count == 29 assert ( tool_calls_generated - == '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "San Francisco, CA"}}<|eot_id|>' + == '{"function": {"_name": "get_current_weather", "location": "San Francisco, CA", "format": "celsius"}}<|eot_id|>' ) assert last_response == response_snapshot @@ -465,6 +465,6 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object( assert count == 39 assert ( tool_calls_generated - == '{"function": {"_name": "get_n_day_weather_forecast", "format": "celsius", "location": "San Francisco, CA", "num_days":3}}<|eot_id|>' + == '{"function": {"_name": "get_n_day_weather_forecast", "location": "San Francisco, CA", "format": "celsius", "num_days":3}}<|eot_id|>' ) assert last_response == response_snapshot diff --git a/router/Cargo.toml b/router/Cargo.toml index 83d85327..9258fe03 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -8,6 +8,7 @@ authors.workspace = true homepage.workspace = true [dependencies] +anyhow = "1" async-trait = "0.1.74" async-stream = "0.3.5" axum = { version = "0.7", features = ["json"] } @@ -22,6 +23,7 @@ metrics-exporter-prometheus = { workspace = true } nohash-hasher = "0.2.0" opentelemetry = { version = "0.20.0", features = ["rt-tokio"] } opentelemetry-otlp = "0.13.0" +outlines-core = { git = "https://github.com/dottxt-ai/outlines-core.git", rev = "ba10c619fc9bf3c487e43f49bdecb95a24bb465c" } rand = "0.8.5" reqwest = { version = "0.11.20", features = [] } serde = "1.0.188" diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index ceba14a6..f5f1dbca 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -804,7 +804,7 @@ mod tests { let tool_prompt = "This default prompt will be used".to_string(); let tools_and_prompt = Some((tools, tool_prompt)); let result = ct.apply(msgs, tools_and_prompt); - let expected = "[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today? [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}}}]\nThis default prompt will be used [/INST]".to_string(); + let expected = "[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today? [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"The city and state, e.g. San Francisco, CA\"},\"format\":{\"type\":\"string\",\"enum\":[\"celsius\",\"fahrenheit\"],\"description\":\"The temperature unit to use. Infer this from the users location.\"}},\"required\":[\"location\",\"format\"]}}}]\nThis default prompt will be used [/INST]".to_string(); assert_eq!(result.unwrap(), expected); } diff --git a/router/src/validation.rs b/router/src/validation.rs index 3cd85a6e..032638ab 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -9,6 +9,7 @@ use crate::{PyTokenizer, Tokenizer}; use base64::{engine::general_purpose::STANDARD, Engine}; use image::{ImageFormat, ImageReader}; use jsonschema::{Draft, JSONSchema}; +use outlines_core::json_schema::to_regex as json_schema_to_regex; use rand::{thread_rng, Rng}; use serde_json::Value; use std::io::Cursor; @@ -351,11 +352,13 @@ impl Validation { "Grammar must have a 'properties' field".to_string(), ))?; - // Serialize json to string - ValidGrammar::Json( - serde_json::to_string(&json) - .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?, - ) + // Do compilation in the router for performance. In the future, we + // should also move regex -> automaton compilation in the router, + // but this is not yet supported in pure Rust by outlines-core. + let grammar_regex = json_schema_to_regex(&json, None, &json) + .map_err(ValidationError::RegexFromSchema)?; + + ValidGrammar::Regex(grammar_regex.to_string()) } GrammarType::Regex(regex) => ValidGrammar::Regex(regex), }; @@ -810,6 +813,8 @@ pub enum ValidationError { Grammar, #[error("grammar is not valid: {0}")] InvalidGrammar(String), + #[error("cannot compile regex from schema: {0}")] + RegexFromSchema(anyhow::Error), #[error("base64 encoding is invalid: {0}")] InvalidBase64(#[from] base64::DecodeError), #[error("invalid image: {0}")] diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index d53f070c..132e441b 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -1,19 +1,19 @@ +from functools import lru_cache import math +import time import torch +from typing import List, Optional, DefaultDict from loguru import logger from typing import Dict, Union from text_generation_server.pb.generate_pb2 import GrammarType from outlines.fsm.guide import RegexGuide -from outlines.fsm.json_schema import build_regex_from_schema -from functools import lru_cache -from typing import List, Optional, DefaultDict -import time from transformers import ( LogitsWarper, LogitsProcessor, + PreTrainedTokenizerBase, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, @@ -484,7 +484,13 @@ class GrammarLogitProcessor(LogitsProcessor): fsm_state: DefaultDict[int, int] fsm: RegexGuide - def __init__(self, tokenizer, device, grammar, grammar_type): + def __init__( + self, + tokenizer: Optional[PreTrainedTokenizerBase], + device: str, + grammar: str, + grammar_type: GrammarType, + ): self.device = device self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer) self.fsm = GrammarLogitProcessor._cached_compile_fsm( @@ -519,18 +525,20 @@ class GrammarLogitProcessor(LogitsProcessor): # TODO: move grammar compilation into the router @staticmethod @lru_cache(maxsize=32, typed=True) - def _cached_compile_fsm(grammar_type, schema, tokenizer): + def _cached_compile_fsm( + grammar_type: GrammarType, + schema: str, + tokenizer: Optional[PreTrainedTokenizerBase], + ): start_time = time.time() if grammar_type == GrammarType.GRAMMAR_TYPE_JSON: - try: - schema = build_regex_from_schema(schema) - # TODO: this is only here short term to avoid crashing the python server, mid term we want this in the rust/router layer - except Exception as e: - logger.error(f"Error compiling FSM, grammar won't be enforced \n{e}") - # allows everything - schema = "(.*?)" - elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX: - pass # schema is already a regex just here for clarity + # JSON schema is compiled by the v3 router. + logger.error( + "Non-regex grammars must be compiled by the router, grammar won't be enforced" + ) + # allows everything + schema = "(.*?)" + fsm = RegexGuide.from_regex(schema, tokenizer) logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s") return fsm