From 7276d4349510bdf3198d6b255a140fd7c1dc7216 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 16 Apr 2024 09:02:46 -0400 Subject: [PATCH] feat: improve tools to include name and add tests (#1693) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR makes tool calling aware of the name of the function selected. Fixes: https://github.com/huggingface/text-generation-inference/issues/1657 Thank you @puppetm4st3r for the helpful snippets, large parts of this PR are simply refactors of the code shared šŸ™ **opening draft PR because small tweaks are needed before merging --- .../test_flash_llama_simple.json} | 2 +- .../test_flash_llama_grammar_tools.json | 19 +- .../test_flash_llama_grammar_tools_auto.json | 19 +- ...test_flash_llama_grammar_tools_choice.json | 16 +- ...rammar_tools_insufficient_information.json | 38 ++++ ...test_flash_llama_grammar_tools_stream.json | 2 +- integration-tests/models/test_chat_llama.py | 42 ++++ integration-tests/models/test_tools_llama.py | 113 ++++++----- router/src/infer.rs | 191 +++++++++++++++++- router/src/lib.rs | 21 +- router/src/server.rs | 141 ++++++------- 11 files changed, 429 insertions(+), 175 deletions(-) rename integration-tests/models/__snapshots__/{test_tools_llama/test_flash_llama_grammar_no_tools.json => test_chat_llama/test_flash_llama_simple.json} (97%) create mode 100644 integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information.json create mode 100644 integration-tests/models/test_chat_llama.py diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_no_tools.json b/integration-tests/models/__snapshots__/test_chat_llama/test_flash_llama_simple.json similarity index 97% rename from integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_no_tools.json rename to integration-tests/models/__snapshots__/test_chat_llama/test_flash_llama_simple.json index 153a508d..0ff874f1 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_no_tools.json +++ b/integration-tests/models/__snapshots__/test_chat_llama/test_flash_llama_simple.json @@ -13,7 +13,7 @@ "usage": null } ], - "created": 1710795556, + "created": 1712874856, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json index 56920b3e..45f8ca99 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json @@ -11,13 +11,12 @@ "tool_calls": [ { "function": { - "description": null, - "name": "tools", - "parameters": { + "arguments": { "format": "celsius", - "location": "New York, NY", - "num_days": 14 - } + "location": "Brooklyn" + }, + "description": null, + "name": "get_current_weather" }, "id": 0, "type": "function" @@ -27,14 +26,14 @@ "usage": null } ], - "created": 1710795556, + "created": 1712782670, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", "system_fingerprint": "2.0.0-native", "usage": { - "completion_tokens": 29, - "prompt_tokens": 316, - "total_tokens": 345 + "completion_tokens": 37, + "prompt_tokens": 524, + "total_tokens": 561 } } diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json index fe679362..e0ed0947 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json @@ -11,13 +11,12 @@ "tool_calls": [ { "function": { - "description": null, - "name": "tools", - "parameters": { + "arguments": { "format": "celsius", - "location": "New York, NY", - "num_days": 14 - } + "location": "Brooklyn" + }, + "description": null, + "name": "get_current_weather" }, "id": 0, "type": "function" @@ -27,14 +26,14 @@ "usage": null } ], - "created": 1710795557, + "created": 1712787937, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", "system_fingerprint": "2.0.0-native", "usage": { - "completion_tokens": 29, - "prompt_tokens": 316, - "total_tokens": 345 + "completion_tokens": 37, + "prompt_tokens": 524, + "total_tokens": 561 } } diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json index e48a1e7d..b70c2d6f 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json @@ -11,12 +11,12 @@ "tool_calls": [ { "function": { - "description": null, - "name": "tools", - "parameters": { + "arguments": { "format": "celsius", "location": "New York, NY" - } + }, + "description": null, + "name": "get_current_weather" }, "id": 0, "type": "function" @@ -26,14 +26,14 @@ "usage": null } ], - "created": 1710795557, + "created": 1712852394, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", "system_fingerprint": "2.0.0-native", "usage": { - "completion_tokens": 21, - "prompt_tokens": 187, - "total_tokens": 208 + "completion_tokens": 48, + "prompt_tokens": 320, + "total_tokens": 368 } } diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information.json new file mode 100644 index 00000000..0cd3c67f --- /dev/null +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information.json @@ -0,0 +1,38 @@ +{ + "choices": [ + { + "finish_reason": "eos_token", + "index": 0, + "logprobs": null, + "message": { + "content": null, + "name": null, + "role": "assistant", + "tool_calls": [ + { + "function": { + "arguments": { + "error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options." + }, + "description": null, + "name": "notify_error" + }, + "id": 0, + "type": "function" + } + ] + }, + "usage": null + } + ], + "created": 1712852597, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "1.4.5-native", + "usage": { + "completion_tokens": 39, + "prompt_tokens": 496, + "total_tokens": 535 + } +} diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json index cfebc05f..6787b39b 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json @@ -19,7 +19,7 @@ "logprobs": null } ], - "created": 1710795499, + "created": 1712788218, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", diff --git a/integration-tests/models/test_chat_llama.py b/integration-tests/models/test_chat_llama.py new file mode 100644 index 00000000..11419a0e --- /dev/null +++ b/integration-tests/models/test_chat_llama.py @@ -0,0 +1,42 @@ +import pytest +import json + +from text_generation.types import GrammarType + + +@pytest.fixture(scope="module") +def flash_llama_chat_handle(launcher): + with launcher( + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llama_chat(flash_llama_chat_handle): + await flash_llama_chat_handle.health(300) + return flash_llama_chat_handle.client + + +@pytest.mark.private +async def test_flash_llama_simple(flash_llama_chat, response_snapshot): + response = await flash_llama_chat.chat( + max_tokens=100, + seed=1, + messages=[ + { + "role": "system", + "content": "Youre a helpful assistant! Answer the users question best you can.", + }, + { + "role": "user", + "content": "What is the weather like in Brooklyn, New York?", + }, + ], + ) + + assert ( + response.choices[0].message.content + == "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75Ā°F for today and 77Ā°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally" + ) + assert response == response_snapshot diff --git a/integration-tests/models/test_tools_llama.py b/integration-tests/models/test_tools_llama.py index d0ae331f..0af3f66a 100644 --- a/integration-tests/models/test_tools_llama.py +++ b/integration-tests/models/test_tools_llama.py @@ -71,34 +71,7 @@ tools = [ ] -@pytest.mark.asyncio -@pytest.mark.private -async def test_flash_llama_grammar_no_tools( - flash_llama_grammar_tools, response_snapshot -): - response = await flash_llama_grammar_tools.chat( - max_tokens=100, - seed=1, - messages=[ - { - "role": "system", - "content": "Youre a helpful assistant! Answer the users question best you can.", - }, - { - "role": "user", - "content": "What is the weather like in Brooklyn, New York?", - }, - ], - ) - - assert ( - response.choices[0].message.content - == "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75Ā°F for today and 77Ā°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally" - ) - assert response == response_snapshot - - -@pytest.mark.skip +@pytest.mark.skip(reason="Takes too long to run") @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot): @@ -121,23 +94,19 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna assert response.choices[0].message.content == None assert response.choices[0].message.tool_calls == [ { - "function": { - "description": None, - "name": "tools", - "parameters": { - "format": "celsius", - "location": "New York, NY", - "num_days": 14, - }, - }, "id": 0, "type": "function", + "function": { + "description": None, + "name": "get_current_weather", + "arguments": {"format": "celsius", "location": "New York, NY"}, + }, } ] assert response == response_snapshot -@pytest.mark.skip +@pytest.mark.skip(reason="Takes too long to run") @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools_auto( @@ -163,23 +132,20 @@ async def test_flash_llama_grammar_tools_auto( assert response.choices[0].message.content == None assert response.choices[0].message.tool_calls == [ { - "function": { - "description": None, - "name": "tools", - "parameters": { - "format": "celsius", - "location": "New York, NY", - "num_days": 14, - }, - }, "id": 0, "type": "function", + "function": { + "description": None, + "name": "get_current_weather", + "arguments": {"format": "celsius", "location": "New York, NY"}, + }, } ] + assert response == response_snapshot -@pytest.mark.skip +@pytest.mark.skip(reason="Takes too long to run") @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools_choice( @@ -209,15 +175,16 @@ async def test_flash_llama_grammar_tools_choice( "type": "function", "function": { "description": None, - "name": "tools", - "parameters": {"format": "celsius", "location": "New York, NY"}, + "name": "get_current_weather", + "arguments": {"format": "celsius", "location": "New York, NY"}, }, } ] + assert response == response_snapshot -@pytest.mark.skip +@pytest.mark.skip(reason="Takes too long to run") @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools_stream( @@ -246,5 +213,47 @@ async def test_flash_llama_grammar_tools_stream( async for response in responses: count += 1 - assert count == 20 + assert count == 38 assert response == response_snapshot + + +@pytest.mark.skip(reason="Takes too long to run") +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_tools_insufficient_information( + flash_llama_grammar_tools, response_snapshot +): + responses = await flash_llama_grammar_tools.chat( + max_tokens=100, + seed=8, + tools=tools, + tool_choice="auto", + messages=[ + { + "role": "system", + "content": "ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION", + }, + { + "role": "user", + "content": "Tell me a story about 3 sea creatures", + }, + ], + stream=False, + ) + + assert responses.choices[0].message.content == None + assert responses.choices[0].message.tool_calls == [ + { + "function": { + "arguments": { + "error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options." + }, + "description": None, + "name": "notify_error", + }, + "id": 0, + "type": "function", + } + ] + + assert responses == response_snapshot diff --git a/router/src/infer.rs b/router/src/infer.rs index 075e76d8..da0db072 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -4,9 +4,12 @@ use crate::{ ChatTemplateInputs, ChatTemplateVersions, Entry, GenerateRequest, GenerateStreamResponse, HubTokenizerConfig, Message, PrefillToken, Queue, Token, }; +use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools}; use futures::future::try_join_all; use minijinja::{Environment, ErrorKind, Template}; use nohash_hasher::IntMap; +use serde_json::{json, Map, Value}; +use std::collections::HashMap; use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, @@ -185,11 +188,15 @@ impl Infer { /// Apply the chat template to the chat request #[instrument(skip_all)] - pub(crate) fn apply_chat_template(&self, messages: Vec) -> Result { + pub(crate) fn apply_chat_template( + &self, + messages: Vec, + grammar_with_prompt: Option<(GrammarType, String)>, + ) -> Result { self.chat_template .as_ref() .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? - .apply(messages) + .apply(messages, grammar_with_prompt) .map_err(|e| { metrics::increment_counter!("tgi_request_failure", "err" => "template"); tracing::error!("{e}"); @@ -322,6 +329,7 @@ struct ChatTemplate { template: Template<'static, 'static>, bos_token: Option, eos_token: Option, + use_default_tool_template: bool, } impl ChatTemplate { @@ -329,6 +337,10 @@ impl ChatTemplate { let mut env = Box::new(Environment::new()); let template_str = template.into_boxed_str(); env.add_function("raise_exception", raise_exception); + + // check if contains the tools variable within the template + let use_default_tool_template = + !template_str.as_ref().replace(' ', "").contains("{{tools}}"); // leaking env and template_str as read-only, static resources for performance. let template = Box::leak(env) .template_from_str(Box::leak(template_str)) @@ -338,21 +350,159 @@ impl ChatTemplate { template, bos_token, eos_token, + use_default_tool_template, } } - fn apply(&self, messages: Vec) -> Result { + fn apply( + &self, + mut messages: Vec, + grammar_with_prompt: Option<(GrammarType, String)>, + ) -> Result { + if self.use_default_tool_template { + if let Some(last_message) = messages.last_mut() { + if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt { + last_message.content = Some(format!( + "{}\n---\n{}\n{}", + last_message.content.as_deref().unwrap_or_default(), + tool_prompt, + tools + )); + } + } + } + self.template .render(ChatTemplateInputs { messages, bos_token: self.bos_token.as_deref(), eos_token: self.eos_token.as_deref(), add_generation_prompt: true, + tools: None, + tools_prompt: None, }) .map_err(InferError::TemplateError) } } +pub struct ToolGrammar {} + +impl ToolGrammar { + pub fn apply( + tools: Option>, + tool_choice: Option, + ) -> Result, InferError> { + if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) { + // let tool_prompt = tool_prompt.unwrap_or_default(); + let tools_to_use = match tool_choice { + ToolType::FunctionName(name) => { + vec![req_tools + .iter() + .find(|tool| tool.function.name == *name) + .unwrap_or_else(|| panic!("Tool with name {} not found", name)) + .clone()] + } + ToolType::OneOf => req_tools.to_owned(), + }; + + // adds the error notification function for LLM feedback if required + let mut text_response_properties = Map::new(); + text_response_properties.insert( + "error".to_string(), + serde_json::json!({ + "type": "string", + "description": "The error or issue to notify" + }), + ); + text_response_properties.insert( + "_name".to_string(), + serde_json::json!({ + "type": "string", + "const": "notify_error" + }), + ); + + let functions: HashMap = tools_to_use + .iter() + .map(|tool| { + let func = tool.function.clone(); + + // Clone the existing parameters, which are expected to be a JSON object + let mut params = if let Value::Object(params) = &func.arguments { + params.clone() + } else { + Map::new() + }; + + // Insert the function's description at the top level, outside of properties + params.insert( + "description".to_string(), + Value::String(func.description.clone().unwrap_or_default()), + ); + + // Ensure 'properties' exists and is an object + let properties = params + .entry("properties".to_string()) + .or_insert_with(|| json!({})) + .as_object_mut() + .unwrap(); + + // Insert the constant for the function name inside 'properties' + properties.insert( + "_name".to_string(), + json!({ + "type": "string", + "const": func.name.clone(), + // "description": "The name of the function" + }), + ); + + // Check if 'required' exists, and it is an array. If not, create an empty array. + let required = params + .entry("required".to_string()) + .or_insert_with(|| json!([])) + .as_array_mut() + .unwrap(); + + // Add 'name' to the 'required' array if it is not already present + if !required.iter().any(|r| r == "_name") { + required.push(json!("_name")); + } + + (func.name, Value::Object(params)) + }) + .chain([( + "notify_error".to_string(), + serde_json::json!({ + "properties": text_response_properties, + "required": ["error", "_name"], + "type": "object" + }), + )]) + .collect(); + + let tools = Tools { + functions_map: FunctionsMap { functions }, + properties: Properties { + function: tools_to_use + .iter() + .map(|tool| FunctionRef { + ref_path: format!("#/$functions/{}", tool.function.name.clone()), + }) + .chain(std::iter::once(FunctionRef { + ref_path: "#/$functions/notify_error".to_string(), + })) + .collect(), + }, + }; + + return Ok(Some(tools)); + } + // Err(InferError::ToolError("No tools provided".to_string())) + Ok(None) + } +} + /// Batching logic /// Will be launched in a background Tokio task /// @@ -768,6 +918,8 @@ pub enum InferError { IncompleteGeneration, #[error("Template error: {0}")] TemplateError(#[from] minijinja::Error), + #[error("Tool error: {0}")] + ToolError(String), } impl InferError { @@ -778,6 +930,7 @@ impl InferError { InferError::ValidationError(_) => "validation", InferError::IncompleteGeneration => "incomplete_generation", InferError::TemplateError(_) => "template_error", + InferError::ToolError(_) => "tool_error", } } } @@ -849,6 +1002,7 @@ mod tests { bos_token: Some("[BOS]"), eos_token: Some("[EOS]"), add_generation_prompt: true, + ..Default::default() }; let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); @@ -924,6 +1078,7 @@ mod tests { bos_token: Some("[BOS]"), eos_token: Some("[EOS]"), add_generation_prompt: true, + ..Default::default() }; let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap(); @@ -998,6 +1153,7 @@ mod tests { bos_token: Some("[BOS]"), eos_token: Some("[EOS]"), add_generation_prompt: true, + ..Default::default() }; let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); @@ -1056,6 +1212,7 @@ mod tests { bos_token: Some("[BOS]"), eos_token: Some("[EOS]"), add_generation_prompt: true, + ..Default::default() }; let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); @@ -1115,6 +1272,7 @@ mod tests { add_generation_prompt: false, bos_token: Some(""), eos_token: Some(""), + ..Default::default() }, target: "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n", }, @@ -1126,6 +1284,7 @@ mod tests { add_generation_prompt: false, bos_token: Some(""), eos_token: Some(""), + ..Default::default() }, target: " Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!", }, @@ -1137,6 +1296,7 @@ mod tests { add_generation_prompt: false, bos_token: Some(""), eos_token: Some(""), + ..Default::default() }, target: " Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!", }, @@ -1148,6 +1308,7 @@ mod tests { add_generation_prompt: false, bos_token: Some(""), eos_token: Some(""), + ..Default::default() }, target: "Hello, how are you?I'm doing great. How can I help you today?I'd like to show off how chat templating works!", }, @@ -1159,6 +1320,7 @@ mod tests { add_generation_prompt: false, bos_token: Some(""), eos_token: Some("<|endoftext|>"), + ..Default::default() }, target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>", }, @@ -1170,6 +1332,7 @@ mod tests { add_generation_prompt: false, bos_token: Some(""), eos_token: Some("<|endoftext|>"), + ..Default::default() }, target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>", }, @@ -1182,6 +1345,7 @@ mod tests { add_generation_prompt: true, bos_token: Some(""), eos_token: Some(""), + ..Default::default() }, target: "[INST] <>\nYou are a friendly chatbot who always responds in the style of a pirate\n<>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]", }, @@ -1193,6 +1357,7 @@ mod tests { add_generation_prompt: true, bos_token: Some(""), eos_token: Some("<|endoftext|>"), + ..Default::default() }, target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>", }, @@ -1222,6 +1387,7 @@ mod tests { add_generation_prompt: false, bos_token: Some(""), eos_token: Some(""), + ..Default::default() }, target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate<|user|>\nHello, how are you?<|assistant|>\nI'm doing great. How can I help you today?<|user|>\nI'd like to show off how chat templating works!", }, @@ -1246,6 +1412,7 @@ mod tests { add_generation_prompt: true, bos_token: Some(""), eos_token: Some(""), + ..Default::default() }, target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate<|user|>\nHow many helicopters can a human eat in one sitting?<|assistant|>", }, @@ -1257,6 +1424,7 @@ mod tests { add_generation_prompt: false, bos_token: Some(""), eos_token: Some(""), + ..Default::default() }, target: "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n", }, @@ -1268,6 +1436,7 @@ mod tests { add_generation_prompt: false, bos_token: Some(""), eos_token: Some(""), + ..Default::default() }, target: "[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]", }, @@ -1279,6 +1448,7 @@ mod tests { add_generation_prompt: false, bos_token: Some(""), eos_token: Some(""), + ..Default::default() }, target: "[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?[INST] I'd like to show off how chat templating works! [/INST]", }, @@ -1290,6 +1460,7 @@ mod tests { add_generation_prompt: false, bos_token: Some(""), eos_token: Some(""), + ..Default::default() }, target: "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n", }, @@ -1302,6 +1473,7 @@ mod tests { add_generation_prompt: false, bos_token: Some(""), eos_token: Some(""), + ..Default::default() }, target: "GPT4 Correct User: Hello, how are you?<|end_of_turn|>GPT4 Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT4 Correct User: I'd like to show off how chat templating works!<|end_of_turn|>", }, @@ -1313,6 +1485,7 @@ mod tests { add_generation_prompt: false, bos_token: Some(""), eos_token: Some(""), + ..Default::default() }, target: "Hello, how are you?I'm doing great. How can I help you today?I'd like to show off how chat templating works!", }, @@ -1325,6 +1498,7 @@ mod tests { add_generation_prompt: false, bos_token: Some(""), eos_token: Some(""), + ..Default::default() }, target: "Source: user\n\n Hello, how are you? Source: assistant\n\n I'm doing great. How can I help you today? Source: user\n\n I'd like to show off how chat templating works! Source: assistant\nDestination: user\n\n ", }, @@ -1336,6 +1510,7 @@ mod tests { add_generation_prompt: false, bos_token: Some(""), eos_token: Some(""), + ..Default::default() }, target: "### User:\nHello, how are you?### Assistant:\nI'm doing great. How can I help you today?### User:\nI'd like to show off how chat templating works!", }, @@ -1347,6 +1522,7 @@ mod tests { add_generation_prompt: false, bos_token: Some(""), eos_token: Some(""), + ..Default::default() }, target: "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!", }, @@ -1358,6 +1534,7 @@ mod tests { add_generation_prompt: false, bos_token: Some("<ļ½œbeginā–ofā–sentenceļ½œ>"), eos_token: Some("<ļ½œendā–ofā–sentenceļ½œ>"), + ..Default::default() }, target: "<ļ½œbeginā–ofā–sentenceļ½œ>User: Hello, how are you?\n\nAssistant: I'm doing great. How can I help you today?<ļ½œendā–ofā–sentenceļ½œ>User: I'd like to show off how chat templating works!\n\n", }, @@ -1369,6 +1546,7 @@ mod tests { add_generation_prompt: false, bos_token: Some(""), eos_token: Some(""), + ..Default::default() }, target: "<|prompt|>Hello, how are you?<|answer|>I'm doing great. How can I help you today?<|prompt|>I'd like to show off how chat templating works!", }, @@ -1380,6 +1558,7 @@ mod tests { add_generation_prompt: false, bos_token: Some(""), eos_token: Some(""), + ..Default::default() }, target: "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n", }, @@ -1391,6 +1570,7 @@ mod tests { add_generation_prompt: false, bos_token: Some("<ļ½œbeginā–ofā–sentenceļ½œ>"), eos_token: Some("<|EOT|>"), + ..Default::default() }, target: "You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\n### Instruction:\nHello, how are you?\n### Response:\nI'm doing great. How can I help you today?\n<|EOT|>\n### Instruction:\nI'd like to show off how chat templating works!\n### Response:\n", }, @@ -1403,6 +1583,7 @@ mod tests { add_generation_prompt: false, bos_token: Some("<|endoftext|>"), eos_token: Some("<|endoftext|>"), + ..Default::default() }, target: "[INST] Hello, how are you? [RESP] I'm doing great. How can I help you today?<|endoftext|>[INST] I'd like to show off how chat templating works!", }, @@ -1414,6 +1595,7 @@ mod tests { add_generation_prompt: false, bos_token: Some(""), eos_token: Some(""), + ..Default::default() }, target: "Hello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]", }, @@ -1425,6 +1607,7 @@ mod tests { add_generation_prompt: false, bos_token: Some(""), eos_token: Some(""), + ..Default::default() }, target: "Below is an instruction that describes a task. Write a response that appropriately completes the request.### Instruction:Hello, how are you?### Response:I'm doing great. How can I help you today?### Instruction:I'd like to show off how chat templating works!", }, @@ -1436,6 +1619,7 @@ mod tests { add_generation_prompt: false, bos_token: Some("<ļ½œbeginā–ofā–sentenceļ½œ>"), eos_token: Some(""), + ..Default::default() }, target: "<ļ½œbeginā–ofā–sentenceļ½œ>You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\n### Instruction:\nHello, how are you?\n### Response:\nI'm doing great. How can I help you today?\n<|EOT|>\n### Instruction:\nI'd like to show off how chat templating works!\n", }, @@ -1451,6 +1635,7 @@ mod tests { add_generation_prompt: false, bos_token: Some(""), eos_token: Some(""), + ..Default::default() }, target: "You are a friendly chatbot who always responds in the style of a pirateYou are a friendly chatbot who always responds in the style of a pirate### Instruction: Hello, how are you?### Response: I'm doing great. How can I help you today?### Instruction: I'd like to show off how chat templating works!", }, diff --git a/router/src/lib.rs b/router/src/lib.rs index 2e412f1a..ddb28848 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -79,7 +79,7 @@ impl HubTokenizerConfig { } } -#[derive(Clone, Debug, Deserialize, ToSchema)] +#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)] #[serde(tag = "type", content = "value")] pub(crate) enum GrammarType { /// A string that represents a [JSON Schema](https://json-schema.org/). @@ -669,7 +669,7 @@ pub(crate) struct ChatRequest { #[serde(default = "default_tool_prompt")] #[schema( nullable = true, - example = "\"Based on the conversation, please choose the most appropriate tool to use: \"" + example = "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\"" )] pub tool_prompt: Option, @@ -682,7 +682,7 @@ pub(crate) struct ChatRequest { fn default_tool_prompt() -> Option { Some( - "\nBased on the conversation, please choose the most appropriate tool to use: ".to_string(), + "\nYou will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n".to_string(), ) } #[derive(Clone, Deserialize, ToSchema, Serialize)] @@ -727,26 +727,26 @@ mod deserialize_tool_choice { } } -#[derive(Debug, Deserialize, Serialize, ToSchema)] +#[derive(Debug, Deserialize, Serialize, ToSchema, PartialEq)] pub struct Tools { #[serde(flatten)] functions_map: FunctionsMap, properties: Properties, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, PartialEq)] struct FunctionsMap { #[serde(rename = "$functions")] functions: std::collections::HashMap, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, PartialEq)] struct FunctionRef { #[serde(rename = "$ref")] ref_path: String, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, PartialEq)] struct Properties { #[serde(serialize_with = "serialize_function")] function: Vec, @@ -767,7 +767,8 @@ pub(crate) struct FunctionDefinition { #[serde(default)] pub description: Option, pub name: String, - pub parameters: serde_json::Value, + #[serde(alias = "parameters")] + pub arguments: serde_json::Value, } #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] @@ -779,12 +780,14 @@ pub(crate) struct Tool { pub function: FunctionDefinition, } -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize, Default)] pub(crate) struct ChatTemplateInputs<'a> { messages: Vec, bos_token: Option<&'a str>, eos_token: Option<&'a str>, add_generation_prompt: bool, + tools: Option<&'a str>, + tools_prompt: Option<&'a str>, } #[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)] diff --git a/router/src/server.rs b/router/src/server.rs index b8f93514..c1648f9e 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,7 +1,7 @@ use crate::config::Config; /// HTTP Server logic use crate::health::Health; -use crate::infer::{InferError, InferResponse, InferStreamResponse}; +use crate::infer::{InferError, InferResponse, InferStreamResponse, ToolGrammar}; use crate::validation::ValidationError; use crate::{ BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, @@ -15,7 +15,7 @@ use crate::{ ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk, CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse, }; -use crate::{FunctionDefinition, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools}; +use crate::{FunctionDefinition, ToolCall, ToolType}; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; use axum::response::sse::{Event, KeepAlive, Sse}; @@ -29,7 +29,6 @@ use futures::Stream; use futures::TryStreamExt; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use serde_json::Value; -use std::collections::HashMap; use std::convert::Infallible; use std::net::SocketAddr; use std::sync::atomic::AtomicBool; @@ -757,19 +756,29 @@ async fn chat_completions( ) -> Result)> { metrics::increment_counter!("tgi_request_count"); - let stream = req.stream; - let max_new_tokens = req.max_tokens.or(Some(100)); - let repetition_penalty = req - .presence_penalty - // rescale repetition_penalty from (-2.0, 2.0) to (0.0, 4.0) - .map(|x| x + 2.0); - let logprobs = req.logprobs.unwrap_or(false); - let seed = req.seed; - let stop = req.stop.unwrap_or_default(); + let ChatRequest { + logprobs, + max_tokens, + messages, + presence_penalty, + seed, + stop, + stream, + tools, + tool_choice, + tool_prompt, + .. + } = req; - // apply chat template to flatten the request into a single input - let mut inputs = match infer.apply_chat_template(req.messages) { - Ok(inputs) => inputs, + let repetition_penalty = presence_penalty.map(|x| x + 2.0); + let max_new_tokens = max_tokens.or(Some(100)); + let logprobs = logprobs.unwrap_or(false); + let tool_prompt = tool_prompt.unwrap_or_default(); + let stop = stop.unwrap_or_default(); + + // extract tool grammar if present + let tool_grammar = match ToolGrammar::apply(tools, tool_choice) { + Ok(grammar) => grammar, Err(err) => { metrics::increment_counter!("tgi_request_failure", "err" => "validation"); tracing::error!("{err}"); @@ -783,60 +792,28 @@ async fn chat_completions( } }; - let tool_grammar = if let Some((req_tools, tool_choice)) = req.tools.zip(req.tool_choice) { - let tool_prompt = req.tool_prompt.unwrap_or_default(); - let tools_to_use = match tool_choice { - ToolType::FunctionName(name) => { - vec![req_tools - .iter() - .find(|tool| tool.function.name == *name) - .ok_or_else(|| { - ( - StatusCode::UNPROCESSABLE_ENTITY, - Json(ErrorResponse { - error: "Tool choice not found in tool names".to_string(), - error_type: "Tool not found".to_string(), - }), - ) - })? - .clone()] - } - ToolType::OneOf => req_tools.to_owned(), - }; + let grammar_with_prompt = tool_grammar + .as_ref() + .map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt)); - let functions: HashMap = tools_to_use - .iter() - .map(|tool| { - let func = tool.function.clone(); - (func.name, func.parameters) - }) - .collect(); + let typed_grammar = grammar_with_prompt + .as_ref() + .map(|(grammar, _)| grammar.clone()); - let tools = Tools { - functions_map: FunctionsMap { functions }, - properties: Properties { - function: tools_to_use - .iter() - .map(|tool| FunctionRef { - ref_path: format!("#/$functions/{}", tool.function.name.clone()), - }) - .collect(), - }, - }; - - let tools_str = serde_json::to_string(&tools).map_err(|e| { - ( + // apply chat template to flatten the request into a single input + let inputs = match infer.apply_chat_template(messages, grammar_with_prompt) { + Ok(inputs) => inputs, + Err(err) => { + metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + tracing::error!("{err}"); + return Err(( StatusCode::UNPROCESSABLE_ENTITY, Json(ErrorResponse { - error: e.to_string(), - error_type: "Input validation error".to_string(), + error: err.to_string(), + error_type: err.error_type().to_string(), }), - ) - })?; - inputs = format!("{inputs}{tool_prompt}{tools_str}"); - Some(GrammarType::Json(serde_json::json!(tools))) - } else { - None + )); + } }; // build the request passing some parameters @@ -860,7 +837,7 @@ async fn chat_completions( decoder_input_details: !stream, seed, top_n_tokens: req.top_logprobs, - grammar: tool_grammar.clone(), + grammar: typed_grammar, }, }; @@ -943,27 +920,28 @@ async fn chat_completions( }), ) })?; - let tool_calls = vec![ToolCall { id: 0, r#type: "function".to_string(), function: FunctionDefinition { description: None, - name: "tools".to_string(), - parameters: gen_text_value.get("function").map_or_else( - || { - serde_json::from_str(&generation.generated_text).map_err(|e| { - ( - StatusCode::UNPROCESSABLE_ENTITY, - Json(ErrorResponse { - error: e.to_string(), - error_type: "Input validation error".to_string(), - }), - ) - }) - }, - |f| Ok(f.clone()), - )?, + name: gen_text_value + .get("function") + .and_then(|f| f.get("_name")) + .and_then(|name| name.as_str()) + .unwrap_or("default_function_name") + .to_string(), + // Serialize the JSON object obtained from "function" to an escaped JSON string + arguments: gen_text_value + .get("function") + .map(|f| { + let mut f_cloned = f.clone(); + if let Value::Object(ref mut props) = f_cloned { + props.remove("_name"); + } + f_cloned + }) + .unwrap_or_default(), }, }]; (Some(tool_calls), None) @@ -1539,6 +1517,7 @@ impl From for (StatusCode, Json) { InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR, InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY, + InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY, }; (