feat: improve tools to include name and add tests (#1693)

This PR makes tool calling aware of the name of the function selected. 

Fixes:
https://github.com/huggingface/text-generation-inference/issues/1657

Thank you @puppetm4st3r for the helpful snippets, large parts of this PR
are simply refactors of the code shared 🙏

**opening draft PR because small tweaks are needed before merging
This commit is contained in:
drbh 2024-04-16 09:02:46 -04:00 committed by GitHub
parent 88702d8763
commit 7276d43495
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 429 additions and 175 deletions

View File

@ -13,7 +13,7 @@
"usage": null
}
],
"created": 1710795556,
"created": 1712874856,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",

View File

@ -11,13 +11,12 @@
"tool_calls": [
{
"function": {
"description": null,
"name": "tools",
"parameters": {
"arguments": {
"format": "celsius",
"location": "New York, NY",
"num_days": 14
}
"location": "Brooklyn"
},
"description": null,
"name": "get_current_weather"
},
"id": 0,
"type": "function"
@ -27,14 +26,14 @@
"usage": null
}
],
"created": 1710795556,
"created": 1712782670,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native",
"usage": {
"completion_tokens": 29,
"prompt_tokens": 316,
"total_tokens": 345
"completion_tokens": 37,
"prompt_tokens": 524,
"total_tokens": 561
}
}

View File

@ -11,13 +11,12 @@
"tool_calls": [
{
"function": {
"description": null,
"name": "tools",
"parameters": {
"arguments": {
"format": "celsius",
"location": "New York, NY",
"num_days": 14
}
"location": "Brooklyn"
},
"description": null,
"name": "get_current_weather"
},
"id": 0,
"type": "function"
@ -27,14 +26,14 @@
"usage": null
}
],
"created": 1710795557,
"created": 1712787937,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native",
"usage": {
"completion_tokens": 29,
"prompt_tokens": 316,
"total_tokens": 345
"completion_tokens": 37,
"prompt_tokens": 524,
"total_tokens": 561
}
}

View File

@ -11,12 +11,12 @@
"tool_calls": [
{
"function": {
"description": null,
"name": "tools",
"parameters": {
"arguments": {
"format": "celsius",
"location": "New York, NY"
}
},
"description": null,
"name": "get_current_weather"
},
"id": 0,
"type": "function"
@ -26,14 +26,14 @@
"usage": null
}
],
"created": 1710795557,
"created": 1712852394,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native",
"usage": {
"completion_tokens": 21,
"prompt_tokens": 187,
"total_tokens": 208
"completion_tokens": 48,
"prompt_tokens": 320,
"total_tokens": 368
}
}

View File

@ -0,0 +1,38 @@
{
"choices": [
{
"finish_reason": "eos_token",
"index": 0,
"logprobs": null,
"message": {
"content": null,
"name": null,
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": {
"error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options."
},
"description": null,
"name": "notify_error"
},
"id": 0,
"type": "function"
}
]
},
"usage": null
}
],
"created": 1712852597,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "1.4.5-native",
"usage": {
"completion_tokens": 39,
"prompt_tokens": 496,
"total_tokens": 535
}
}

View File

@ -19,7 +19,7 @@
"logprobs": null
}
],
"created": 1710795499,
"created": 1712788218,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",

View File

@ -0,0 +1,42 @@
import pytest
import json
from text_generation.types import GrammarType
@pytest.fixture(scope="module")
def flash_llama_chat_handle(launcher):
with launcher(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_chat(flash_llama_chat_handle):
await flash_llama_chat_handle.health(300)
return flash_llama_chat_handle.client
@pytest.mark.private
async def test_flash_llama_simple(flash_llama_chat, response_snapshot):
response = await flash_llama_chat.chat(
max_tokens=100,
seed=1,
messages=[
{
"role": "system",
"content": "Youre a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "What is the weather like in Brooklyn, New York?",
},
],
)
assert (
response.choices[0].message.content
== "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally"
)
assert response == response_snapshot

View File

@ -71,34 +71,7 @@ tools = [
]
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_no_tools(
flash_llama_grammar_tools, response_snapshot
):
response = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=1,
messages=[
{
"role": "system",
"content": "Youre a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "What is the weather like in Brooklyn, New York?",
},
],
)
assert (
response.choices[0].message.content
== "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally"
)
assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
@ -121,23 +94,19 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
assert response.choices[0].message.content == None
assert response.choices[0].message.tool_calls == [
{
"function": {
"description": None,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY",
"num_days": 14,
},
},
"id": 0,
"type": "function",
"function": {
"description": None,
"name": "get_current_weather",
"arguments": {"format": "celsius", "location": "New York, NY"},
},
}
]
assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_auto(
@ -163,23 +132,20 @@ async def test_flash_llama_grammar_tools_auto(
assert response.choices[0].message.content == None
assert response.choices[0].message.tool_calls == [
{
"function": {
"description": None,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY",
"num_days": 14,
},
},
"id": 0,
"type": "function",
"function": {
"description": None,
"name": "get_current_weather",
"arguments": {"format": "celsius", "location": "New York, NY"},
},
}
]
assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_choice(
@ -209,15 +175,16 @@ async def test_flash_llama_grammar_tools_choice(
"type": "function",
"function": {
"description": None,
"name": "tools",
"parameters": {"format": "celsius", "location": "New York, NY"},
"name": "get_current_weather",
"arguments": {"format": "celsius", "location": "New York, NY"},
},
}
]
assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_stream(
@ -246,5 +213,47 @@ async def test_flash_llama_grammar_tools_stream(
async for response in responses:
count += 1
assert count == 20
assert count == 38
assert response == response_snapshot
@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_insufficient_information(
flash_llama_grammar_tools, response_snapshot
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=8,
tools=tools,
tool_choice="auto",
messages=[
{
"role": "system",
"content": "ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
},
{
"role": "user",
"content": "Tell me a story about 3 sea creatures",
},
],
stream=False,
)
assert responses.choices[0].message.content == None
assert responses.choices[0].message.tool_calls == [
{
"function": {
"arguments": {
"error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options."
},
"description": None,
"name": "notify_error",
},
"id": 0,
"type": "function",
}
]
assert responses == response_snapshot

View File

@ -4,9 +4,12 @@ use crate::{
ChatTemplateInputs, ChatTemplateVersions, Entry, GenerateRequest, GenerateStreamResponse,
HubTokenizerConfig, Message, PrefillToken, Queue, Token,
};
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
use futures::future::try_join_all;
use minijinja::{Environment, ErrorKind, Template};
use nohash_hasher::IntMap;
use serde_json::{json, Map, Value};
use std::collections::HashMap;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
@ -185,11 +188,15 @@ impl Infer {
/// Apply the chat template to the chat request
#[instrument(skip_all)]
pub(crate) fn apply_chat_template(&self, messages: Vec<Message>) -> Result<String, InferError> {
pub(crate) fn apply_chat_template(
&self,
messages: Vec<Message>,
grammar_with_prompt: Option<(GrammarType, String)>,
) -> Result<String, InferError> {
self.chat_template
.as_ref()
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
.apply(messages)
.apply(messages, grammar_with_prompt)
.map_err(|e| {
metrics::increment_counter!("tgi_request_failure", "err" => "template");
tracing::error!("{e}");
@ -322,6 +329,7 @@ struct ChatTemplate {
template: Template<'static, 'static>,
bos_token: Option<String>,
eos_token: Option<String>,
use_default_tool_template: bool,
}
impl ChatTemplate {
@ -329,6 +337,10 @@ impl ChatTemplate {
let mut env = Box::new(Environment::new());
let template_str = template.into_boxed_str();
env.add_function("raise_exception", raise_exception);
// check if contains the tools variable within the template
let use_default_tool_template =
!template_str.as_ref().replace(' ', "").contains("{{tools}}");
// leaking env and template_str as read-only, static resources for performance.
let template = Box::leak(env)
.template_from_str(Box::leak(template_str))
@ -338,21 +350,159 @@ impl ChatTemplate {
template,
bos_token,
eos_token,
use_default_tool_template,
}
}
fn apply(&self, messages: Vec<Message>) -> Result<String, InferError> {
fn apply(
&self,
mut messages: Vec<Message>,
grammar_with_prompt: Option<(GrammarType, String)>,
) -> Result<String, InferError> {
if self.use_default_tool_template {
if let Some(last_message) = messages.last_mut() {
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
last_message.content = Some(format!(
"{}\n---\n{}\n{}",
last_message.content.as_deref().unwrap_or_default(),
tool_prompt,
tools
));
}
}
}
self.template
.render(ChatTemplateInputs {
messages,
bos_token: self.bos_token.as_deref(),
eos_token: self.eos_token.as_deref(),
add_generation_prompt: true,
tools: None,
tools_prompt: None,
})
.map_err(InferError::TemplateError)
}
}
pub struct ToolGrammar {}
impl ToolGrammar {
pub fn apply(
tools: Option<Vec<Tool>>,
tool_choice: Option<ToolType>,
) -> Result<Option<Tools>, InferError> {
if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) {
// let tool_prompt = tool_prompt.unwrap_or_default();
let tools_to_use = match tool_choice {
ToolType::FunctionName(name) => {
vec![req_tools
.iter()
.find(|tool| tool.function.name == *name)
.unwrap_or_else(|| panic!("Tool with name {} not found", name))
.clone()]
}
ToolType::OneOf => req_tools.to_owned(),
};
// adds the error notification function for LLM feedback if required
let mut text_response_properties = Map::new();
text_response_properties.insert(
"error".to_string(),
serde_json::json!({
"type": "string",
"description": "The error or issue to notify"
}),
);
text_response_properties.insert(
"_name".to_string(),
serde_json::json!({
"type": "string",
"const": "notify_error"
}),
);
let functions: HashMap<String, serde_json::Value> = tools_to_use
.iter()
.map(|tool| {
let func = tool.function.clone();
// Clone the existing parameters, which are expected to be a JSON object
let mut params = if let Value::Object(params) = &func.arguments {
params.clone()
} else {
Map::new()
};
// Insert the function's description at the top level, outside of properties
params.insert(
"description".to_string(),
Value::String(func.description.clone().unwrap_or_default()),
);
// Ensure 'properties' exists and is an object
let properties = params
.entry("properties".to_string())
.or_insert_with(|| json!({}))
.as_object_mut()
.unwrap();
// Insert the constant for the function name inside 'properties'
properties.insert(
"_name".to_string(),
json!({
"type": "string",
"const": func.name.clone(),
// "description": "The name of the function"
}),
);
// Check if 'required' exists, and it is an array. If not, create an empty array.
let required = params
.entry("required".to_string())
.or_insert_with(|| json!([]))
.as_array_mut()
.unwrap();
// Add 'name' to the 'required' array if it is not already present
if !required.iter().any(|r| r == "_name") {
required.push(json!("_name"));
}
(func.name, Value::Object(params))
})
.chain([(
"notify_error".to_string(),
serde_json::json!({
"properties": text_response_properties,
"required": ["error", "_name"],
"type": "object"
}),
)])
.collect();
let tools = Tools {
functions_map: FunctionsMap { functions },
properties: Properties {
function: tools_to_use
.iter()
.map(|tool| FunctionRef {
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
})
.chain(std::iter::once(FunctionRef {
ref_path: "#/$functions/notify_error".to_string(),
}))
.collect(),
},
};
return Ok(Some(tools));
}
// Err(InferError::ToolError("No tools provided".to_string()))
Ok(None)
}
}
/// Batching logic
/// Will be launched in a background Tokio task
///
@ -768,6 +918,8 @@ pub enum InferError {
IncompleteGeneration,
#[error("Template error: {0}")]
TemplateError(#[from] minijinja::Error),
#[error("Tool error: {0}")]
ToolError(String),
}
impl InferError {
@ -778,6 +930,7 @@ impl InferError {
InferError::ValidationError(_) => "validation",
InferError::IncompleteGeneration => "incomplete_generation",
InferError::TemplateError(_) => "template_error",
InferError::ToolError(_) => "tool_error",
}
}
}
@ -849,6 +1002,7 @@ mod tests {
bos_token: Some("[BOS]"),
eos_token: Some("[EOS]"),
add_generation_prompt: true,
..Default::default()
};
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
@ -924,6 +1078,7 @@ mod tests {
bos_token: Some("[BOS]"),
eos_token: Some("[EOS]"),
add_generation_prompt: true,
..Default::default()
};
let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap();
@ -998,6 +1153,7 @@ mod tests {
bos_token: Some("[BOS]"),
eos_token: Some("[EOS]"),
add_generation_prompt: true,
..Default::default()
};
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
@ -1056,6 +1212,7 @@ mod tests {
bos_token: Some("[BOS]"),
eos_token: Some("[EOS]"),
add_generation_prompt: true,
..Default::default()
};
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
@ -1115,6 +1272,7 @@ mod tests {
add_generation_prompt: false,
bos_token: Some(""),
eos_token: Some(""),
..Default::default()
},
target: "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n",
},
@ -1126,6 +1284,7 @@ mod tests {
add_generation_prompt: false,
bos_token: Some(""),
eos_token: Some("</s>"),
..Default::default()
},
target: " Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!</s>",
},
@ -1137,6 +1296,7 @@ mod tests {
add_generation_prompt: false,
bos_token: Some(""),
eos_token: Some("</s>"),
..Default::default()
},
target: " Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!</s>",
},
@ -1148,6 +1308,7 @@ mod tests {
add_generation_prompt: false,
bos_token: Some(""),
eos_token: Some("</s>"),
..Default::default()
},
target: "Hello, how are you?</s>I'm doing great. How can I help you today?</s>I'd like to show off how chat templating works!</s>",
},
@ -1159,6 +1320,7 @@ mod tests {
add_generation_prompt: false,
bos_token: Some(""),
eos_token: Some("<|endoftext|>"),
..Default::default()
},
target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>",
},
@ -1170,6 +1332,7 @@ mod tests {
add_generation_prompt: false,
bos_token: Some(""),
eos_token: Some("<|endoftext|>"),
..Default::default()
},
target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>",
},
@ -1182,6 +1345,7 @@ mod tests {
add_generation_prompt: true,
bos_token: Some("<s>"),
eos_token: Some("</s>"),
..Default::default()
},
target: "<s>[INST] <<SYS>>\nYou are a friendly chatbot who always responds in the style of a pirate\n<</SYS>>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? </s><s>[INST] I'd like to show off how chat templating works! [/INST]",
},
@ -1193,6 +1357,7 @@ mod tests {
add_generation_prompt: true,
bos_token: Some(""),
eos_token: Some("<|endoftext|>"),
..Default::default()
},
target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>",
},
@ -1222,6 +1387,7 @@ mod tests {
add_generation_prompt: false,
bos_token: Some(""),
eos_token: Some("</s>"),
..Default::default()
},
target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate</s><|user|>\nHello, how are you?</s><|assistant|>\nI'm doing great. How can I help you today?</s><|user|>\nI'd like to show off how chat templating works!</s>",
},
@ -1246,6 +1412,7 @@ mod tests {
add_generation_prompt: true,
bos_token: Some(""),
eos_token: Some("</s>"),
..Default::default()
},
target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate</s><|user|>\nHow many helicopters can a human eat in one sitting?</s><|assistant|>",
},
@ -1257,6 +1424,7 @@ mod tests {
add_generation_prompt: false,
bos_token: Some("<bos>"),
eos_token: Some("<eos>"),
..Default::default()
},
target: "<bos><|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n",
},
@ -1268,6 +1436,7 @@ mod tests {
add_generation_prompt: false,
bos_token: Some("<s>"),
eos_token: Some("</s>"),
..Default::default()
},
target: "<s>[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?</s> [INST] I'd like to show off how chat templating works! [/INST]",
},
@ -1279,6 +1448,7 @@ mod tests {
add_generation_prompt: false,
bos_token: Some("<s>"),
eos_token: Some("</s>"),
..Default::default()
},
target: "<s>[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?</s>[INST] I'd like to show off how chat templating works! [/INST]",
},
@ -1290,6 +1460,7 @@ mod tests {
add_generation_prompt: false,
bos_token: Some("<s>"),
eos_token: Some("</s>"),
..Default::default()
},
target: "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n",
},
@ -1302,6 +1473,7 @@ mod tests {
add_generation_prompt: false,
bos_token: Some("<s>"),
eos_token: Some("</s>"),
..Default::default()
},
target: "<s>GPT4 Correct User: Hello, how are you?<|end_of_turn|>GPT4 Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT4 Correct User: I'd like to show off how chat templating works!<|end_of_turn|>",
},
@ -1313,6 +1485,7 @@ mod tests {
add_generation_prompt: false,
bos_token: Some("<s>"),
eos_token: Some("</s>"),
..Default::default()
},
target: "Hello, how are you?</s>I'm doing great. How can I help you today?</s>I'd like to show off how chat templating works!</s>",
},
@ -1325,6 +1498,7 @@ mod tests {
add_generation_prompt: false,
bos_token: Some("<s>"),
eos_token: Some("</s>"),
..Default::default()
},
target: "<s>Source: user\n\n Hello, how are you? <step> Source: assistant\n\n I'm doing great. How can I help you today? <step> Source: user\n\n I'd like to show off how chat templating works! <step> Source: assistant\nDestination: user\n\n ",
},
@ -1336,6 +1510,7 @@ mod tests {
add_generation_prompt: false,
bos_token: Some("<s>"),
eos_token: Some("</s>"),
..Default::default()
},
target: "### User:\nHello, how are you?### Assistant:\nI'm doing great. How can I help you today?### User:\nI'd like to show off how chat templating works!",
},
@ -1347,6 +1522,7 @@ mod tests {
add_generation_prompt: false,
bos_token: Some("<s>"),
eos_token: Some("</s>"),
..Default::default()
},
target: "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!",
},
@ -1358,6 +1534,7 @@ mod tests {
add_generation_prompt: false,
bos_token: Some("<begin▁of▁sentence>"),
eos_token: Some("<end▁of▁sentence>"),
..Default::default()
},
target: "<begin▁of▁sentence>User: Hello, how are you?\n\nAssistant: I'm doing great. How can I help you today?<end▁of▁sentence>User: I'd like to show off how chat templating works!\n\n",
},
@ -1369,6 +1546,7 @@ mod tests {
add_generation_prompt: false,
bos_token: Some("<s>"),
eos_token: Some("</s>"),
..Default::default()
},
target: "<|prompt|>Hello, how are you?</s><|answer|>I'm doing great. How can I help you today?</s><|prompt|>I'd like to show off how chat templating works!</s>",
},
@ -1380,6 +1558,7 @@ mod tests {
add_generation_prompt: false,
bos_token: Some("<s>"),
eos_token: Some("</s>"),
..Default::default()
},
target: "<s><|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n",
},
@ -1391,6 +1570,7 @@ mod tests {
add_generation_prompt: false,
bos_token: Some("<begin▁of▁sentence>"),
eos_token: Some("<|EOT|>"),
..Default::default()
},
target: "You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\n### Instruction:\nHello, how are you?\n### Response:\nI'm doing great. How can I help you today?\n<|EOT|>\n### Instruction:\nI'd like to show off how chat templating works!\n### Response:\n",
},
@ -1403,6 +1583,7 @@ mod tests {
add_generation_prompt: false,
bos_token: Some("<|endoftext|>"),
eos_token: Some("<|endoftext|>"),
..Default::default()
},
target: "[INST] Hello, how are you? [RESP] I'm doing great. How can I help you today?<|endoftext|>[INST] I'd like to show off how chat templating works!",
},
@ -1414,6 +1595,7 @@ mod tests {
add_generation_prompt: false,
bos_token: Some("<s>"),
eos_token: Some("</s>"),
..Default::default()
},
target: "Hello, how are you? [/INST] I'm doing great. How can I help you today? </s><s>[INST] I'd like to show off how chat templating works! [/INST]",
},
@ -1425,6 +1607,7 @@ mod tests {
add_generation_prompt: false,
bos_token: Some("<s>"),
eos_token: Some("</s>"),
..Default::default()
},
target: "Below is an instruction that describes a task. Write a response that appropriately completes the request.### Instruction:Hello, how are you?### Response:I'm doing great. How can I help you today?### Instruction:I'd like to show off how chat templating works!",
},
@ -1436,6 +1619,7 @@ mod tests {
add_generation_prompt: false,
bos_token: Some("<begin▁of▁sentence>"),
eos_token: Some("</EOT>"),
..Default::default()
},
target: "<begin▁of▁sentence>You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\n### Instruction:\nHello, how are you?\n### Response:\nI'm doing great. How can I help you today?\n<|EOT|>\n### Instruction:\nI'd like to show off how chat templating works!\n",
},
@ -1451,6 +1635,7 @@ mod tests {
add_generation_prompt: false,
bos_token: Some("<s>"),
eos_token: Some("</s>"),
..Default::default()
},
target: "You are a friendly chatbot who always responds in the style of a pirateYou are a friendly chatbot who always responds in the style of a pirate### Instruction: Hello, how are you?### Response: I'm doing great. How can I help you today?### Instruction: I'd like to show off how chat templating works!",
},

View File

@ -79,7 +79,7 @@ impl HubTokenizerConfig {
}
}
#[derive(Clone, Debug, Deserialize, ToSchema)]
#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
#[serde(tag = "type", content = "value")]
pub(crate) enum GrammarType {
/// A string that represents a [JSON Schema](https://json-schema.org/).
@ -669,7 +669,7 @@ pub(crate) struct ChatRequest {
#[serde(default = "default_tool_prompt")]
#[schema(
nullable = true,
example = "\"Based on the conversation, please choose the most appropriate tool to use: \""
example = "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\""
)]
pub tool_prompt: Option<String>,
@ -682,7 +682,7 @@ pub(crate) struct ChatRequest {
fn default_tool_prompt() -> Option<String> {
Some(
"\nBased on the conversation, please choose the most appropriate tool to use: ".to_string(),
"\nYou will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n".to_string(),
)
}
#[derive(Clone, Deserialize, ToSchema, Serialize)]
@ -727,26 +727,26 @@ mod deserialize_tool_choice {
}
}
#[derive(Debug, Deserialize, Serialize, ToSchema)]
#[derive(Debug, Deserialize, Serialize, ToSchema, PartialEq)]
pub struct Tools {
#[serde(flatten)]
functions_map: FunctionsMap,
properties: Properties,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, PartialEq)]
struct FunctionsMap {
#[serde(rename = "$functions")]
functions: std::collections::HashMap<String, serde_json::Value>,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, PartialEq)]
struct FunctionRef {
#[serde(rename = "$ref")]
ref_path: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, PartialEq)]
struct Properties {
#[serde(serialize_with = "serialize_function")]
function: Vec<FunctionRef>,
@ -767,7 +767,8 @@ pub(crate) struct FunctionDefinition {
#[serde(default)]
pub description: Option<String>,
pub name: String,
pub parameters: serde_json::Value,
#[serde(alias = "parameters")]
pub arguments: serde_json::Value,
}
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
@ -779,12 +780,14 @@ pub(crate) struct Tool {
pub function: FunctionDefinition,
}
#[derive(Clone, Serialize, Deserialize)]
#[derive(Clone, Serialize, Deserialize, Default)]
pub(crate) struct ChatTemplateInputs<'a> {
messages: Vec<Message>,
bos_token: Option<&'a str>,
eos_token: Option<&'a str>,
add_generation_prompt: bool,
tools: Option<&'a str>,
tools_prompt: Option<&'a str>,
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]

View File

@ -1,7 +1,7 @@
use crate::config::Config;
/// HTTP Server logic
use crate::health::Health;
use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::infer::{InferError, InferResponse, InferStreamResponse, ToolGrammar};
use crate::validation::ValidationError;
use crate::{
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
@ -15,7 +15,7 @@ use crate::{
ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk,
CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse,
};
use crate::{FunctionDefinition, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools};
use crate::{FunctionDefinition, ToolCall, ToolType};
use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse};
@ -29,7 +29,6 @@ use futures::Stream;
use futures::TryStreamExt;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use serde_json::Value;
use std::collections::HashMap;
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::atomic::AtomicBool;
@ -757,19 +756,29 @@ async fn chat_completions(
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
metrics::increment_counter!("tgi_request_count");
let stream = req.stream;
let max_new_tokens = req.max_tokens.or(Some(100));
let repetition_penalty = req
.presence_penalty
// rescale repetition_penalty from (-2.0, 2.0) to (0.0, 4.0)
.map(|x| x + 2.0);
let logprobs = req.logprobs.unwrap_or(false);
let seed = req.seed;
let stop = req.stop.unwrap_or_default();
let ChatRequest {
logprobs,
max_tokens,
messages,
presence_penalty,
seed,
stop,
stream,
tools,
tool_choice,
tool_prompt,
..
} = req;
// apply chat template to flatten the request into a single input
let mut inputs = match infer.apply_chat_template(req.messages) {
Ok(inputs) => inputs,
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
let max_new_tokens = max_tokens.or(Some(100));
let logprobs = logprobs.unwrap_or(false);
let tool_prompt = tool_prompt.unwrap_or_default();
let stop = stop.unwrap_or_default();
// extract tool grammar if present
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
Ok(grammar) => grammar,
Err(err) => {
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}");
@ -783,60 +792,28 @@ async fn chat_completions(
}
};
let tool_grammar = if let Some((req_tools, tool_choice)) = req.tools.zip(req.tool_choice) {
let tool_prompt = req.tool_prompt.unwrap_or_default();
let tools_to_use = match tool_choice {
ToolType::FunctionName(name) => {
vec![req_tools
.iter()
.find(|tool| tool.function.name == *name)
.ok_or_else(|| {
(
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: "Tool choice not found in tool names".to_string(),
error_type: "Tool not found".to_string(),
}),
)
})?
.clone()]
}
ToolType::OneOf => req_tools.to_owned(),
};
let grammar_with_prompt = tool_grammar
.as_ref()
.map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt));
let functions: HashMap<String, Value> = tools_to_use
.iter()
.map(|tool| {
let func = tool.function.clone();
(func.name, func.parameters)
})
.collect();
let typed_grammar = grammar_with_prompt
.as_ref()
.map(|(grammar, _)| grammar.clone());
let tools = Tools {
functions_map: FunctionsMap { functions },
properties: Properties {
function: tools_to_use
.iter()
.map(|tool| FunctionRef {
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
})
.collect(),
},
};
let tools_str = serde_json::to_string(&tools).map_err(|e| {
(
// apply chat template to flatten the request into a single input
let inputs = match infer.apply_chat_template(messages, grammar_with_prompt) {
Ok(inputs) => inputs,
Err(err) => {
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}");
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: e.to_string(),
error_type: "Input validation error".to_string(),
error: err.to_string(),
error_type: err.error_type().to_string(),
}),
)
})?;
inputs = format!("{inputs}{tool_prompt}{tools_str}");
Some(GrammarType::Json(serde_json::json!(tools)))
} else {
None
));
}
};
// build the request passing some parameters
@ -860,7 +837,7 @@ async fn chat_completions(
decoder_input_details: !stream,
seed,
top_n_tokens: req.top_logprobs,
grammar: tool_grammar.clone(),
grammar: typed_grammar,
},
};
@ -943,27 +920,28 @@ async fn chat_completions(
}),
)
})?;
let tool_calls = vec![ToolCall {
id: 0,
r#type: "function".to_string(),
function: FunctionDefinition {
description: None,
name: "tools".to_string(),
parameters: gen_text_value.get("function").map_or_else(
|| {
serde_json::from_str(&generation.generated_text).map_err(|e| {
(
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: e.to_string(),
error_type: "Input validation error".to_string(),
}),
)
})
},
|f| Ok(f.clone()),
)?,
name: gen_text_value
.get("function")
.and_then(|f| f.get("_name"))
.and_then(|name| name.as_str())
.unwrap_or("default_function_name")
.to_string(),
// Serialize the JSON object obtained from "function" to an escaped JSON string
arguments: gen_text_value
.get("function")
.map(|f| {
let mut f_cloned = f.clone();
if let Value::Object(ref mut props) = f_cloned {
props.remove("_name");
}
f_cloned
})
.unwrap_or_default(),
},
}];
(Some(tool_calls), None)
@ -1539,6 +1517,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,
};
(