Support chat response format (#2046)

* feat: support response_format in chat

* fix: adjust typos

* fix: add trufflehog lint
This commit is contained in:
drbh 2024-06-11 10:44:56 -04:00 committed by GitHub
parent a6e4d63c86
commit 376a0b7ada
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 156 additions and 7 deletions

View File

@ -16,4 +16,3 @@ jobs:
fetch-depth: 0 fetch-depth: 0
- name: Secret Scanning - name: Secret Scanning
uses: trufflesecurity/trufflehog@main uses: trufflesecurity/trufflehog@main

View File

@ -0,0 +1,23 @@
{
"choices": [
{
"finish_reason": "eos_token",
"index": 0,
"logprobs": null,
"message": {
"content": "{\n \"temperature\": [\n 35,\n 34,\n 36\n ],\n \"unit\": \"°c\"\n}",
"role": "assistant"
}
}
],
"created": 1718044128,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.5-dev0-native",
"usage": {
"completion_tokens": 39,
"prompt_tokens": 136,
"total_tokens": 175
}
}

View File

@ -0,0 +1,101 @@
import pytest
import requests
from pydantic import BaseModel
from typing import List
@pytest.fixture(scope="module")
def llama_grammar_handle(launcher):
with launcher(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
num_shard=1,
disable_grammar_support=False,
use_flash_attention=False,
max_batch_prefill_tokens=3000,
) as handle:
yield handle
@pytest.fixture(scope="module")
async def llama_grammar(llama_grammar_handle):
await llama_grammar_handle.health(300)
return llama_grammar_handle.client
@pytest.mark.asyncio
async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot):
class Weather(BaseModel):
unit: str
temperature: List[int]
# send the request
response = requests.post(
f"{llama_grammar.base_url}/v1/chat/completions",
headers=llama_grammar.headers,
json={
"model": "tgi",
"messages": [
{
"role": "system",
"content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}",
},
{
"role": "user",
"content": "What's the weather like the next 3 days in San Francisco, CA?",
},
],
"seed": 42,
"max_tokens": 500,
"response_format": {"type": "json_object", "value": Weather.schema()},
},
)
chat_completion = response.json()
called = chat_completion["choices"][0]["message"]["content"]
assert response.status_code == 200
assert (
called
== '{\n "temperature": [\n 35,\n 34,\n 36\n ],\n "unit": "°c"\n}'
)
assert chat_completion == response_snapshot
@pytest.mark.asyncio
async def test_grammar_response_format_llama_error_if_tools_not_installed(
llama_grammar,
):
class Weather(BaseModel):
unit: str
temperature: List[int]
# send the request
response = requests.post(
f"{llama_grammar.base_url}/v1/chat/completions",
headers=llama_grammar.headers,
json={
"model": "tgi",
"messages": [
{
"role": "system",
"content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}",
},
{
"role": "user",
"content": "What's the weather like the next 3 days in San Francisco, CA?",
},
],
"seed": 42,
"max_tokens": 500,
"tools": [],
"response_format": {"type": "json_object", "value": Weather.schema()},
},
)
# 422 means the server was unable to process the request because it contains invalid data.
assert response.status_code == 422
assert response.json() == {
"error": "Grammar and tools are mutually exclusive",
"error_type": "grammar and tools",
}

View File

@ -89,6 +89,7 @@ pub(crate) enum GrammarType {
/// JSON Schema is a declarative language that allows to annotate JSON documents /// JSON Schema is a declarative language that allows to annotate JSON documents
/// with types and descriptions. /// with types and descriptions.
#[serde(rename = "json")] #[serde(rename = "json")]
#[serde(alias = "json_object")]
#[schema(example = json ! ({"properties": {"location":{"type": "string"}}}))] #[schema(example = json ! ({"properties": {"location":{"type": "string"}}}))]
Json(serde_json::Value), Json(serde_json::Value),
#[serde(rename = "regex")] #[serde(rename = "regex")]
@ -791,6 +792,13 @@ pub(crate) struct ChatRequest {
#[schema(nullable = true, example = "null")] #[schema(nullable = true, example = "null")]
#[serde(deserialize_with = "deserialize_tool_choice::deserialize")] #[serde(deserialize_with = "deserialize_tool_choice::deserialize")]
pub tool_choice: Option<ToolType>, pub tool_choice: Option<ToolType>,
/// Response format constraints for the generation.
///
/// NOTE: A request can use `response_format` OR `tools` but not both.
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub response_format: Option<GrammarType>,
} }
fn default_tool_prompt() -> Option<String> { fn default_tool_prompt() -> Option<String> {

View File

@ -1016,6 +1016,7 @@ async fn chat_completions(
tool_choice, tool_choice,
tool_prompt, tool_prompt,
temperature, temperature,
response_format,
.. ..
} = req; } = req;
@ -1030,6 +1031,18 @@ async fn chat_completions(
other => (true, other), other => (true, other),
}; };
// response_format and tools are mutually exclusive
if response_format.is_some() && tools.as_ref().is_some() {
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: "Grammar and tools are mutually exclusive".to_string(),
error_type: "grammar and tools".to_string(),
}),
));
}
// extract tool grammar if present // extract tool grammar if present
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) { let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
Ok(grammar) => grammar, Ok(grammar) => grammar,
@ -1046,16 +1059,21 @@ async fn chat_completions(
} }
}; };
let grammar_with_prompt = tool_grammar // determine the appropriate arguments for apply_chat_template
let tools_grammar_prompt = tool_grammar
.as_ref() .as_ref()
.map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt)); .map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt));
let typed_grammar = grammar_with_prompt let (tools_grammar_prompt, grammar) = match response_format {
.as_ref() Some(response_format) => (None, Some(response_format)),
.map(|(grammar, _)| grammar.clone()); None => (
tools_grammar_prompt.clone(),
tools_grammar_prompt.map(|(grammar, _)| grammar.clone()),
),
};
// apply chat template to flatten the request into a single input // apply chat template to flatten the request into a single input
let inputs = match infer.apply_chat_template(messages, grammar_with_prompt) { let inputs = match infer.apply_chat_template(messages, tools_grammar_prompt) {
Ok(inputs) => inputs, Ok(inputs) => inputs,
Err(err) => { Err(err) => {
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); metrics::increment_counter!("tgi_request_failure", "err" => "validation");
@ -1091,7 +1109,7 @@ async fn chat_completions(
decoder_input_details: !stream, decoder_input_details: !stream,
seed, seed,
top_n_tokens: req.top_logprobs, top_n_tokens: req.top_logprobs,
grammar: typed_grammar, grammar,
}, },
}; };