diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 0c8b6885..b5d939cd 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -4,6 +4,7 @@ repos:
hooks:
- id: check-yaml
- id: end-of-file-fixer
+ exclude: crate-hashes.json
- id: trailing-whitespace
exclude: docs/source/reference/launcher.md
- repo: https://github.com/psf/black
diff --git a/Cargo.lock b/Cargo.lock
index 77c49408..72f70fdc 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -3005,6 +3005,17 @@ dependencies = [
"num-traits",
]
+[[package]]
+name = "outlines-core"
+version = "0.1.0"
+source = "git+https://github.com/dottxt-ai/outlines-core.git?rev=ba10c619fc9bf3c487e43f49bdecb95a24bb465c#ba10c619fc9bf3c487e43f49bdecb95a24bb465c"
+dependencies = [
+ "anyhow",
+ "regex",
+ "serde-pyobject",
+ "serde_json",
+]
+
[[package]]
name = "overload"
version = "0.1.1"
@@ -3952,6 +3963,16 @@ dependencies = [
"serde_derive",
]
+[[package]]
+name = "serde-pyobject"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ca4b0aad8b225845739a0030a0d5cc2ae949c56a86a7daf9226c7df7c2016d16"
+dependencies = [
+ "pyo3",
+ "serde",
+]
+
[[package]]
name = "serde_cbor"
version = "0.11.2"
@@ -3979,6 +4000,7 @@ version = "1.0.133"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377"
dependencies = [
+ "indexmap 2.6.0",
"itoa",
"memchr",
"ryu",
@@ -4430,6 +4452,7 @@ dependencies = [
name = "text-generation-router"
version = "2.4.2-dev0"
dependencies = [
+ "anyhow",
"async-stream",
"async-trait",
"axum 0.7.9",
@@ -4453,6 +4476,7 @@ dependencies = [
"once_cell",
"opentelemetry 0.20.0",
"opentelemetry-otlp",
+ "outlines-core",
"pyo3",
"rand",
"regex",
diff --git a/crate-hashes.json b/crate-hashes.json
new file mode 100644
index 00000000..2694759c
--- /dev/null
+++ b/crate-hashes.json
@@ -0,0 +1,3 @@
+{
+ "git+https://github.com/dottxt-ai/outlines-core.git?rev=ba10c619fc9bf3c487e43f49bdecb95a24bb465c#outlines-core@0.1.0": "1j9dcd831b0bmmjk2n4aag3x47qnqmkpg4gqpvwwyic7744llbfm"
+}
\ No newline at end of file
diff --git a/integration-tests/models/__snapshots__/test_grammar_response_format_llama/test_grammar_response_format_llama_json.json b/integration-tests/models/__snapshots__/test_grammar_response_format_llama/test_grammar_response_format_llama_json.json
index 2bd79b1d..38229e0a 100644
--- a/integration-tests/models/__snapshots__/test_grammar_response_format_llama/test_grammar_response_format_llama_json.json
+++ b/integration-tests/models/__snapshots__/test_grammar_response_format_llama/test_grammar_response_format_llama_json.json
@@ -1,23 +1,23 @@
{
"choices": [
{
- "finish_reason": "eos_token",
+ "finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
- "content": "{ \"temperature\": [ 26, 30, 33, 29 ] ,\"unit\": \"Fahrenheit\" }",
+ "content": "{ \"unit\": \"fahrenheit\", \"temperature\": [ 72, 79, 88 ] }",
"role": "assistant"
}
}
],
- "created": 1718044128,
+ "created": 1732525803,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
- "object": "text_completion",
- "system_fingerprint": "2.0.5-dev0-native",
+ "object": "chat.completion",
+ "system_fingerprint": "2.4.1-dev0-native",
"usage": {
- "completion_tokens": 39,
+ "completion_tokens": 29,
"prompt_tokens": 136,
- "total_tokens": 175
+ "total_tokens": 165
}
}
diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json
index a4c34a10..33e223ba 100644
--- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json
+++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json
@@ -1,7 +1,7 @@
{
"choices": [
{
- "finish_reason": "eos_token",
+ "finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
@@ -13,12 +13,12 @@
"function": {
"arguments": {
"format": "celsius",
- "location": "Brooklyn"
+ "location": "Brooklyn, New York"
},
"description": null,
"name": "get_current_weather"
},
- "id": 0,
+ "id": "0",
"type": "function"
}
]
@@ -26,14 +26,14 @@
"usage": null
}
],
- "created": 1712782670,
+ "created": 1732293383,
"id": "",
- "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
- "object": "text_completion",
- "system_fingerprint": "2.0.1-native",
+ "model": "meta-llama/Llama-3.1-8B-Instruct",
+ "object": "chat.completion",
+ "system_fingerprint": "2.4.1-dev0-native",
"usage": {
- "completion_tokens": 37,
- "prompt_tokens": 524,
- "total_tokens": 561
+ "completion_tokens": 30,
+ "prompt_tokens": 615,
+ "total_tokens": 645
}
}
diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json
index 04bcdc4e..92ffbbc1 100644
--- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json
+++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json
@@ -1,7 +1,7 @@
{
"choices": [
{
- "finish_reason": "eos_token",
+ "finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
@@ -13,12 +13,12 @@
"function": {
"arguments": {
"format": "celsius",
- "location": "Brooklyn"
+ "location": "Brooklyn, New York"
},
"description": null,
"name": "get_current_weather"
},
- "id": 0,
+ "id": "0",
"type": "function"
}
]
@@ -26,14 +26,14 @@
"usage": null
}
],
- "created": 1712787937,
+ "created": 1732293384,
"id": "",
- "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
- "object": "text_completion",
- "system_fingerprint": "2.0.1-native",
+ "model": "meta-llama/Llama-3.1-8B-Instruct",
+ "object": "chat.completion",
+ "system_fingerprint": "2.4.1-dev0-native",
"usage": {
- "completion_tokens": 37,
- "prompt_tokens": 524,
- "total_tokens": 561
+ "completion_tokens": 30,
+ "prompt_tokens": 615,
+ "total_tokens": 645
}
}
diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_function_object.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_function_object.json
index e64dd49d..bb8d61c8 100644
--- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_function_object.json
+++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_function_object.json
@@ -18,10 +18,10 @@
"logprobs": null
}
],
- "created": 1729084854,
+ "created": 1732293254,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
- "system_fingerprint": "2.3.2-dev0-native",
+ "system_fingerprint": "2.4.1-dev0-native",
"usage": null
}
diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_required.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_required.json
index d8d538d6..dbced5b8 100644
--- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_required.json
+++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_required.json
@@ -19,10 +19,10 @@
"logprobs": null
}
],
- "created": 1729084850,
+ "created": 1732293246,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
- "system_fingerprint": "2.3.2-dev0-native",
+ "system_fingerprint": "2.4.1-dev0-native",
"usage": null
}
diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json
index f72a5d38..27d2f9ca 100644
--- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json
+++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json
@@ -6,7 +6,7 @@
"role": "assistant",
"tool_calls": {
"function": {
- "arguments": "",
+ "arguments": "<|eot_id|>",
"name": null
},
"id": "",
@@ -14,14 +14,15 @@
"type": "function"
}
},
- "finish_reason": "eos_token",
+ "finish_reason": "stop",
"index": 0,
"logprobs": null
}
],
- "created": 1712788218,
+ "created": 1732293235,
"id": "",
- "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
- "object": "text_completion",
- "system_fingerprint": "2.0.1-native"
+ "model": "meta-llama/Llama-3.1-8B-Instruct",
+ "object": "chat.completion.chunk",
+ "system_fingerprint": "2.4.1-dev0-native",
+ "usage": null
}
diff --git a/integration-tests/models/test_grammar_response_format_llama.py b/integration-tests/models/test_grammar_response_format_llama.py
index 3c46cefe..f2a8a96d 100644
--- a/integration-tests/models/test_grammar_response_format_llama.py
+++ b/integration-tests/models/test_grammar_response_format_llama.py
@@ -55,7 +55,7 @@ async def test_grammar_response_format_llama_json(llama_grammar, response_snapsh
called = chat_completion["choices"][0]["message"]["content"]
assert response.status_code == 200
- assert called == '{ "temperature": [ 26, 30, 33, 29 ] ,"unit": "Fahrenheit" }'
+ assert called == '{ "unit": "fahrenheit", "temperature": [ 72, 79, 88 ] }'
assert chat_completion == response_snapshot
diff --git a/integration-tests/models/test_tools_llama.py b/integration-tests/models/test_tools_llama.py
index b5821945..70c3aff0 100644
--- a/integration-tests/models/test_tools_llama.py
+++ b/integration-tests/models/test_tools_llama.py
@@ -101,7 +101,7 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
"function": {
"description": None,
"name": "get_current_weather",
- "arguments": {"format": "celsius", "location": "Brooklyn, NY"},
+ "arguments": {"format": "celsius", "location": "Brooklyn, New York"},
},
}
]
@@ -138,7 +138,7 @@ async def test_flash_llama_grammar_tools_auto(
"function": {
"description": None,
"name": "get_current_weather",
- "arguments": {"format": "celsius", "location": "Brooklyn, NY"},
+ "arguments": {"format": "celsius", "location": "Brooklyn, New York"},
},
}
]
@@ -219,7 +219,7 @@ async def test_flash_llama_grammar_tools_stream(
assert (
tool_calls_generated
- == '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "Paris, France"}}<|eot_id|>'
+ == '{"function": {"_name": "get_current_weather", "location": "Paris, France", "format": "celsius"}}<|eot_id|>'
)
assert count == 28
assert last_response == response_snapshot
@@ -366,7 +366,7 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_required(
assert count == 29
assert (
tool_calls_generated
- == '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "San Francisco, CA"}}<|eot_id|>'
+ == '{"function": {"_name": "get_current_weather", "location": "San Francisco, CA", "format": "celsius"}}<|eot_id|>'
)
assert last_response == response_snapshot
@@ -465,6 +465,6 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object(
assert count == 39
assert (
tool_calls_generated
- == '{"function": {"_name": "get_n_day_weather_forecast", "format": "celsius", "location": "San Francisco, CA", "num_days":3}}<|eot_id|>'
+ == '{"function": {"_name": "get_n_day_weather_forecast", "location": "San Francisco, CA", "format": "celsius", "num_days":3}}<|eot_id|>'
)
assert last_response == response_snapshot
diff --git a/router/Cargo.toml b/router/Cargo.toml
index 83d85327..9258fe03 100644
--- a/router/Cargo.toml
+++ b/router/Cargo.toml
@@ -8,6 +8,7 @@ authors.workspace = true
homepage.workspace = true
[dependencies]
+anyhow = "1"
async-trait = "0.1.74"
async-stream = "0.3.5"
axum = { version = "0.7", features = ["json"] }
@@ -22,6 +23,7 @@ metrics-exporter-prometheus = { workspace = true }
nohash-hasher = "0.2.0"
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.13.0"
+outlines-core = { git = "https://github.com/dottxt-ai/outlines-core.git", rev = "ba10c619fc9bf3c487e43f49bdecb95a24bb465c" }
rand = "0.8.5"
reqwest = { version = "0.11.20", features = [] }
serde = "1.0.188"
diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs
index ceba14a6..f5f1dbca 100644
--- a/router/src/infer/chat_template.rs
+++ b/router/src/infer/chat_template.rs
@@ -804,7 +804,7 @@ mod tests {
let tool_prompt = "This default prompt will be used".to_string();
let tools_and_prompt = Some((tools, tool_prompt));
let result = ct.apply(msgs, tools_and_prompt);
- let expected = "[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today? [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}}}]\nThis default prompt will be used [/INST]".to_string();
+ let expected = "[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today? [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"The city and state, e.g. San Francisco, CA\"},\"format\":{\"type\":\"string\",\"enum\":[\"celsius\",\"fahrenheit\"],\"description\":\"The temperature unit to use. Infer this from the users location.\"}},\"required\":[\"location\",\"format\"]}}}]\nThis default prompt will be used [/INST]".to_string();
assert_eq!(result.unwrap(), expected);
}
diff --git a/router/src/validation.rs b/router/src/validation.rs
index 3cd85a6e..032638ab 100644
--- a/router/src/validation.rs
+++ b/router/src/validation.rs
@@ -9,6 +9,7 @@ use crate::{PyTokenizer, Tokenizer};
use base64::{engine::general_purpose::STANDARD, Engine};
use image::{ImageFormat, ImageReader};
use jsonschema::{Draft, JSONSchema};
+use outlines_core::json_schema::to_regex as json_schema_to_regex;
use rand::{thread_rng, Rng};
use serde_json::Value;
use std::io::Cursor;
@@ -351,11 +352,13 @@ impl Validation {
"Grammar must have a 'properties' field".to_string(),
))?;
- // Serialize json to string
- ValidGrammar::Json(
- serde_json::to_string(&json)
- .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?,
- )
+ // Do compilation in the router for performance. In the future, we
+ // should also move regex -> automaton compilation in the router,
+ // but this is not yet supported in pure Rust by outlines-core.
+ let grammar_regex = json_schema_to_regex(&json, None, &json)
+ .map_err(ValidationError::RegexFromSchema)?;
+
+ ValidGrammar::Regex(grammar_regex.to_string())
}
GrammarType::Regex(regex) => ValidGrammar::Regex(regex),
};
@@ -810,6 +813,8 @@ pub enum ValidationError {
Grammar,
#[error("grammar is not valid: {0}")]
InvalidGrammar(String),
+ #[error("cannot compile regex from schema: {0}")]
+ RegexFromSchema(anyhow::Error),
#[error("base64 encoding is invalid: {0}")]
InvalidBase64(#[from] base64::DecodeError),
#[error("invalid image: {0}")]
diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py
index d53f070c..132e441b 100644
--- a/server/text_generation_server/utils/logits_process.py
+++ b/server/text_generation_server/utils/logits_process.py
@@ -1,19 +1,19 @@
+from functools import lru_cache
import math
+import time
import torch
+from typing import List, Optional, DefaultDict
from loguru import logger
from typing import Dict, Union
from text_generation_server.pb.generate_pb2 import GrammarType
from outlines.fsm.guide import RegexGuide
-from outlines.fsm.json_schema import build_regex_from_schema
-from functools import lru_cache
-from typing import List, Optional, DefaultDict
-import time
from transformers import (
LogitsWarper,
LogitsProcessor,
+ PreTrainedTokenizerBase,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
@@ -484,7 +484,13 @@ class GrammarLogitProcessor(LogitsProcessor):
fsm_state: DefaultDict[int, int]
fsm: RegexGuide
- def __init__(self, tokenizer, device, grammar, grammar_type):
+ def __init__(
+ self,
+ tokenizer: Optional[PreTrainedTokenizerBase],
+ device: str,
+ grammar: str,
+ grammar_type: GrammarType,
+ ):
self.device = device
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
self.fsm = GrammarLogitProcessor._cached_compile_fsm(
@@ -519,18 +525,20 @@ class GrammarLogitProcessor(LogitsProcessor):
# TODO: move grammar compilation into the router
@staticmethod
@lru_cache(maxsize=32, typed=True)
- def _cached_compile_fsm(grammar_type, schema, tokenizer):
+ def _cached_compile_fsm(
+ grammar_type: GrammarType,
+ schema: str,
+ tokenizer: Optional[PreTrainedTokenizerBase],
+ ):
start_time = time.time()
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
- try:
- schema = build_regex_from_schema(schema)
- # TODO: this is only here short term to avoid crashing the python server, mid term we want this in the rust/router layer
- except Exception as e:
- logger.error(f"Error compiling FSM, grammar won't be enforced \n{e}")
- # allows everything
- schema = "(.*?)"
- elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
- pass # schema is already a regex just here for clarity
+ # JSON schema is compiled by the v3 router.
+ logger.error(
+ "Non-regex grammars must be compiled by the router, grammar won't be enforced"
+ )
+ # allows everything
+ schema = "(.*?)"
+
fsm = RegexGuide.from_regex(schema, tokenizer)
logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s")
return fsm