Move JSON grammar -> regex grammar conversion to the router (#2772)
* Move JSON grammar -> regex grammar conversion to the router This change moves the JSON grammar -> regex grammar conversion to the router by adding a dependency on the `outlines-core` Rust crate. In contrast to the Python implementation, the conversions are not LRU-cached since they seem to be fast enough: simple schema time: [5.8293 µs 5.8307 µs 5.8320 µs] change: [-13.166% -12.884% -12.641%] (p = 0.00 < 0.05) Performance has improved. complex schema time: [14.875 µs 14.881 µs 14.887 µs] change: [-2.1637% -1.9914% -1.7852%] (p = 0.00 < 0.05) Performance has improved. Using the schemas from: https://github.com/dottxt-ai/outlines-core/blob/main/benchmarks/bench_json_schema.py
This commit is contained in:
parent
c637d68d74
commit
289aa48554
|
@ -4,6 +4,7 @@ repos:
|
||||||
hooks:
|
hooks:
|
||||||
- id: check-yaml
|
- id: check-yaml
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
|
exclude: crate-hashes.json
|
||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
exclude: docs/source/reference/launcher.md
|
exclude: docs/source/reference/launcher.md
|
||||||
- repo: https://github.com/psf/black
|
- repo: https://github.com/psf/black
|
||||||
|
|
|
@ -3005,6 +3005,17 @@ dependencies = [
|
||||||
"num-traits",
|
"num-traits",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "outlines-core"
|
||||||
|
version = "0.1.0"
|
||||||
|
source = "git+https://github.com/dottxt-ai/outlines-core.git?rev=ba10c619fc9bf3c487e43f49bdecb95a24bb465c#ba10c619fc9bf3c487e43f49bdecb95a24bb465c"
|
||||||
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
|
"regex",
|
||||||
|
"serde-pyobject",
|
||||||
|
"serde_json",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "overload"
|
name = "overload"
|
||||||
version = "0.1.1"
|
version = "0.1.1"
|
||||||
|
@ -3952,6 +3963,16 @@ dependencies = [
|
||||||
"serde_derive",
|
"serde_derive",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "serde-pyobject"
|
||||||
|
version = "0.4.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ca4b0aad8b225845739a0030a0d5cc2ae949c56a86a7daf9226c7df7c2016d16"
|
||||||
|
dependencies = [
|
||||||
|
"pyo3",
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_cbor"
|
name = "serde_cbor"
|
||||||
version = "0.11.2"
|
version = "0.11.2"
|
||||||
|
@ -3979,6 +4000,7 @@ version = "1.0.133"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377"
|
checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"indexmap 2.6.0",
|
||||||
"itoa",
|
"itoa",
|
||||||
"memchr",
|
"memchr",
|
||||||
"ryu",
|
"ryu",
|
||||||
|
@ -4430,6 +4452,7 @@ dependencies = [
|
||||||
name = "text-generation-router"
|
name = "text-generation-router"
|
||||||
version = "2.4.2-dev0"
|
version = "2.4.2-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"axum 0.7.9",
|
"axum 0.7.9",
|
||||||
|
@ -4453,6 +4476,7 @@ dependencies = [
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"opentelemetry 0.20.0",
|
"opentelemetry 0.20.0",
|
||||||
"opentelemetry-otlp",
|
"opentelemetry-otlp",
|
||||||
|
"outlines-core",
|
||||||
"pyo3",
|
"pyo3",
|
||||||
"rand",
|
"rand",
|
||||||
"regex",
|
"regex",
|
||||||
|
|
|
@ -0,0 +1,3 @@
|
||||||
|
{
|
||||||
|
"git+https://github.com/dottxt-ai/outlines-core.git?rev=ba10c619fc9bf3c487e43f49bdecb95a24bb465c#outlines-core@0.1.0": "1j9dcd831b0bmmjk2n4aag3x47qnqmkpg4gqpvwwyic7744llbfm"
|
||||||
|
}
|
|
@ -1,23 +1,23 @@
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"finish_reason": "eos_token",
|
"finish_reason": "stop",
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": null,
|
"logprobs": null,
|
||||||
"message": {
|
"message": {
|
||||||
"content": "{ \"temperature\": [ 26, 30, 33, 29 ] ,\"unit\": \"Fahrenheit\" }",
|
"content": "{ \"unit\": \"fahrenheit\", \"temperature\": [ 72, 79, 88 ] }",
|
||||||
"role": "assistant"
|
"role": "assistant"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1718044128,
|
"created": 1732525803,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "chat.completion",
|
||||||
"system_fingerprint": "2.0.5-dev0-native",
|
"system_fingerprint": "2.4.1-dev0-native",
|
||||||
"usage": {
|
"usage": {
|
||||||
"completion_tokens": 39,
|
"completion_tokens": 29,
|
||||||
"prompt_tokens": 136,
|
"prompt_tokens": 136,
|
||||||
"total_tokens": 175
|
"total_tokens": 165
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"finish_reason": "eos_token",
|
"finish_reason": "stop",
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": null,
|
"logprobs": null,
|
||||||
"message": {
|
"message": {
|
||||||
|
@ -13,12 +13,12 @@
|
||||||
"function": {
|
"function": {
|
||||||
"arguments": {
|
"arguments": {
|
||||||
"format": "celsius",
|
"format": "celsius",
|
||||||
"location": "Brooklyn"
|
"location": "Brooklyn, New York"
|
||||||
},
|
},
|
||||||
"description": null,
|
"description": null,
|
||||||
"name": "get_current_weather"
|
"name": "get_current_weather"
|
||||||
},
|
},
|
||||||
"id": 0,
|
"id": "0",
|
||||||
"type": "function"
|
"type": "function"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
@ -26,14 +26,14 @@
|
||||||
"usage": null
|
"usage": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1712782670,
|
"created": 1732293383,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "text_completion",
|
"object": "chat.completion",
|
||||||
"system_fingerprint": "2.0.1-native",
|
"system_fingerprint": "2.4.1-dev0-native",
|
||||||
"usage": {
|
"usage": {
|
||||||
"completion_tokens": 37,
|
"completion_tokens": 30,
|
||||||
"prompt_tokens": 524,
|
"prompt_tokens": 615,
|
||||||
"total_tokens": 561
|
"total_tokens": 645
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"finish_reason": "eos_token",
|
"finish_reason": "stop",
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": null,
|
"logprobs": null,
|
||||||
"message": {
|
"message": {
|
||||||
|
@ -13,12 +13,12 @@
|
||||||
"function": {
|
"function": {
|
||||||
"arguments": {
|
"arguments": {
|
||||||
"format": "celsius",
|
"format": "celsius",
|
||||||
"location": "Brooklyn"
|
"location": "Brooklyn, New York"
|
||||||
},
|
},
|
||||||
"description": null,
|
"description": null,
|
||||||
"name": "get_current_weather"
|
"name": "get_current_weather"
|
||||||
},
|
},
|
||||||
"id": 0,
|
"id": "0",
|
||||||
"type": "function"
|
"type": "function"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
@ -26,14 +26,14 @@
|
||||||
"usage": null
|
"usage": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1712787937,
|
"created": 1732293384,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "text_completion",
|
"object": "chat.completion",
|
||||||
"system_fingerprint": "2.0.1-native",
|
"system_fingerprint": "2.4.1-dev0-native",
|
||||||
"usage": {
|
"usage": {
|
||||||
"completion_tokens": 37,
|
"completion_tokens": 30,
|
||||||
"prompt_tokens": 524,
|
"prompt_tokens": 615,
|
||||||
"total_tokens": 561
|
"total_tokens": 645
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,10 +18,10 @@
|
||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1729084854,
|
"created": 1732293254,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
"system_fingerprint": "2.3.2-dev0-native",
|
"system_fingerprint": "2.4.1-dev0-native",
|
||||||
"usage": null
|
"usage": null
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,10 +19,10 @@
|
||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1729084850,
|
"created": 1732293246,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
"system_fingerprint": "2.3.2-dev0-native",
|
"system_fingerprint": "2.4.1-dev0-native",
|
||||||
"usage": null
|
"usage": null
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"tool_calls": {
|
"tool_calls": {
|
||||||
"function": {
|
"function": {
|
||||||
"arguments": "</s>",
|
"arguments": "<|eot_id|>",
|
||||||
"name": null
|
"name": null
|
||||||
},
|
},
|
||||||
"id": "",
|
"id": "",
|
||||||
|
@ -14,14 +14,15 @@
|
||||||
"type": "function"
|
"type": "function"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"finish_reason": "eos_token",
|
"finish_reason": "stop",
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1712788218,
|
"created": 1732293235,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "text_completion",
|
"object": "chat.completion.chunk",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.4.1-dev0-native",
|
||||||
|
"usage": null
|
||||||
}
|
}
|
||||||
|
|
|
@ -55,7 +55,7 @@ async def test_grammar_response_format_llama_json(llama_grammar, response_snapsh
|
||||||
called = chat_completion["choices"][0]["message"]["content"]
|
called = chat_completion["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert called == '{ "temperature": [ 26, 30, 33, 29 ] ,"unit": "Fahrenheit" }'
|
assert called == '{ "unit": "fahrenheit", "temperature": [ 72, 79, 88 ] }'
|
||||||
assert chat_completion == response_snapshot
|
assert chat_completion == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -101,7 +101,7 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
|
||||||
"function": {
|
"function": {
|
||||||
"description": None,
|
"description": None,
|
||||||
"name": "get_current_weather",
|
"name": "get_current_weather",
|
||||||
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
|
"arguments": {"format": "celsius", "location": "Brooklyn, New York"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
@ -138,7 +138,7 @@ async def test_flash_llama_grammar_tools_auto(
|
||||||
"function": {
|
"function": {
|
||||||
"description": None,
|
"description": None,
|
||||||
"name": "get_current_weather",
|
"name": "get_current_weather",
|
||||||
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
|
"arguments": {"format": "celsius", "location": "Brooklyn, New York"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
@ -219,7 +219,7 @@ async def test_flash_llama_grammar_tools_stream(
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
tool_calls_generated
|
tool_calls_generated
|
||||||
== '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "Paris, France"}}<|eot_id|>'
|
== '{"function": {"_name": "get_current_weather", "location": "Paris, France", "format": "celsius"}}<|eot_id|>'
|
||||||
)
|
)
|
||||||
assert count == 28
|
assert count == 28
|
||||||
assert last_response == response_snapshot
|
assert last_response == response_snapshot
|
||||||
|
@ -366,7 +366,7 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_required(
|
||||||
assert count == 29
|
assert count == 29
|
||||||
assert (
|
assert (
|
||||||
tool_calls_generated
|
tool_calls_generated
|
||||||
== '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "San Francisco, CA"}}<|eot_id|>'
|
== '{"function": {"_name": "get_current_weather", "location": "San Francisco, CA", "format": "celsius"}}<|eot_id|>'
|
||||||
)
|
)
|
||||||
assert last_response == response_snapshot
|
assert last_response == response_snapshot
|
||||||
|
|
||||||
|
@ -465,6 +465,6 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object(
|
||||||
assert count == 39
|
assert count == 39
|
||||||
assert (
|
assert (
|
||||||
tool_calls_generated
|
tool_calls_generated
|
||||||
== '{"function": {"_name": "get_n_day_weather_forecast", "format": "celsius", "location": "San Francisco, CA", "num_days":3}}<|eot_id|>'
|
== '{"function": {"_name": "get_n_day_weather_forecast", "location": "San Francisco, CA", "format": "celsius", "num_days":3}}<|eot_id|>'
|
||||||
)
|
)
|
||||||
assert last_response == response_snapshot
|
assert last_response == response_snapshot
|
||||||
|
|
|
@ -8,6 +8,7 @@ authors.workspace = true
|
||||||
homepage.workspace = true
|
homepage.workspace = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
anyhow = "1"
|
||||||
async-trait = "0.1.74"
|
async-trait = "0.1.74"
|
||||||
async-stream = "0.3.5"
|
async-stream = "0.3.5"
|
||||||
axum = { version = "0.7", features = ["json"] }
|
axum = { version = "0.7", features = ["json"] }
|
||||||
|
@ -22,6 +23,7 @@ metrics-exporter-prometheus = { workspace = true }
|
||||||
nohash-hasher = "0.2.0"
|
nohash-hasher = "0.2.0"
|
||||||
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
|
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
|
||||||
opentelemetry-otlp = "0.13.0"
|
opentelemetry-otlp = "0.13.0"
|
||||||
|
outlines-core = { git = "https://github.com/dottxt-ai/outlines-core.git", rev = "ba10c619fc9bf3c487e43f49bdecb95a24bb465c" }
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
reqwest = { version = "0.11.20", features = [] }
|
reqwest = { version = "0.11.20", features = [] }
|
||||||
serde = "1.0.188"
|
serde = "1.0.188"
|
||||||
|
|
|
@ -804,7 +804,7 @@ mod tests {
|
||||||
let tool_prompt = "This default prompt will be used".to_string();
|
let tool_prompt = "This default prompt will be used".to_string();
|
||||||
let tools_and_prompt = Some((tools, tool_prompt));
|
let tools_and_prompt = Some((tools, tool_prompt));
|
||||||
let result = ct.apply(msgs, tools_and_prompt);
|
let result = ct.apply(msgs, tools_and_prompt);
|
||||||
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}}}]\nThis default prompt will be used [/INST]".to_string();
|
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"The city and state, e.g. San Francisco, CA\"},\"format\":{\"type\":\"string\",\"enum\":[\"celsius\",\"fahrenheit\"],\"description\":\"The temperature unit to use. Infer this from the users location.\"}},\"required\":[\"location\",\"format\"]}}}]\nThis default prompt will be used [/INST]".to_string();
|
||||||
assert_eq!(result.unwrap(), expected);
|
assert_eq!(result.unwrap(), expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,7 @@ use crate::{PyTokenizer, Tokenizer};
|
||||||
use base64::{engine::general_purpose::STANDARD, Engine};
|
use base64::{engine::general_purpose::STANDARD, Engine};
|
||||||
use image::{ImageFormat, ImageReader};
|
use image::{ImageFormat, ImageReader};
|
||||||
use jsonschema::{Draft, JSONSchema};
|
use jsonschema::{Draft, JSONSchema};
|
||||||
|
use outlines_core::json_schema::to_regex as json_schema_to_regex;
|
||||||
use rand::{thread_rng, Rng};
|
use rand::{thread_rng, Rng};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::io::Cursor;
|
use std::io::Cursor;
|
||||||
|
@ -351,11 +352,13 @@ impl Validation {
|
||||||
"Grammar must have a 'properties' field".to_string(),
|
"Grammar must have a 'properties' field".to_string(),
|
||||||
))?;
|
))?;
|
||||||
|
|
||||||
// Serialize json to string
|
// Do compilation in the router for performance. In the future, we
|
||||||
ValidGrammar::Json(
|
// should also move regex -> automaton compilation in the router,
|
||||||
serde_json::to_string(&json)
|
// but this is not yet supported in pure Rust by outlines-core.
|
||||||
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?,
|
let grammar_regex = json_schema_to_regex(&json, None, &json)
|
||||||
)
|
.map_err(ValidationError::RegexFromSchema)?;
|
||||||
|
|
||||||
|
ValidGrammar::Regex(grammar_regex.to_string())
|
||||||
}
|
}
|
||||||
GrammarType::Regex(regex) => ValidGrammar::Regex(regex),
|
GrammarType::Regex(regex) => ValidGrammar::Regex(regex),
|
||||||
};
|
};
|
||||||
|
@ -810,6 +813,8 @@ pub enum ValidationError {
|
||||||
Grammar,
|
Grammar,
|
||||||
#[error("grammar is not valid: {0}")]
|
#[error("grammar is not valid: {0}")]
|
||||||
InvalidGrammar(String),
|
InvalidGrammar(String),
|
||||||
|
#[error("cannot compile regex from schema: {0}")]
|
||||||
|
RegexFromSchema(anyhow::Error),
|
||||||
#[error("base64 encoding is invalid: {0}")]
|
#[error("base64 encoding is invalid: {0}")]
|
||||||
InvalidBase64(#[from] base64::DecodeError),
|
InvalidBase64(#[from] base64::DecodeError),
|
||||||
#[error("invalid image: {0}")]
|
#[error("invalid image: {0}")]
|
||||||
|
|
|
@ -1,19 +1,19 @@
|
||||||
|
from functools import lru_cache
|
||||||
import math
|
import math
|
||||||
|
import time
|
||||||
import torch
|
import torch
|
||||||
|
from typing import List, Optional, DefaultDict
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import Dict, Union
|
from typing import Dict, Union
|
||||||
from text_generation_server.pb.generate_pb2 import GrammarType
|
from text_generation_server.pb.generate_pb2 import GrammarType
|
||||||
|
|
||||||
from outlines.fsm.guide import RegexGuide
|
from outlines.fsm.guide import RegexGuide
|
||||||
from outlines.fsm.json_schema import build_regex_from_schema
|
|
||||||
from functools import lru_cache
|
|
||||||
from typing import List, Optional, DefaultDict
|
|
||||||
import time
|
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
LogitsWarper,
|
LogitsWarper,
|
||||||
LogitsProcessor,
|
LogitsProcessor,
|
||||||
|
PreTrainedTokenizerBase,
|
||||||
TemperatureLogitsWarper,
|
TemperatureLogitsWarper,
|
||||||
TopKLogitsWarper,
|
TopKLogitsWarper,
|
||||||
TopPLogitsWarper,
|
TopPLogitsWarper,
|
||||||
|
@ -484,7 +484,13 @@ class GrammarLogitProcessor(LogitsProcessor):
|
||||||
fsm_state: DefaultDict[int, int]
|
fsm_state: DefaultDict[int, int]
|
||||||
fsm: RegexGuide
|
fsm: RegexGuide
|
||||||
|
|
||||||
def __init__(self, tokenizer, device, grammar, grammar_type):
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer: Optional[PreTrainedTokenizerBase],
|
||||||
|
device: str,
|
||||||
|
grammar: str,
|
||||||
|
grammar_type: GrammarType,
|
||||||
|
):
|
||||||
self.device = device
|
self.device = device
|
||||||
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
|
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
|
||||||
self.fsm = GrammarLogitProcessor._cached_compile_fsm(
|
self.fsm = GrammarLogitProcessor._cached_compile_fsm(
|
||||||
|
@ -519,18 +525,20 @@ class GrammarLogitProcessor(LogitsProcessor):
|
||||||
# TODO: move grammar compilation into the router
|
# TODO: move grammar compilation into the router
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@lru_cache(maxsize=32, typed=True)
|
@lru_cache(maxsize=32, typed=True)
|
||||||
def _cached_compile_fsm(grammar_type, schema, tokenizer):
|
def _cached_compile_fsm(
|
||||||
|
grammar_type: GrammarType,
|
||||||
|
schema: str,
|
||||||
|
tokenizer: Optional[PreTrainedTokenizerBase],
|
||||||
|
):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
|
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
|
||||||
try:
|
# JSON schema is compiled by the v3 router.
|
||||||
schema = build_regex_from_schema(schema)
|
logger.error(
|
||||||
# TODO: this is only here short term to avoid crashing the python server, mid term we want this in the rust/router layer
|
"Non-regex grammars must be compiled by the router, grammar won't be enforced"
|
||||||
except Exception as e:
|
)
|
||||||
logger.error(f"Error compiling FSM, grammar won't be enforced \n{e}")
|
# allows everything
|
||||||
# allows everything
|
schema = "(.*?)"
|
||||||
schema = "(.*?)"
|
|
||||||
elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
|
|
||||||
pass # schema is already a regex just here for clarity
|
|
||||||
fsm = RegexGuide.from_regex(schema, tokenizer)
|
fsm = RegexGuide.from_regex(schema, tokenizer)
|
||||||
logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s")
|
logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s")
|
||||||
return fsm
|
return fsm
|
||||||
|
|
Loading…
Reference in New Issue