Support chat response format (#2046)
* feat: support response_format in chat * fix: adjust typos * fix: add trufflehog lint
This commit is contained in:
parent
a6e4d63c86
commit
376a0b7ada
|
@ -16,4 +16,3 @@ jobs:
|
|||
fetch-depth: 0
|
||||
- name: Secret Scanning
|
||||
uses: trufflesecurity/trufflehog@main
|
||||
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "eos_token",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "{\n \"temperature\": [\n 35,\n 34,\n 36\n ],\n \"unit\": \"°c\"\n}",
|
||||
"role": "assistant"
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1718044128,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.5-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 39,
|
||||
"prompt_tokens": 136,
|
||||
"total_tokens": 175
|
||||
}
|
||||
}
|
|
@ -0,0 +1,101 @@
|
|||
import pytest
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from typing import List
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def llama_grammar_handle(launcher):
|
||||
with launcher(
|
||||
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
num_shard=1,
|
||||
disable_grammar_support=False,
|
||||
use_flash_attention=False,
|
||||
max_batch_prefill_tokens=3000,
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def llama_grammar(llama_grammar_handle):
|
||||
await llama_grammar_handle.health(300)
|
||||
return llama_grammar_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot):
|
||||
|
||||
class Weather(BaseModel):
|
||||
unit: str
|
||||
temperature: List[int]
|
||||
|
||||
# send the request
|
||||
response = requests.post(
|
||||
f"{llama_grammar.base_url}/v1/chat/completions",
|
||||
headers=llama_grammar.headers,
|
||||
json={
|
||||
"model": "tgi",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like the next 3 days in San Francisco, CA?",
|
||||
},
|
||||
],
|
||||
"seed": 42,
|
||||
"max_tokens": 500,
|
||||
"response_format": {"type": "json_object", "value": Weather.schema()},
|
||||
},
|
||||
)
|
||||
|
||||
chat_completion = response.json()
|
||||
called = chat_completion["choices"][0]["message"]["content"]
|
||||
|
||||
assert response.status_code == 200
|
||||
assert (
|
||||
called
|
||||
== '{\n "temperature": [\n 35,\n 34,\n 36\n ],\n "unit": "°c"\n}'
|
||||
)
|
||||
assert chat_completion == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grammar_response_format_llama_error_if_tools_not_installed(
|
||||
llama_grammar,
|
||||
):
|
||||
class Weather(BaseModel):
|
||||
unit: str
|
||||
temperature: List[int]
|
||||
|
||||
# send the request
|
||||
response = requests.post(
|
||||
f"{llama_grammar.base_url}/v1/chat/completions",
|
||||
headers=llama_grammar.headers,
|
||||
json={
|
||||
"model": "tgi",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like the next 3 days in San Francisco, CA?",
|
||||
},
|
||||
],
|
||||
"seed": 42,
|
||||
"max_tokens": 500,
|
||||
"tools": [],
|
||||
"response_format": {"type": "json_object", "value": Weather.schema()},
|
||||
},
|
||||
)
|
||||
|
||||
# 422 means the server was unable to process the request because it contains invalid data.
|
||||
assert response.status_code == 422
|
||||
assert response.json() == {
|
||||
"error": "Grammar and tools are mutually exclusive",
|
||||
"error_type": "grammar and tools",
|
||||
}
|
|
@ -89,6 +89,7 @@ pub(crate) enum GrammarType {
|
|||
/// JSON Schema is a declarative language that allows to annotate JSON documents
|
||||
/// with types and descriptions.
|
||||
#[serde(rename = "json")]
|
||||
#[serde(alias = "json_object")]
|
||||
#[schema(example = json ! ({"properties": {"location":{"type": "string"}}}))]
|
||||
Json(serde_json::Value),
|
||||
#[serde(rename = "regex")]
|
||||
|
@ -791,6 +792,13 @@ pub(crate) struct ChatRequest {
|
|||
#[schema(nullable = true, example = "null")]
|
||||
#[serde(deserialize_with = "deserialize_tool_choice::deserialize")]
|
||||
pub tool_choice: Option<ToolType>,
|
||||
|
||||
/// Response format constraints for the generation.
|
||||
///
|
||||
/// NOTE: A request can use `response_format` OR `tools` but not both.
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, default = "null", example = "null")]
|
||||
pub response_format: Option<GrammarType>,
|
||||
}
|
||||
|
||||
fn default_tool_prompt() -> Option<String> {
|
||||
|
|
|
@ -1016,6 +1016,7 @@ async fn chat_completions(
|
|||
tool_choice,
|
||||
tool_prompt,
|
||||
temperature,
|
||||
response_format,
|
||||
..
|
||||
} = req;
|
||||
|
||||
|
@ -1030,6 +1031,18 @@ async fn chat_completions(
|
|||
other => (true, other),
|
||||
};
|
||||
|
||||
// response_format and tools are mutually exclusive
|
||||
if response_format.is_some() && tools.as_ref().is_some() {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
return Err((
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: "Grammar and tools are mutually exclusive".to_string(),
|
||||
error_type: "grammar and tools".to_string(),
|
||||
}),
|
||||
));
|
||||
}
|
||||
|
||||
// extract tool grammar if present
|
||||
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
|
||||
Ok(grammar) => grammar,
|
||||
|
@ -1046,16 +1059,21 @@ async fn chat_completions(
|
|||
}
|
||||
};
|
||||
|
||||
let grammar_with_prompt = tool_grammar
|
||||
// determine the appropriate arguments for apply_chat_template
|
||||
let tools_grammar_prompt = tool_grammar
|
||||
.as_ref()
|
||||
.map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt));
|
||||
|
||||
let typed_grammar = grammar_with_prompt
|
||||
.as_ref()
|
||||
.map(|(grammar, _)| grammar.clone());
|
||||
let (tools_grammar_prompt, grammar) = match response_format {
|
||||
Some(response_format) => (None, Some(response_format)),
|
||||
None => (
|
||||
tools_grammar_prompt.clone(),
|
||||
tools_grammar_prompt.map(|(grammar, _)| grammar.clone()),
|
||||
),
|
||||
};
|
||||
|
||||
// apply chat template to flatten the request into a single input
|
||||
let inputs = match infer.apply_chat_template(messages, grammar_with_prompt) {
|
||||
let inputs = match infer.apply_chat_template(messages, tools_grammar_prompt) {
|
||||
Ok(inputs) => inputs,
|
||||
Err(err) => {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
|
@ -1091,7 +1109,7 @@ async fn chat_completions(
|
|||
decoder_input_details: !stream,
|
||||
seed,
|
||||
top_n_tokens: req.top_logprobs,
|
||||
grammar: typed_grammar,
|
||||
grammar,
|
||||
},
|
||||
};
|
||||
|
||||
|
|
Loading…
Reference in New Issue