From cfa73b5c99bc009903fbc340f8b77a6d4674455d Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 26 Aug 2024 20:19:38 -0400 Subject: [PATCH] Pr 2451 ci branch (#2454) * fix[router]: Fix tools not passed in chat template Signed-off-by: GitHub * feat: improve default tool serialization and lints * feat: refactor tool logic to include notify_error in prompt and adjust typing * fix: adjust non tool template apply * fix: simplify tool grammar logic and improve schema * feat: avoid skip tool test and avoid empty tool prompts * fix: increase test client timeout for grammar compilation tests --------- Signed-off-by: GitHub Co-authored-by: Simone Rossi --- Cargo.lock | 1 + clients/python/text_generation/client.py | 7 +- docs/openapi.json | 2 +- integration-tests/conftest.py | 2 +- integration-tests/models/test_tools_llama.py | 50 +++--- router/Cargo.toml | 2 +- router/src/infer/chat_template.rs | 57 ++++--- router/src/infer/mod.rs | 6 +- router/src/infer/tool_grammar.rs | 121 +++++++------- router/src/lib.rs | 15 +- router/src/server.rs | 160 ++++++++++++++++--- 11 files changed, 268 insertions(+), 155 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d298c379..aa5cb642 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2174,6 +2174,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "45f7e8e35b6c7b169bf40b0176d2c79291ab8ee53290b84e0668ab21d841aa9d" dependencies = [ "serde", + "serde_json", ] [[package]] diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index 12966747..45301b63 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -757,7 +757,12 @@ class AsyncClient: continue payload = byte_payload.decode("utf-8") if payload.startswith("data:"): - json_payload = json.loads(payload.lstrip("data:").rstrip("\n")) + payload_data = ( + payload.lstrip("data:").rstrip("\n").removeprefix(" ") + ) + if payload_data == "[DONE]": + break + json_payload = json.loads(payload_data) try: response = ChatCompletionChunk(**json_payload) yield response diff --git a/docs/openapi.json b/docs/openapi.json index df21e19d..fd64a3ab 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -924,7 +924,7 @@ "tool_prompt": { "type": "string", "description": "A prompt to be appended before the tools", - "example": "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\"", + "example": "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.", "nullable": true }, "tools": { diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 15af1cad..a8a77cd2 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -257,7 +257,7 @@ class IgnoreLogProbResponseComparator(ResponseComparator): class LauncherHandle: def __init__(self, port: int): - self.client = AsyncClient(f"http://localhost:{port}") + self.client = AsyncClient(f"http://localhost:{port}", timeout=30) def _inner_health(self): raise NotImplementedError diff --git a/integration-tests/models/test_tools_llama.py b/integration-tests/models/test_tools_llama.py index f831990a..9855cfda 100644 --- a/integration-tests/models/test_tools_llama.py +++ b/integration-tests/models/test_tools_llama.py @@ -36,6 +36,7 @@ tools = [ }, }, "required": ["location", "format"], + "additionalProperties": False, }, }, }, @@ -62,13 +63,13 @@ tools = [ }, }, "required": ["location", "format", "num_days"], + "additionalProperties": False, }, }, }, ] -@pytest.mark.skip(reason="Takes too long to run") @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot): @@ -76,7 +77,7 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna max_tokens=100, seed=1, tools=tools, - presence_penalty=-1.1, + temperature=0.0, messages=[ { "role": "system", @@ -91,19 +92,18 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna assert response.choices[0].message.content is None assert response.choices[0].message.tool_calls == [ { - "id": 0, + "id": "0", "type": "function", "function": { "description": None, "name": "get_current_weather", - "arguments": {"format": "celsius", "location": "New York, NY"}, + "arguments": {"format": "celsius", "location": "Brooklyn, NY"}, }, } ] assert response == response_snapshot -@pytest.mark.skip(reason="Takes too long to run") @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools_auto( @@ -113,8 +113,8 @@ async def test_flash_llama_grammar_tools_auto( max_tokens=100, seed=1, tools=tools, + temperature=0.0, tool_choice="auto", - presence_penalty=-1.1, messages=[ { "role": "system", @@ -129,12 +129,12 @@ async def test_flash_llama_grammar_tools_auto( assert response.choices[0].message.content is None assert response.choices[0].message.tool_calls == [ { - "id": 0, + "id": "0", "type": "function", "function": { "description": None, "name": "get_current_weather", - "arguments": {"format": "celsius", "location": "New York, NY"}, + "arguments": {"format": "celsius", "location": "Brooklyn, NY"}, }, } ] @@ -142,7 +142,6 @@ async def test_flash_llama_grammar_tools_auto( assert response == response_snapshot -@pytest.mark.skip(reason="Takes too long to run") @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools_choice( @@ -152,8 +151,8 @@ async def test_flash_llama_grammar_tools_choice( max_tokens=100, seed=1, tools=tools, + temperature=0.0, tool_choice="get_current_weather", - presence_penalty=-1.1, messages=[ { "role": "system", @@ -168,12 +167,12 @@ async def test_flash_llama_grammar_tools_choice( assert response.choices[0].message.content is None assert response.choices[0].message.tool_calls == [ { - "id": 0, + "id": "0", "type": "function", "function": { "description": None, "name": "get_current_weather", - "arguments": {"format": "celsius", "location": "New York, NY"}, + "arguments": {"format": "celsius", "location": "Brooklyn, NY"}, }, } ] @@ -181,7 +180,6 @@ async def test_flash_llama_grammar_tools_choice( assert response == response_snapshot -@pytest.mark.skip(reason="Takes too long to run") @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools_stream( @@ -191,8 +189,8 @@ async def test_flash_llama_grammar_tools_stream( max_tokens=100, seed=1, tools=tools, + temperature=0.0, tool_choice="get_current_weather", - presence_penalty=-1.1, messages=[ { "role": "system", @@ -210,11 +208,10 @@ async def test_flash_llama_grammar_tools_stream( async for response in responses: count += 1 - assert count == 38 + assert count == 48 assert response == response_snapshot -@pytest.mark.skip(reason="Takes too long to run") @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_grammar_tools_insufficient_information( @@ -222,13 +219,13 @@ async def test_flash_llama_grammar_tools_insufficient_information( ): responses = await flash_llama_grammar_tools.chat( max_tokens=100, - seed=8, + seed=24, tools=tools, tool_choice="auto", messages=[ { "role": "system", - "content": "ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION", + "content": "STRICTLY ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION", }, { "role": "user", @@ -239,18 +236,7 @@ async def test_flash_llama_grammar_tools_insufficient_information( ) assert responses.choices[0].message.content is None - assert responses.choices[0].message.tool_calls == [ - { - "function": { - "arguments": { - "error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options." - }, - "description": None, - "name": "notify_error", - }, - "id": 0, - "type": "function", - } - ] - + assert ( + responses.choices[0].message.tool_calls[0]["function"]["name"] == "notify_error" + ) assert responses == response_snapshot diff --git a/router/Cargo.toml b/router/Cargo.toml index 7773e212..45acab8e 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -46,7 +46,7 @@ ngrok = { version = "0.13.1", features = ["axum"], optional = true } init-tracing-opentelemetry = { version = "0.14.1", features = [ "opentelemetry-otlp", ] } -minijinja = { version = "2.0.2" } +minijinja = { version = "2.0.2", features = ["json"] } minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } futures-util = "0.3.30" regex = "1.10.3" diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index a8537818..bfa9421c 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -1,9 +1,7 @@ use std::collections::HashSet; use crate::infer::InferError; -use crate::{ - ChatTemplateInputs, GrammarType, Message, MessageChunk, TextMessage, TokenizerConfigToken, -}; +use crate::{ChatTemplateInputs, Message, MessageChunk, TextMessage, TokenizerConfigToken, Tool}; use minijinja::{Environment, ErrorKind, Template}; use minijinja_contrib::pycompat; @@ -32,6 +30,7 @@ impl ChatTemplate { env.set_unknown_method_callback(pycompat::unknown_method_callback); let template_str = template.into_boxed_str(); env.add_function("raise_exception", raise_exception); + tracing::debug!("Loading template: {:#?}", template_str); // leaking env and template_str as read-only, static resources for performance. let template = Box::leak(env) @@ -42,6 +41,7 @@ impl ChatTemplate { let variables = template.undeclared_variables(true); // check if the `tools` variable is used in the template let use_default_tool_template = !variables.contains("tools"); + tracing::debug!("Use default tool template: {}", use_default_tool_template); Self { template, @@ -56,25 +56,36 @@ impl ChatTemplate { &self, guideline: Option<&str>, mut messages: Vec, - grammar_with_prompt: Option<(GrammarType, String)>, + tools_and_prompt: Option<(Vec, String)>, ) -> Result { - if self.use_default_tool_template { - if let Some(last_message) = messages.last_mut() { - if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt { - last_message.content.push(MessageChunk::Text { - text: format!("\n---\n{}\n{}", tool_prompt, tools), - }); - } - } - } - - let messages: Vec = messages.into_iter().map(|c| c.into()).collect(); - // check if guideline is expected but not provided if self.variables.contains("guideline") && guideline.is_none() { return Err(InferError::MissingTemplateVariable("guideline".to_string())); } + let tools = match tools_and_prompt { + Some((tools, tool_prompt)) => { + // check if the `tools` variable is used in the template + // if not, we need to append the tools to the last message + let text = if self.use_default_tool_template { + match serde_json::to_string(&tools) { + Ok(tools_str) => format!("\n---\n{}\n{}", tools_str, tool_prompt), + Err(e) => return Err(InferError::ToolError(e.to_string())), + } + } else { + // if the `tools` variable is used in the template, we just append the tool_prompt + format!("\n---\n{}", tool_prompt) + }; + if let Some(last_message) = messages.last_mut() { + last_message.content.push(MessageChunk::Text { text }); + } + Some(tools) + } + None => None, + }; + + let messages: Vec = messages.into_iter().map(|c| c.into()).collect(); + self.template .render(ChatTemplateInputs { guideline, @@ -82,8 +93,7 @@ impl ChatTemplate { bos_token: self.bos_token.as_deref(), eos_token: self.eos_token.as_deref(), add_generation_prompt: true, - tools: None, - tools_prompt: None, + tools, }) .map_err(InferError::TemplateError) } @@ -95,7 +105,7 @@ mod tests { use crate::infer::chat_template::raise_exception; use crate::infer::ChatTemplate; use crate::{ - ChatTemplateInputs, GrammarType, Message, MessageContent, TextMessage, TokenizerConfigToken, + ChatTemplateInputs, Message, MessageContent, TextMessage, TokenizerConfigToken, Tool, }; use minijinja::Environment; @@ -854,11 +864,12 @@ mod tests { content: MessageContent::SingleText("Just testing".to_string()), }, ]; - let tools = serde_json::json!("[]"); + let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string(); + let tools: Vec = serde_json::from_str(&tools_string).unwrap(); let tool_prompt = "This default prompt will be used".to_string(); - let grammer_with_prompt = (GrammarType::Json(tools), tool_prompt); - let result = ct.apply(None, msgs, Some(grammer_with_prompt)); - let expected = "[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today? [INST] Just testing\n---\nThis default prompt will be used\n\"[]\" [/INST]".to_string(); + let tools_and_prompt = Some((tools, tool_prompt)); + let result = ct.apply(None, msgs, tools_and_prompt); + let expected = "[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today? [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}}}]\nThis default prompt will be used [/INST]".to_string(); assert_eq!(result.unwrap(), expected); } } diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index c9354d9a..81c0d38f 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -3,7 +3,7 @@ mod chat_template; pub mod tool_grammar; use crate::validation::{ValidGenerateRequest, Validation, ValidationError}; -use crate::GrammarType; +use crate::Tool; use crate::{ ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig, Message, PrefillToken, Token, @@ -140,12 +140,12 @@ impl Infer { &self, guideline: Option, messages: Vec, - grammar_with_prompt: Option<(GrammarType, String)>, + tools_and_prompt: Option<(Vec, String)>, ) -> Result { self.chat_template .as_ref() .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? - .apply(guideline.as_deref(), messages, grammar_with_prompt) + .apply(guideline.as_deref(), messages, tools_and_prompt) .map_err(|e| { metrics::counter!("tgi_request_failure", "err" => "template").increment(1); tracing::error!("{e}"); diff --git a/router/src/infer/tool_grammar.rs b/router/src/infer/tool_grammar.rs index 05027f30..4fe15720 100644 --- a/router/src/infer/tool_grammar.rs +++ b/router/src/infer/tool_grammar.rs @@ -1,5 +1,8 @@ use crate::infer::InferError; -use crate::{FunctionRef, FunctionsMap, Properties, Tool, ToolChoice, ToolType, Tools}; +use crate::{ + FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice, + ToolType, +}; use serde_json::{json, Map, Value}; use std::collections::HashMap; @@ -16,17 +19,38 @@ impl ToolGrammar { } pub fn apply( - tools: Option>, + tools: Vec, tool_choice: ToolChoice, - ) -> Result, InferError> { + ) -> Result<(Vec, Option), InferError> { // if no tools are provided, we return None - let tools = match tools { - Some(tools) if !tools.is_empty() => tools, - _ => return Ok(None), - }; + if tools.is_empty() { + return Ok((tools, None)); + } let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf); + let mut tools = tools.clone(); + + // add the notify_error function to the tools + let notify_error = Tool { + r#type: "function".to_string(), + function: FunctionDefinition { + name: "notify_error".to_string(), + description: Some("Notify an error or issue".to_string()), + arguments: json!({ + "type": "object", + "properties": { + "error": { + "type": "string", + "description": "The error or issue to notify" + } + }, + "required": ["error"] + }), + }, + }; + tools.push(notify_error); + // if tools are provided and no tool_choice we default to the OneOf let tools_to_use = match tool_choice { ToolType::FunctionName(name) => { @@ -35,87 +59,57 @@ impl ToolGrammar { ToolType::Function { function } => { vec![Self::find_tool_by_name(&tools, &function.name)?] } - ToolType::OneOf => tools, - ToolType::NoTool => return Ok(None), + ToolType::OneOf => tools.clone(), + ToolType::NoTool => return Ok((tools, None)), }; - // adds the error notification function for LLM feedback if required - let mut text_response_properties = Map::new(); - text_response_properties.insert( - "error".to_string(), - serde_json::json!({ - "type": "string", - "description": "The error or issue to notify" - }), - ); - text_response_properties.insert( - "_name".to_string(), - serde_json::json!({ - "type": "string", - "const": "notify_error" - }), - ); - let functions: HashMap = tools_to_use .iter() .map(|tool| { let func = tool.function.clone(); - // Clone the existing parameters, which are expected to be a JSON object - let mut params = if let Value::Object(params) = &func.arguments { - params.clone() - } else { - Map::new() - }; + let mut params = Map::new(); - // Insert the function's description at the top level, outside of properties params.insert( "description".to_string(), - Value::String(func.description.clone().unwrap_or_default()), + Value::String(func.description.unwrap_or_default()), ); - // Ensure 'properties' exists and is an object - let properties = params - .entry("properties".to_string()) - .or_insert_with(|| json!({})) - .as_object_mut() - .unwrap(); + let mut properties = Map::new(); + let mut required = vec![Value::String("_name".to_string())]; - // Insert the constant for the function name inside 'properties' properties.insert( "_name".to_string(), json!({ "type": "string", "const": func.name.clone(), - // "description": "The name of the function" }), ); - // Check if 'required' exists, and it is an array. If not, create an empty array. - let required = params - .entry("required".to_string()) - .or_insert_with(|| json!([])) - .as_array_mut() - .unwrap(); - - // Add 'name' to the 'required' array if it is not already present - if !required.iter().any(|r| r == "_name") { - required.push(json!("_name")); + if let Value::Object(args) = func.arguments { + if let Some(Value::Object(props)) = args.get("properties") { + properties.extend(props.clone()); + } + if let Some(Value::Array(reqs)) = args.get("required") { + required.extend(reqs.clone()); + } + params.insert( + "additionalProperties".to_string(), + Value::Bool( + args.get("additionalProperties").and_then(|v| v.as_str()) + == Some("true"), + ), + ); } + params.insert("properties".to_string(), Value::Object(properties)); + params.insert("required".to_string(), Value::Array(required)); + (func.name, Value::Object(params)) }) - .chain([( - "notify_error".to_string(), - serde_json::json!({ - "properties": text_response_properties, - "required": ["error", "_name"], - "type": "object" - }), - )]) .collect(); - let tools = Tools { + let tool_schema = JsonSchemaTool { functions_map: FunctionsMap { functions }, properties: Properties { function: tools_to_use @@ -123,13 +117,10 @@ impl ToolGrammar { .map(|tool| FunctionRef { ref_path: format!("#/$functions/{}", tool.function.name.clone()), }) - .chain(std::iter::once(FunctionRef { - ref_path: "#/$functions/notify_error".to_string(), - })) .collect(), }, }; - Ok(Some(tools)) + Ok((tools, Some(tool_schema))) } } diff --git a/router/src/lib.rs b/router/src/lib.rs index 1b2ff153..ce4f7c46 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -840,10 +840,10 @@ pub(crate) struct ChatRequest { pub tools: Option>, /// A prompt to be appended before the tools - #[serde(default = "default_tool_prompt")] + #[serde(default)] #[schema( nullable = true, - example = "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\"" + example = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables." )] pub tool_prompt: Option, @@ -865,10 +865,8 @@ pub(crate) struct ChatRequest { pub guideline: Option, } -fn default_tool_prompt() -> Option { - Some( - "\nYou will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n".to_string(), - ) +pub fn default_tool_prompt() -> String { + "\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string() } #[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)] @@ -910,7 +908,7 @@ impl From for ToolChoice { } #[derive(Debug, Deserialize, Serialize, ToSchema, PartialEq)] -pub struct Tools { +pub struct JsonSchemaTool { #[serde(flatten)] functions_map: FunctionsMap, properties: Properties, @@ -968,8 +966,7 @@ pub(crate) struct ChatTemplateInputs<'a> { bos_token: Option<&'a str>, eos_token: Option<&'a str>, add_generation_prompt: bool, - tools: Option<&'a str>, - tools_prompt: Option<&'a str>, + tools: Option>, guideline: Option<&'a str>, } diff --git a/router/src/server.rs b/router/src/server.rs index 8ec7a871..8ebd1a33 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -8,7 +8,7 @@ use crate::kserve::{ kserve_model_metadata, kserve_model_metadata_ready, }; use crate::validation::ValidationError; -use crate::ChatTokenizeResponse; +use crate::{default_tool_prompt, ChatTokenizeResponse}; use crate::{ usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, @@ -23,7 +23,7 @@ use crate::{ CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest, VertexResponse, }; -use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType, Tools}; +use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType}; use async_stream::__private::AsyncStream; use axum::extract::Extension; use axum::http::{HeaderMap, HeaderValue, Method, StatusCode}; @@ -146,7 +146,7 @@ async fn get_chat_tokenize( } = req; let tool_prompt = tool_prompt.unwrap_or_default(); - let (inputs, _grammar, _tool_grammar) = prepare_chat_input( + let (inputs, _grammar, _using_tools) = prepare_chat_input( &infer, response_format, tools, @@ -1158,14 +1158,16 @@ async fn chat_completions( let repetition_penalty = presence_penalty.map(|x| x + 2.0); let max_new_tokens = max_tokens.or(Some(100)); let logprobs = logprobs.unwrap_or(false); - let tool_prompt = tool_prompt.unwrap_or_default(); + let tool_prompt = tool_prompt + .filter(|s| !s.is_empty()) + .unwrap_or_else(default_tool_prompt); let stop = stop.unwrap_or_default(); // enable greedy only when temperature is 0 let (do_sample, temperature) = match temperature { Some(temperature) if temperature == 0.0 => (false, None), other => (true, other), }; - let (inputs, grammar, tool_grammar) = prepare_chat_input( + let (inputs, grammar, using_tools) = prepare_chat_input( &infer, response_format, tools, @@ -1221,7 +1223,7 @@ async fn chat_completions( }); // replace the content with the tool calls if grammar is present - let (content, tool_calls) = if tool_grammar.is_some() { + let (content, tool_calls) = if using_tools { (None, Some(vec![stream_token.token.text])) } else { let content = if !stream_token.token.special { @@ -1275,7 +1277,7 @@ async fn chat_completions( .unwrap_or_else(|_| std::time::Duration::from_secs(0)) .as_secs(); - let (tool_calls, output) = if tool_grammar.is_some() { + let (tool_calls, output) = if using_tools { let gen_text_value: Value = serde_json::from_str(&generation.generated_text).map_err(|e| { InferError::ToolError(format!( @@ -2539,7 +2541,7 @@ fn create_post_processor( Ok(post_processor) } -type PreparedInput = (String, Option, Option); +type PreparedInput = (String, Option, bool); fn prepare_chat_input( infer: &Infer, @@ -2556,19 +2558,139 @@ fn prepare_chat_input( )); } + // when response_format is set, tools are not included when applying the chat template to generate inputs if let Some(format) = response_format { let inputs = infer.apply_chat_template(guideline, messages, None)?; - return Ok((inputs, Some(format), None)); + return Ok((inputs, Some(format), false)); } - // if tools are set, apply the tool grammar and then the chat template - let tool_grammar: Option = ToolGrammar::apply(tools, tool_choice)?; - let grammar = tool_grammar - .as_ref() - .map(|t| GrammarType::Json(serde_json::json!(t))); - let tools_grammar_prompt = tool_grammar - .as_ref() - .map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt.into())); - let inputs = infer.apply_chat_template(guideline, messages, tools_grammar_prompt)?; - Ok((inputs, grammar, tool_grammar)) + // when no response_format is set and tools are included, apply the chat template with the tools + // to generate inputs + if let Some(tools) = tools { + let (updated_tools, tool_schema) = ToolGrammar::apply(tools, tool_choice)?; + + let grammar = tool_schema + .as_ref() + .map(|t| GrammarType::Json(serde_json::json!(t))); + + let inputs: String = infer.apply_chat_template( + guideline, + messages, + Some((updated_tools, tool_prompt.into())), + )?; + return Ok((inputs, grammar, tool_schema.is_some())); + } + + // if no response_format or tools are set simply apply the chat template to generate inputs + let inputs = infer.apply_chat_template(guideline, messages, None)?; + Ok((inputs, None, false)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ChatTemplateVersions; + use crate::HubTokenizerConfig; + use crate::TokenizerConfigToken; + use crate::Tool; + + use serde_json::json; + + #[test] + fn test_prepare_chat_input() { + // Mock Backend to avoid network requests + struct MockBackend; + + impl Backend for MockBackend { + fn schedule( + &self, + _request: crate::validation::ValidGenerateRequest, + ) -> Result< + tokio_stream::wrappers::UnboundedReceiverStream< + Result, + >, + InferError, + > { + unimplemented!("Never called in this test"); + } + fn health<'a, 'async_trait>( + &'a self, + _current_health: bool, + ) -> core::pin::Pin< + Box + core::marker::Send + 'async_trait>, + > + where + 'a: 'async_trait, + Self: 'async_trait, + { + unimplemented!("Never called in this test"); + } + } + + let backend = MockBackend {}; + + let mut tokenizer_config = HubTokenizerConfig::default(); + + // mock tokenizer config values + tokenizer_config.bos_token = Some(TokenizerConfigToken::String("".to_string())); + tokenizer_config.eos_token = Some(TokenizerConfigToken::String("".to_string())); + tokenizer_config.chat_template = Some( + ChatTemplateVersions::Single("{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n".to_string()) + ); + + let infer = Infer::new( + backend, + Validation::new(1, None, None, None, 1, 1, 1, 1, 1, false), + 1, + tokenizer_config, + HubProcessorConfig::default(), + ); + let response_format = None; + let tools = Some(vec![Tool { + r#type: "function".to_string(), + function: FunctionDefinition { + name: "get_current_weather".to_string(), + description: Some("Get the current weather".to_string()), + arguments: json!({ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location." + } + }, + "required": ["location", "format"] + }), + }, + }]); + let tool_prompt = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables."; + let guideline = None; + let messages = vec![Message { + name: None, + role: "user".to_string(), + content: MessageContent::SingleText( + "What is the weather like in New York?".to_string(), + ), + }]; + + let result = prepare_chat_input( + &infer, + response_format, + tools, + ToolChoice(None), + tool_prompt, + guideline, + messages, + ); + + assert!(result.is_ok()); + let (inputs, _grammar, using_tools) = result.unwrap(); + assert_eq!(using_tools, true); + assert_eq!(inputs, "[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"error\":{\"description\":\"The error or issue to notify\",\"type\":\"string\"}},\"required\":[\"error\"],\"type\":\"object\"}, \"description\": \"Notify an error or issue\", \"name\": \"notify_error\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string()); + } }