Move JSON grammar -> regex grammar conversion to the router (#2772)

* Move JSON grammar -> regex grammar conversion to the router

This change moves the JSON grammar -> regex grammar conversion to the
router by adding a dependency on the `outlines-core` Rust crate. In
contrast to the Python implementation, the conversions are not LRU-cached
since they seem to be fast enough:

simple schema           time:   [5.8293 µs 5.8307 µs 5.8320 µs]
                        change: [-13.166% -12.884% -12.641%] (p = 0.00 < 0.05)
                        Performance has improved.

complex schema          time:   [14.875 µs 14.881 µs 14.887 µs]
                        change: [-2.1637% -1.9914% -1.7852%] (p = 0.00 < 0.05)
                        Performance has improved.

Using the schemas from:
https://github.com/dottxt-ai/outlines-core/blob/main/benchmarks/bench_json_schema.py
This commit is contained in:
Daniël de Kok 2024-11-25 18:47:34 +01:00 committed by GitHub
parent c637d68d74
commit 289aa48554
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 108 additions and 64 deletions

View File

@ -4,6 +4,7 @@ repos:
hooks: hooks:
- id: check-yaml - id: check-yaml
- id: end-of-file-fixer - id: end-of-file-fixer
exclude: crate-hashes.json
- id: trailing-whitespace - id: trailing-whitespace
exclude: docs/source/reference/launcher.md exclude: docs/source/reference/launcher.md
- repo: https://github.com/psf/black - repo: https://github.com/psf/black

24
Cargo.lock generated
View File

@ -3005,6 +3005,17 @@ dependencies = [
"num-traits", "num-traits",
] ]
[[package]]
name = "outlines-core"
version = "0.1.0"
source = "git+https://github.com/dottxt-ai/outlines-core.git?rev=ba10c619fc9bf3c487e43f49bdecb95a24bb465c#ba10c619fc9bf3c487e43f49bdecb95a24bb465c"
dependencies = [
"anyhow",
"regex",
"serde-pyobject",
"serde_json",
]
[[package]] [[package]]
name = "overload" name = "overload"
version = "0.1.1" version = "0.1.1"
@ -3952,6 +3963,16 @@ dependencies = [
"serde_derive", "serde_derive",
] ]
[[package]]
name = "serde-pyobject"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ca4b0aad8b225845739a0030a0d5cc2ae949c56a86a7daf9226c7df7c2016d16"
dependencies = [
"pyo3",
"serde",
]
[[package]] [[package]]
name = "serde_cbor" name = "serde_cbor"
version = "0.11.2" version = "0.11.2"
@ -3979,6 +4000,7 @@ version = "1.0.133"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377" checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377"
dependencies = [ dependencies = [
"indexmap 2.6.0",
"itoa", "itoa",
"memchr", "memchr",
"ryu", "ryu",
@ -4430,6 +4452,7 @@ dependencies = [
name = "text-generation-router" name = "text-generation-router"
version = "2.4.2-dev0" version = "2.4.2-dev0"
dependencies = [ dependencies = [
"anyhow",
"async-stream", "async-stream",
"async-trait", "async-trait",
"axum 0.7.9", "axum 0.7.9",
@ -4453,6 +4476,7 @@ dependencies = [
"once_cell", "once_cell",
"opentelemetry 0.20.0", "opentelemetry 0.20.0",
"opentelemetry-otlp", "opentelemetry-otlp",
"outlines-core",
"pyo3", "pyo3",
"rand", "rand",
"regex", "regex",

3
crate-hashes.json Normal file
View File

@ -0,0 +1,3 @@
{
"git+https://github.com/dottxt-ai/outlines-core.git?rev=ba10c619fc9bf3c487e43f49bdecb95a24bb465c#outlines-core@0.1.0": "1j9dcd831b0bmmjk2n4aag3x47qnqmkpg4gqpvwwyic7744llbfm"
}

View File

@ -1,23 +1,23 @@
{ {
"choices": [ "choices": [
{ {
"finish_reason": "eos_token", "finish_reason": "stop",
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"message": { "message": {
"content": "{ \"temperature\": [ 26, 30, 33, 29 ] ,\"unit\": \"Fahrenheit\" }", "content": "{ \"unit\": \"fahrenheit\", \"temperature\": [ 72, 79, 88 ] }",
"role": "assistant" "role": "assistant"
} }
} }
], ],
"created": 1718044128, "created": 1732525803,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "chat.completion",
"system_fingerprint": "2.0.5-dev0-native", "system_fingerprint": "2.4.1-dev0-native",
"usage": { "usage": {
"completion_tokens": 39, "completion_tokens": 29,
"prompt_tokens": 136, "prompt_tokens": 136,
"total_tokens": 175 "total_tokens": 165
} }
} }

View File

@ -1,7 +1,7 @@
{ {
"choices": [ "choices": [
{ {
"finish_reason": "eos_token", "finish_reason": "stop",
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"message": { "message": {
@ -13,12 +13,12 @@
"function": { "function": {
"arguments": { "arguments": {
"format": "celsius", "format": "celsius",
"location": "Brooklyn" "location": "Brooklyn, New York"
}, },
"description": null, "description": null,
"name": "get_current_weather" "name": "get_current_weather"
}, },
"id": 0, "id": "0",
"type": "function" "type": "function"
} }
] ]
@ -26,14 +26,14 @@
"usage": null "usage": null
} }
], ],
"created": 1712782670, "created": 1732293383,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "chat.completion",
"system_fingerprint": "2.0.1-native", "system_fingerprint": "2.4.1-dev0-native",
"usage": { "usage": {
"completion_tokens": 37, "completion_tokens": 30,
"prompt_tokens": 524, "prompt_tokens": 615,
"total_tokens": 561 "total_tokens": 645
} }
} }

View File

@ -1,7 +1,7 @@
{ {
"choices": [ "choices": [
{ {
"finish_reason": "eos_token", "finish_reason": "stop",
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"message": { "message": {
@ -13,12 +13,12 @@
"function": { "function": {
"arguments": { "arguments": {
"format": "celsius", "format": "celsius",
"location": "Brooklyn" "location": "Brooklyn, New York"
}, },
"description": null, "description": null,
"name": "get_current_weather" "name": "get_current_weather"
}, },
"id": 0, "id": "0",
"type": "function" "type": "function"
} }
] ]
@ -26,14 +26,14 @@
"usage": null "usage": null
} }
], ],
"created": 1712787937, "created": 1732293384,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "chat.completion",
"system_fingerprint": "2.0.1-native", "system_fingerprint": "2.4.1-dev0-native",
"usage": { "usage": {
"completion_tokens": 37, "completion_tokens": 30,
"prompt_tokens": 524, "prompt_tokens": 615,
"total_tokens": 561 "total_tokens": 645
} }
} }

View File

@ -18,10 +18,10 @@
"logprobs": null "logprobs": null
} }
], ],
"created": 1729084854, "created": 1732293254,
"id": "", "id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct", "model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
"system_fingerprint": "2.3.2-dev0-native", "system_fingerprint": "2.4.1-dev0-native",
"usage": null "usage": null
} }

View File

@ -19,10 +19,10 @@
"logprobs": null "logprobs": null
} }
], ],
"created": 1729084850, "created": 1732293246,
"id": "", "id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct", "model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
"system_fingerprint": "2.3.2-dev0-native", "system_fingerprint": "2.4.1-dev0-native",
"usage": null "usage": null
} }

View File

@ -6,7 +6,7 @@
"role": "assistant", "role": "assistant",
"tool_calls": { "tool_calls": {
"function": { "function": {
"arguments": "</s>", "arguments": "<|eot_id|>",
"name": null "name": null
}, },
"id": "", "id": "",
@ -14,14 +14,15 @@
"type": "function" "type": "function"
} }
}, },
"finish_reason": "eos_token", "finish_reason": "stop",
"index": 0, "index": 0,
"logprobs": null "logprobs": null
} }
], ],
"created": 1712788218, "created": 1732293235,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "chat.completion.chunk",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.4.1-dev0-native",
"usage": null
} }

View File

@ -55,7 +55,7 @@ async def test_grammar_response_format_llama_json(llama_grammar, response_snapsh
called = chat_completion["choices"][0]["message"]["content"] called = chat_completion["choices"][0]["message"]["content"]
assert response.status_code == 200 assert response.status_code == 200
assert called == '{ "temperature": [ 26, 30, 33, 29 ] ,"unit": "Fahrenheit" }' assert called == '{ "unit": "fahrenheit", "temperature": [ 72, 79, 88 ] }'
assert chat_completion == response_snapshot assert chat_completion == response_snapshot

View File

@ -101,7 +101,7 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
"function": { "function": {
"description": None, "description": None,
"name": "get_current_weather", "name": "get_current_weather",
"arguments": {"format": "celsius", "location": "Brooklyn, NY"}, "arguments": {"format": "celsius", "location": "Brooklyn, New York"},
}, },
} }
] ]
@ -138,7 +138,7 @@ async def test_flash_llama_grammar_tools_auto(
"function": { "function": {
"description": None, "description": None,
"name": "get_current_weather", "name": "get_current_weather",
"arguments": {"format": "celsius", "location": "Brooklyn, NY"}, "arguments": {"format": "celsius", "location": "Brooklyn, New York"},
}, },
} }
] ]
@ -219,7 +219,7 @@ async def test_flash_llama_grammar_tools_stream(
assert ( assert (
tool_calls_generated tool_calls_generated
== '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "Paris, France"}}<|eot_id|>' == '{"function": {"_name": "get_current_weather", "location": "Paris, France", "format": "celsius"}}<|eot_id|>'
) )
assert count == 28 assert count == 28
assert last_response == response_snapshot assert last_response == response_snapshot
@ -366,7 +366,7 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_required(
assert count == 29 assert count == 29
assert ( assert (
tool_calls_generated tool_calls_generated
== '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "San Francisco, CA"}}<|eot_id|>' == '{"function": {"_name": "get_current_weather", "location": "San Francisco, CA", "format": "celsius"}}<|eot_id|>'
) )
assert last_response == response_snapshot assert last_response == response_snapshot
@ -465,6 +465,6 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object(
assert count == 39 assert count == 39
assert ( assert (
tool_calls_generated tool_calls_generated
== '{"function": {"_name": "get_n_day_weather_forecast", "format": "celsius", "location": "San Francisco, CA", "num_days":3}}<|eot_id|>' == '{"function": {"_name": "get_n_day_weather_forecast", "location": "San Francisco, CA", "format": "celsius", "num_days":3}}<|eot_id|>'
) )
assert last_response == response_snapshot assert last_response == response_snapshot

View File

@ -8,6 +8,7 @@ authors.workspace = true
homepage.workspace = true homepage.workspace = true
[dependencies] [dependencies]
anyhow = "1"
async-trait = "0.1.74" async-trait = "0.1.74"
async-stream = "0.3.5" async-stream = "0.3.5"
axum = { version = "0.7", features = ["json"] } axum = { version = "0.7", features = ["json"] }
@ -22,6 +23,7 @@ metrics-exporter-prometheus = { workspace = true }
nohash-hasher = "0.2.0" nohash-hasher = "0.2.0"
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] } opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.13.0" opentelemetry-otlp = "0.13.0"
outlines-core = { git = "https://github.com/dottxt-ai/outlines-core.git", rev = "ba10c619fc9bf3c487e43f49bdecb95a24bb465c" }
rand = "0.8.5" rand = "0.8.5"
reqwest = { version = "0.11.20", features = [] } reqwest = { version = "0.11.20", features = [] }
serde = "1.0.188" serde = "1.0.188"

View File

@ -804,7 +804,7 @@ mod tests {
let tool_prompt = "This default prompt will be used".to_string(); let tool_prompt = "This default prompt will be used".to_string();
let tools_and_prompt = Some((tools, tool_prompt)); let tools_and_prompt = Some((tools, tool_prompt));
let result = ct.apply(msgs, tools_and_prompt); let result = ct.apply(msgs, tools_and_prompt);
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}}}]\nThis default prompt will be used [/INST]".to_string(); let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"The city and state, e.g. San Francisco, CA\"},\"format\":{\"type\":\"string\",\"enum\":[\"celsius\",\"fahrenheit\"],\"description\":\"The temperature unit to use. Infer this from the users location.\"}},\"required\":[\"location\",\"format\"]}}}]\nThis default prompt will be used [/INST]".to_string();
assert_eq!(result.unwrap(), expected); assert_eq!(result.unwrap(), expected);
} }

View File

@ -9,6 +9,7 @@ use crate::{PyTokenizer, Tokenizer};
use base64::{engine::general_purpose::STANDARD, Engine}; use base64::{engine::general_purpose::STANDARD, Engine};
use image::{ImageFormat, ImageReader}; use image::{ImageFormat, ImageReader};
use jsonschema::{Draft, JSONSchema}; use jsonschema::{Draft, JSONSchema};
use outlines_core::json_schema::to_regex as json_schema_to_regex;
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use serde_json::Value; use serde_json::Value;
use std::io::Cursor; use std::io::Cursor;
@ -351,11 +352,13 @@ impl Validation {
"Grammar must have a 'properties' field".to_string(), "Grammar must have a 'properties' field".to_string(),
))?; ))?;
// Serialize json to string // Do compilation in the router for performance. In the future, we
ValidGrammar::Json( // should also move regex -> automaton compilation in the router,
serde_json::to_string(&json) // but this is not yet supported in pure Rust by outlines-core.
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?, let grammar_regex = json_schema_to_regex(&json, None, &json)
) .map_err(ValidationError::RegexFromSchema)?;
ValidGrammar::Regex(grammar_regex.to_string())
} }
GrammarType::Regex(regex) => ValidGrammar::Regex(regex), GrammarType::Regex(regex) => ValidGrammar::Regex(regex),
}; };
@ -810,6 +813,8 @@ pub enum ValidationError {
Grammar, Grammar,
#[error("grammar is not valid: {0}")] #[error("grammar is not valid: {0}")]
InvalidGrammar(String), InvalidGrammar(String),
#[error("cannot compile regex from schema: {0}")]
RegexFromSchema(anyhow::Error),
#[error("base64 encoding is invalid: {0}")] #[error("base64 encoding is invalid: {0}")]
InvalidBase64(#[from] base64::DecodeError), InvalidBase64(#[from] base64::DecodeError),
#[error("invalid image: {0}")] #[error("invalid image: {0}")]

View File

@ -1,19 +1,19 @@
from functools import lru_cache
import math import math
import time
import torch import torch
from typing import List, Optional, DefaultDict
from loguru import logger from loguru import logger
from typing import Dict, Union from typing import Dict, Union
from text_generation_server.pb.generate_pb2 import GrammarType from text_generation_server.pb.generate_pb2 import GrammarType
from outlines.fsm.guide import RegexGuide from outlines.fsm.guide import RegexGuide
from outlines.fsm.json_schema import build_regex_from_schema
from functools import lru_cache
from typing import List, Optional, DefaultDict
import time
from transformers import ( from transformers import (
LogitsWarper, LogitsWarper,
LogitsProcessor, LogitsProcessor,
PreTrainedTokenizerBase,
TemperatureLogitsWarper, TemperatureLogitsWarper,
TopKLogitsWarper, TopKLogitsWarper,
TopPLogitsWarper, TopPLogitsWarper,
@ -484,7 +484,13 @@ class GrammarLogitProcessor(LogitsProcessor):
fsm_state: DefaultDict[int, int] fsm_state: DefaultDict[int, int]
fsm: RegexGuide fsm: RegexGuide
def __init__(self, tokenizer, device, grammar, grammar_type): def __init__(
self,
tokenizer: Optional[PreTrainedTokenizerBase],
device: str,
grammar: str,
grammar_type: GrammarType,
):
self.device = device self.device = device
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer) self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
self.fsm = GrammarLogitProcessor._cached_compile_fsm( self.fsm = GrammarLogitProcessor._cached_compile_fsm(
@ -519,18 +525,20 @@ class GrammarLogitProcessor(LogitsProcessor):
# TODO: move grammar compilation into the router # TODO: move grammar compilation into the router
@staticmethod @staticmethod
@lru_cache(maxsize=32, typed=True) @lru_cache(maxsize=32, typed=True)
def _cached_compile_fsm(grammar_type, schema, tokenizer): def _cached_compile_fsm(
grammar_type: GrammarType,
schema: str,
tokenizer: Optional[PreTrainedTokenizerBase],
):
start_time = time.time() start_time = time.time()
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON: if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
try: # JSON schema is compiled by the v3 router.
schema = build_regex_from_schema(schema) logger.error(
# TODO: this is only here short term to avoid crashing the python server, mid term we want this in the rust/router layer "Non-regex grammars must be compiled by the router, grammar won't be enforced"
except Exception as e: )
logger.error(f"Error compiling FSM, grammar won't be enforced \n{e}") # allows everything
# allows everything schema = "(.*?)"
schema = "(.*?)"
elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
pass # schema is already a regex just here for clarity
fsm = RegexGuide.from_regex(schema, tokenizer) fsm = RegexGuide.from_regex(schema, tokenizer)
logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s") logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s")
return fsm return fsm