Support chat response format (#2046)
* feat: support response_format in chat * fix: adjust typos * fix: add trufflehog lint
This commit is contained in:
parent
a6e4d63c86
commit
376a0b7ada
|
@ -16,4 +16,3 @@ jobs:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
- name: Secret Scanning
|
- name: Secret Scanning
|
||||||
uses: trufflesecurity/trufflehog@main
|
uses: trufflesecurity/trufflehog@main
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,23 @@
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "eos_token",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "{\n \"temperature\": [\n 35,\n 34,\n 36\n ],\n \"unit\": \"°c\"\n}",
|
||||||
|
"role": "assistant"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1718044128,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.5-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 39,
|
||||||
|
"prompt_tokens": 136,
|
||||||
|
"total_tokens": 175
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,101 @@
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def llama_grammar_handle(launcher):
|
||||||
|
with launcher(
|
||||||
|
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
num_shard=1,
|
||||||
|
disable_grammar_support=False,
|
||||||
|
use_flash_attention=False,
|
||||||
|
max_batch_prefill_tokens=3000,
|
||||||
|
) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def llama_grammar(llama_grammar_handle):
|
||||||
|
await llama_grammar_handle.health(300)
|
||||||
|
return llama_grammar_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot):
|
||||||
|
|
||||||
|
class Weather(BaseModel):
|
||||||
|
unit: str
|
||||||
|
temperature: List[int]
|
||||||
|
|
||||||
|
# send the request
|
||||||
|
response = requests.post(
|
||||||
|
f"{llama_grammar.base_url}/v1/chat/completions",
|
||||||
|
headers=llama_grammar.headers,
|
||||||
|
json={
|
||||||
|
"model": "tgi",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What's the weather like the next 3 days in San Francisco, CA?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"seed": 42,
|
||||||
|
"max_tokens": 500,
|
||||||
|
"response_format": {"type": "json_object", "value": Weather.schema()},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_completion = response.json()
|
||||||
|
called = chat_completion["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert (
|
||||||
|
called
|
||||||
|
== '{\n "temperature": [\n 35,\n 34,\n 36\n ],\n "unit": "°c"\n}'
|
||||||
|
)
|
||||||
|
assert chat_completion == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_grammar_response_format_llama_error_if_tools_not_installed(
|
||||||
|
llama_grammar,
|
||||||
|
):
|
||||||
|
class Weather(BaseModel):
|
||||||
|
unit: str
|
||||||
|
temperature: List[int]
|
||||||
|
|
||||||
|
# send the request
|
||||||
|
response = requests.post(
|
||||||
|
f"{llama_grammar.base_url}/v1/chat/completions",
|
||||||
|
headers=llama_grammar.headers,
|
||||||
|
json={
|
||||||
|
"model": "tgi",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What's the weather like the next 3 days in San Francisco, CA?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"seed": 42,
|
||||||
|
"max_tokens": 500,
|
||||||
|
"tools": [],
|
||||||
|
"response_format": {"type": "json_object", "value": Weather.schema()},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# 422 means the server was unable to process the request because it contains invalid data.
|
||||||
|
assert response.status_code == 422
|
||||||
|
assert response.json() == {
|
||||||
|
"error": "Grammar and tools are mutually exclusive",
|
||||||
|
"error_type": "grammar and tools",
|
||||||
|
}
|
|
@ -89,6 +89,7 @@ pub(crate) enum GrammarType {
|
||||||
/// JSON Schema is a declarative language that allows to annotate JSON documents
|
/// JSON Schema is a declarative language that allows to annotate JSON documents
|
||||||
/// with types and descriptions.
|
/// with types and descriptions.
|
||||||
#[serde(rename = "json")]
|
#[serde(rename = "json")]
|
||||||
|
#[serde(alias = "json_object")]
|
||||||
#[schema(example = json ! ({"properties": {"location":{"type": "string"}}}))]
|
#[schema(example = json ! ({"properties": {"location":{"type": "string"}}}))]
|
||||||
Json(serde_json::Value),
|
Json(serde_json::Value),
|
||||||
#[serde(rename = "regex")]
|
#[serde(rename = "regex")]
|
||||||
|
@ -791,6 +792,13 @@ pub(crate) struct ChatRequest {
|
||||||
#[schema(nullable = true, example = "null")]
|
#[schema(nullable = true, example = "null")]
|
||||||
#[serde(deserialize_with = "deserialize_tool_choice::deserialize")]
|
#[serde(deserialize_with = "deserialize_tool_choice::deserialize")]
|
||||||
pub tool_choice: Option<ToolType>,
|
pub tool_choice: Option<ToolType>,
|
||||||
|
|
||||||
|
/// Response format constraints for the generation.
|
||||||
|
///
|
||||||
|
/// NOTE: A request can use `response_format` OR `tools` but not both.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, default = "null", example = "null")]
|
||||||
|
pub response_format: Option<GrammarType>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_tool_prompt() -> Option<String> {
|
fn default_tool_prompt() -> Option<String> {
|
||||||
|
|
|
@ -1016,6 +1016,7 @@ async fn chat_completions(
|
||||||
tool_choice,
|
tool_choice,
|
||||||
tool_prompt,
|
tool_prompt,
|
||||||
temperature,
|
temperature,
|
||||||
|
response_format,
|
||||||
..
|
..
|
||||||
} = req;
|
} = req;
|
||||||
|
|
||||||
|
@ -1030,6 +1031,18 @@ async fn chat_completions(
|
||||||
other => (true, other),
|
other => (true, other),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// response_format and tools are mutually exclusive
|
||||||
|
if response_format.is_some() && tools.as_ref().is_some() {
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||||
|
return Err((
|
||||||
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: "Grammar and tools are mutually exclusive".to_string(),
|
||||||
|
error_type: "grammar and tools".to_string(),
|
||||||
|
}),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
// extract tool grammar if present
|
// extract tool grammar if present
|
||||||
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
|
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
|
||||||
Ok(grammar) => grammar,
|
Ok(grammar) => grammar,
|
||||||
|
@ -1046,16 +1059,21 @@ async fn chat_completions(
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let grammar_with_prompt = tool_grammar
|
// determine the appropriate arguments for apply_chat_template
|
||||||
|
let tools_grammar_prompt = tool_grammar
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt));
|
.map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt));
|
||||||
|
|
||||||
let typed_grammar = grammar_with_prompt
|
let (tools_grammar_prompt, grammar) = match response_format {
|
||||||
.as_ref()
|
Some(response_format) => (None, Some(response_format)),
|
||||||
.map(|(grammar, _)| grammar.clone());
|
None => (
|
||||||
|
tools_grammar_prompt.clone(),
|
||||||
|
tools_grammar_prompt.map(|(grammar, _)| grammar.clone()),
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
// apply chat template to flatten the request into a single input
|
// apply chat template to flatten the request into a single input
|
||||||
let inputs = match infer.apply_chat_template(messages, grammar_with_prompt) {
|
let inputs = match infer.apply_chat_template(messages, tools_grammar_prompt) {
|
||||||
Ok(inputs) => inputs,
|
Ok(inputs) => inputs,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||||
|
@ -1091,7 +1109,7 @@ async fn chat_completions(
|
||||||
decoder_input_details: !stream,
|
decoder_input_details: !stream,
|
||||||
seed,
|
seed,
|
||||||
top_n_tokens: req.top_logprobs,
|
top_n_tokens: req.top_logprobs,
|
||||||
grammar: typed_grammar,
|
grammar,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue