From 5489406c4a06780c23357880588f807a5f2f52e7 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 19 Nov 2024 13:31:59 -0500 Subject: [PATCH] PR 2634 CI - Fix the tool_choice format for named choice by adapting OpenAIs scheme (#2645) * add OpenAI like tool_choice for named choice * add tests * fix: run linter and bump api docs * fix: consolidate changes and remove old tool type * feat: improve, simplify and rename tool choice struct add required support and refactor * fix: simplify tool choice logic, improve tests, openapi and rust docs * fix: refactor away prepare_chat_input and improve tool grammar apply control flow * feat: update docs and add tool choice configuration section * fix: simplify naming, tool choice default and improve test * fix: adjust tool choice none logic, add test and small refactors * fix: add missing snapshot file * fix: adjust tool choice type in test * fix: adjust default when json tool choice is * fix: remove trailing space lint after rebase * fix: remove mostly mocked unit test --------- Co-authored-by: Linus Bierhoff --- docs/openapi.json | 19 +-- docs/source/basic_tutorials/using_guidance.md | 60 ++++++- ..._sea_creatures_stream_function_object.json | 27 +++ ...ammar_tools_sea_creatures_stream_none.json | 20 +++ ...r_tools_sea_creatures_stream_required.json | 28 +++ integration-tests/models/test_tools_llama.py | 143 +++++++++++++++- router/src/infer/tool_grammar.rs | 72 ++++---- router/src/lib.rs | 154 +++++++++++++---- router/src/server.rs | 161 +----------------- 9 files changed, 442 insertions(+), 242 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_function_object.json create mode 100644 integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_none.json create mode 100644 integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_required.json diff --git a/docs/openapi.json b/docs/openapi.json index e4c8ffdb..f42f9390 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -1102,6 +1102,7 @@ "$ref": "#/components/schemas/ToolChoice" } ], + "default": "auto", "nullable": true }, "tool_prompt": { @@ -2294,14 +2295,6 @@ } }, "ToolChoice": { - "allOf": [ - { - "$ref": "#/components/schemas/ToolType" - } - ], - "nullable": true - }, - "ToolType": { "oneOf": [ { "type": "string", @@ -2317,6 +2310,13 @@ "none" ] }, + { + "type": "string", + "description": "Means the model must call one or more tools.", + "enum": [ + "required" + ] + }, { "type": "object", "required": [ @@ -2329,8 +2329,7 @@ } } ], - "description": "Controls which (if any) tool is called by the model.", - "example": "auto" + "description": "" }, "Url": { "type": "object", diff --git a/docs/source/basic_tutorials/using_guidance.md b/docs/source/basic_tutorials/using_guidance.md index dfa3f0e4..2d55c952 100644 --- a/docs/source/basic_tutorials/using_guidance.md +++ b/docs/source/basic_tutorials/using_guidance.md @@ -315,8 +315,6 @@ print(chat.choices[0].message.tool_calls) TGI exposes an OpenAI-compatible API, which means you can use OpenAI's client libraries to interact with TGI's Messages API and Tool functions. -However there are some minor differences in the API, for example `tool_choice="auto"` will ALWAYS choose the tool for you. This is different from OpenAI's API where `tool_choice="auto"` will choose a tool if the model thinks it's necessary. - ```python from openai import OpenAI @@ -362,3 +360,61 @@ print(called) # }, # } ``` + +### Tool Choice Configuration + +When configuring how the model interacts with tools during a chat completion, there are several options for determining if or how a tool should be called. These options are controlled by the `tool_choice` parameter, which specifies the behavior of the model in relation to tool usage. The following modes are supported: + +1. **`auto`**: + + - The model decides whether to call a tool or generate a response message based on the user's input. + - If tools are provided, this is the default mode. + - Example usage: + ```python + tool_choice="auto" + ``` + +2. **`none`**: + + - The model will never call any tools and will only generate a response message. + - If no tools are provided, this is the default mode. + - Example usage: + ```python + tool_choice="none" + ``` + +3. **`required`**: + + - The model must call one or more tools and will not generate a response message on its own. + - Example usage: + ```python + tool_choice="required" + ``` + +4. **Specific Tool Call by Function Name**: + - You can force the model to call a specific tool either by specifying the tool function directly or by using an object definition. + - Two ways to do this: + 1. Provide the function name as a string: + ```python + tool_choice="get_current_weather" + ``` + 2. Use the function object format: + ```python + tool_choice={ + "type": "function", + "function": { + "name": "get_current_weather" + } + } + ``` + +These options allow flexibility when integrating tools with the chat completions endpoint. You can configure the model to either rely on tools automatically or force it to follow a predefined behavior, based on the needs of the task at hand. + +--- + +| **Tool Choice Option** | **Description** | **When to Use** | +| ------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------- | +| `auto` | The model decides whether to call a tool or generate a message. This is the default if tools are provided. | Use when you want the model to decide when a tool is necessary. | +| `none` | The model generates a message without calling any tools. This is the default if no tools are provided. | Use when you do not want the model to call any tools. | +| `required` | The model must call one or more tools and will not generate a message on its own. | Use when a tool call is mandatory, and you do not want a regular message generated. | +| Specific Tool Call (`name` or object) | Force the model to call a specific tool either by specifying its name (`tool_choice="get_current_weather"`) or using an object. | Use when you want to restrict the model to calling a particular tool for the response. | diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_function_object.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_function_object.json new file mode 100644 index 00000000..e64dd49d --- /dev/null +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_function_object.json @@ -0,0 +1,27 @@ +{ + "choices": [ + { + "delta": { + "role": "assistant", + "tool_calls": { + "function": { + "arguments": "<|eot_id|>", + "name": null + }, + "id": "", + "index": 0, + "type": "function" + } + }, + "finish_reason": "stop", + "index": 0, + "logprobs": null + } + ], + "created": 1729084854, + "id": "", + "model": "meta-llama/Llama-3.1-8B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "2.3.2-dev0-native", + "usage": null +} diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_none.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_none.json new file mode 100644 index 00000000..2ccab4a9 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_none.json @@ -0,0 +1,20 @@ +{ + "choices": [ + { + "delta": { + "content": " deep", + "role": "assistant", + "tool_calls": null + }, + "finish_reason": "length", + "index": 0, + "logprobs": null + } + ], + "created": 1729262528, + "id": "", + "model": "meta-llama/Llama-3.1-8B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "2.3.2-dev0-native", + "usage": null +} diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_required.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_required.json new file mode 100644 index 00000000..d8d538d6 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_required.json @@ -0,0 +1,28 @@ +{ + "choices": [ + { + "delta": { + "content": null, + "role": "assistant", + "tool_calls": { + "function": { + "arguments": "<|eot_id|>", + "name": null + }, + "id": "", + "index": 0, + "type": "function" + } + }, + "finish_reason": "stop", + "index": 0, + "logprobs": null + } + ], + "created": 1729084850, + "id": "", + "model": "meta-llama/Llama-3.1-8B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "2.3.2-dev0-native", + "usage": null +} diff --git a/integration-tests/models/test_tools_llama.py b/integration-tests/models/test_tools_llama.py index 98e75bb4..b5821945 100644 --- a/integration-tests/models/test_tools_llama.py +++ b/integration-tests/models/test_tools_llama.py @@ -1,4 +1,6 @@ import pytest +import requests +import json @pytest.fixture(scope="module") @@ -174,7 +176,7 @@ async def test_flash_llama_grammar_tools_choice( "function": { "description": None, "name": "get_current_weather", - "arguments": {"format": "celsius", "location": "Brooklyn, NY"}, + "arguments": {"format": "celsius", "location": "Brooklyn, New York"}, }, } ] @@ -327,3 +329,142 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream( == "Once upon a time, in the ocean, there lived three sea creatures. There was a wise old octopus named Bob, a mischievous seagull named Sam, and a gentle sea turtle named Luna. They all lived together in a beautiful coral reef, surrounded by colorful fish and swaying sea fans" ) assert last_response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_tools_sea_creatures_stream_required( + flash_llama_grammar_tools, response_snapshot +): + responses = await flash_llama_grammar_tools.chat( + max_tokens=100, + seed=24, + tools=tools, + tool_choice="required", + messages=[ + { + "role": "system", + "content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.", + }, + { + "role": "user", + "content": "Tell me a story about 3 sea creatures", + }, + ], + stream=True, + ) + + count = 0 + tool_calls_generated = "" + last_response = None + async for response in responses: + count += 1 + assert response.choices[0].delta.content is None + tool_calls_generated += response.choices[0].delta.tool_calls.function.arguments + last_response = response + + assert count == 29 + assert ( + tool_calls_generated + == '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "San Francisco, CA"}}<|eot_id|>' + ) + assert last_response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_tools_sea_creatures_stream_none( + flash_llama_grammar_tools, response_snapshot +): + responses = await flash_llama_grammar_tools.chat( + max_tokens=100, + seed=24, + tools=tools, + tool_choice="none", + messages=[ + { + "role": "system", + "content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.", + }, + { + "role": "user", + "content": "Tell me a story about 3 sea creatures", + }, + ], + stream=True, + ) + + count = 0 + content_generated = "" + last_response = None + async for response in responses: + count += 1 + content_generated += response.choices[0].delta.content + last_response = response + assert response.choices[0].delta.tool_calls is None + + assert count == 100 + print(content_generated) + assert ( + content_generated + == "Once upon a time, in a vibrant ocean filled with coral reefs and schools of shimmering fish, lived three dear friends: Luna the sea turtle, Finley the friendly fish, and Crusty the wise crab.\n\nLuna was the oldest of the three. She had traveled the world, exploring hidden caves and shipwrecks, and collecting sparkling shells and shiny pebbles. Her shell was a beautiful mosaic of blues and greens, and her gentle eyes twinkled with the secrets of the deep" + ) + assert last_response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object( + flash_llama_grammar_tools, response_snapshot +): + # using `requests` to send the request until the client library supports tool_choice as a function object + responses = requests.post( + f"{flash_llama_grammar_tools.base_url}/v1/chat/completions", + headers=flash_llama_grammar_tools.headers, + json={ + "model": "tgi", + "messages": [ + { + "role": "system", + "content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.", + }, + { + "role": "user", + "content": "Tell me a story about 3 sea creatures", + }, + ], + "tools": tools, + "tool_choice": { + "type": "function", + "function": {"name": "get_n_day_weather_forecast"}, + }, + "seed": 24, + "max_tokens": 100, + "stream": True, + }, + stream=True, + ) + # iterate over the response in chunks + count = 0 + tool_calls_generated = "" + last_response = None + for chunk in responses.iter_content(chunk_size=1024): + if chunk: + count += 1 + # remove the "data: " prefix, trailing newline, and split the chunk into individual lines + lines = chunk.decode("utf-8").replace("data: ", "").rstrip("\n").split("\n") + for line in lines: + if line == "[DONE]": + break + response = json.loads(line) + tool_calls_generated += response["choices"][0]["delta"]["tool_calls"][ + "function" + ]["arguments"] + last_response = response + + assert count == 39 + assert ( + tool_calls_generated + == '{"function": {"_name": "get_n_day_weather_forecast", "format": "celsius", "location": "San Francisco, CA", "num_days":3}}<|eot_id|>' + ) + assert last_response == response_snapshot diff --git a/router/src/infer/tool_grammar.rs b/router/src/infer/tool_grammar.rs index f86205fb..7770cd9d 100644 --- a/router/src/infer/tool_grammar.rs +++ b/router/src/infer/tool_grammar.rs @@ -1,7 +1,6 @@ use crate::infer::InferError; use crate::{ FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice, - ToolType, }; use serde_json::{json, Map, Value}; use std::collections::HashMap; @@ -21,45 +20,46 @@ impl ToolGrammar { pub fn apply( tools: Vec, tool_choice: ToolChoice, - ) -> Result<(Vec, Option), InferError> { - // if no tools are provided, we return None - if tools.is_empty() { - return Ok((tools, None)); - } - - let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf); - - let mut tools = tools.clone(); - - // add the no_tool function to the tools - let no_tool = Tool { - r#type: "function".to_string(), - function: FunctionDefinition { - name: "no_tool".to_string(), - description: Some("Open ened response with no specific tool selected".to_string()), - arguments: json!({ - "type": "object", - "properties": { - "content": { - "type": "string", - "description": "The response content", - } - }, - "required": ["content"] - }), - }, - }; - tools.push(no_tool); - - // if tools are provided and no tool_choice we default to the OneOf + ) -> Result, JsonSchemaTool)>, InferError> { let tools_to_use = match tool_choice { - ToolType::Function(function) => { + ToolChoice::Function(function) => { vec![Self::find_tool_by_name(&tools, &function.name)?] } - ToolType::OneOf => tools.clone(), - ToolType::NoTool => return Ok((tools, None)), + ToolChoice::Required => tools, + ToolChoice::Auto => { + // only add the no_tool function if the user has selected the auto option + tools + .iter() + .cloned() + .chain(std::iter::once(Tool { + r#type: "function".to_string(), + function: FunctionDefinition { + name: "no_tool".to_string(), + description: Some( + "Open ended response with no specific tool selected".to_string(), + ), + arguments: json!({ + "type": "object", + "properties": { + "content": { + "type": "string", + "description": "The response content", + } + }, + "required": ["content"] + }), + }, + })) + .collect::>() + } + ToolChoice::NoTool => vec![], }; + // if no tools are provided or if the user has selected the no_tool option, return None + if tools_to_use.is_empty() { + return Ok(None); + } + let functions: HashMap = tools_to_use .iter() .map(|tool| { @@ -118,6 +118,6 @@ impl ToolGrammar { }, }; - Ok((tools, Some(tool_schema))) + Ok(Some((tools_to_use, tool_schema))) } } diff --git a/router/src/lib.rs b/router/src/lib.rs index c0155852..7f093b41 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -12,8 +12,8 @@ mod sagemaker; pub mod usage_stats; mod vertex; +use crate::infer::tool_grammar::ToolGrammar; use crate::infer::{Infer, InferError}; -use crate::server::prepare_chat_input; use pyo3::prelude::*; use pyo3::types::IntoPyDict; use serde::{Deserialize, Serialize}; @@ -899,7 +899,7 @@ pub(crate) struct ChatRequest { /// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter. #[serde(default)] - #[schema(nullable = true, example = "null")] + #[schema(nullable = true, default = "auto", example = "auto")] pub tool_choice: ToolChoice, /// Response format constraints for the generation. @@ -953,15 +953,43 @@ impl ChatRequest { Some(temperature) if temperature == 0.0 => (false, None), other => (true, other), }; - let (inputs, grammar, using_tools) = prepare_chat_input( - infer, - response_format, - tools, - tool_choice, - &tool_prompt, - guideline, - messages, - )?; + + if response_format.is_some() && tools.is_some() { + return Err(InferError::ToolError( + "Grammar and tools are mutually exclusive".into(), + )); + } + + let (inputs, grammar, using_tools) = match response_format { + Some(format) => { + let inputs = infer.apply_chat_template(guideline, messages, None)?; + (inputs, Some(format), false) + } + None => { + if let Some(tools) = tools { + match ToolGrammar::apply(tools, tool_choice)? { + Some((updated_tools, tool_schema)) => { + let grammar = GrammarType::Json(serde_json::json!(tool_schema)); + let inputs: String = infer.apply_chat_template( + guideline, + messages, + Some((updated_tools, tool_prompt)), + )?; + (inputs, Some(grammar), true) + } + None => { + // same as if no response_format or tools are set + let inputs = infer.apply_chat_template(guideline, messages, None)?; + (inputs, None, false) + } + } + } else { + // if no response_format or tools are set simply apply the chat template to generate inputs + let inputs = infer.apply_chat_template(guideline, messages, None)?; + (inputs, None, false) + } + } + }; Ok(( GenerateRequest { @@ -1006,19 +1034,11 @@ pub fn default_tool_prompt() -> String { "\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string() } -#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)] -#[schema(example = "auto")] -/// Controls which (if any) tool is called by the model. -pub enum ToolType { - /// Means the model can pick between generating a message or calling one or more tools. - #[schema(rename = "auto")] - OneOf, - /// Means the model will not call any tool and instead generates a message. - #[schema(rename = "none")] - NoTool, - /// Forces the model to call a specific tool. - #[schema(rename = "function")] - Function(FunctionName), +#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] +#[serde(tag = "type")] +pub enum TypedChoice { + #[serde(rename = "function")] + Function { function: FunctionName }, } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)] @@ -1026,28 +1046,58 @@ pub struct FunctionName { pub name: String, } -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default, ToSchema)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema, Default)] #[serde(from = "ToolTypeDeserializer")] -pub struct ToolChoice(pub Option); +#[serde(rename_all = "snake_case")] +/// +pub enum ToolChoice { + /// Means the model can pick between generating a message or calling one or more tools. + #[default] + Auto, + /// Means the model will not call any tool and instead generates a message. + #[serde(rename = "none")] + NoTool, + /// Means the model must call one or more tools. + Required, + /// Forces the model to call a specific tool. This structure aligns with the `OpenAI` API schema to force a specific tool. + Function(FunctionName), +} -#[derive(Deserialize)] +#[derive(Deserialize, ToSchema)] #[serde(untagged)] +/// Controls which (if any) tool is called by the model. +/// - `none` means the model will not call any tool and instead generates a message. +/// - `auto` means the model can pick between generating a message or calling one or more tools. +/// - `required` means the model must call one or more tools. +/// - Specifying a particular tool via `{\"type\": \"function\", \"function\": {\"name\": \"my_function\"}}` forces the model to call that tool. +/// +/// `none` is the default when no tools are present. `auto` is the default if tools are present." enum ToolTypeDeserializer { + /// None means `null` was passed in the JSON, and the default choice is applied based on the presence of tools. Null, + + /// `auto` means the model can pick between generating a message or calling one or more tools. + #[schema(example = "auto")] String(String), - ToolType(ToolType), + + /// Specifying a particular tool forces the model to call that tool, with structured function details. + #[schema(example = r#"{"type": "function", "function": {"name": "my_function"}}"#)] + TypedChoice(TypedChoice), } impl From for ToolChoice { fn from(value: ToolTypeDeserializer) -> Self { match value { - ToolTypeDeserializer::Null => ToolChoice(None), + ToolTypeDeserializer::Null => ToolChoice::Auto, ToolTypeDeserializer::String(s) => match s.as_str() { - "none" => ToolChoice(Some(ToolType::NoTool)), - "auto" => ToolChoice(Some(ToolType::OneOf)), - _ => ToolChoice(Some(ToolType::Function(FunctionName { name: s }))), + "none" => ToolChoice::NoTool, + "auto" => ToolChoice::Auto, + "required" => ToolChoice::Required, + _ => ToolChoice::Function(FunctionName { name: s }), }, - ToolTypeDeserializer::ToolType(tool_type) => ToolChoice(Some(tool_type)), + ToolTypeDeserializer::TypedChoice(TypedChoice::Function { function }) => { + ToolChoice::Function(function) + } } } } @@ -1213,6 +1263,7 @@ pub(crate) enum OutputMessage { } #[derive(Clone, Debug, Deserialize, ToSchema)] +#[cfg_attr(test, derive(PartialEq))] pub(crate) struct GenerateRequest { #[schema(example = "My name is Olivier and I")] pub inputs: String, @@ -1653,4 +1704,41 @@ mod tests { r#"{"role":"assistant","tool_calls":[{"id":"0","type":"function","function":{"description":null,"name":"myfn","arguments":{"format":"csv"}}}]}"# ); } + + #[test] + fn tool_choice_formats() { + #[derive(Deserialize)] + struct TestRequest { + tool_choice: ToolChoice, + } + + let de_none: TestRequest = serde_json::from_str(r#"{"tool_choice":"none"}"#).unwrap(); + assert_eq!(de_none.tool_choice, ToolChoice::NoTool); + + let de_auto: TestRequest = serde_json::from_str(r#"{"tool_choice":"auto"}"#).unwrap(); + assert_eq!(de_auto.tool_choice, ToolChoice::Auto); + + let de_required: TestRequest = + serde_json::from_str(r#"{"tool_choice":"required"}"#).unwrap(); + assert_eq!(de_required.tool_choice, ToolChoice::Required); + + let de_named: TestRequest = serde_json::from_str(r#"{"tool_choice":"myfn"}"#).unwrap(); + assert_eq!( + de_named.tool_choice, + ToolChoice::Function(FunctionName { + name: "myfn".to_string(), + }) + ); + + let de_openai_named: TestRequest = serde_json::from_str( + r#"{"tool_choice":{"type":"function","function":{"name":"myfn"}}}"#, + ) + .unwrap(); + assert_eq!( + de_openai_named.tool_choice, + ToolChoice::Function(FunctionName { + name: "myfn".to_string(), + }) + ); + } } diff --git a/router/src/server.rs b/router/src/server.rs index cbb04174..c85635ff 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,6 +1,5 @@ /// HTTP Server logic use crate::config::Config; -use crate::infer::tool_grammar::ToolGrammar; use crate::infer::{Backend, Infer, InferError, InferResponse, InferStreamResponse}; #[cfg(feature = "kserve")] use crate::kserve::{ @@ -28,7 +27,7 @@ use crate::{ ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal, CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, }; -use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType}; +use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice}; use crate::{ModelInfo, ModelsInfo}; use async_stream::__private::AsyncStream; use axum::extract::Extension; @@ -1559,7 +1558,6 @@ GrammarType, Usage, StreamOptions, DeltaToolCall, -ToolType, Tool, ToolCall, Function, @@ -2525,160 +2523,3 @@ pub enum WebServerError { #[error("Axum error: {0}")] Axum(#[from] axum::BoxError), } - -type PreparedInput = (String, Option, bool); - -pub(crate) fn prepare_chat_input( - infer: &Infer, - response_format: Option, - tools: Option>, - tool_choice: ToolChoice, - tool_prompt: &str, - guideline: Option, - messages: Vec, -) -> Result { - if response_format.is_some() && tools.is_some() { - return Err(InferError::ToolError( - "Grammar and tools are mutually exclusive".into(), - )); - } - - // when response_format is set, tools are not included when applying the chat template to generate inputs - if let Some(format) = response_format { - let inputs = infer.apply_chat_template(guideline, messages, None)?; - return Ok((inputs, Some(format), false)); - } - - // when no response_format is set and tools are included, apply the chat template with the tools - // to generate inputs - if let Some(tools) = tools { - let (updated_tools, tool_schema) = ToolGrammar::apply(tools, tool_choice)?; - - let grammar = tool_schema - .as_ref() - .map(|t| GrammarType::Json(serde_json::json!(t))); - - let inputs: String = infer.apply_chat_template( - guideline, - messages, - Some((updated_tools, tool_prompt.into())), - )?; - return Ok((inputs, grammar, tool_schema.is_some())); - } - - // if no response_format or tools are set simply apply the chat template to generate inputs - let inputs = infer.apply_chat_template(guideline, messages, None)?; - Ok((inputs, None, false)) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::ChatTemplateVersions; - use crate::HubTokenizerConfig; - use crate::TokenizerConfigToken; - use crate::Tool; - - use crate::tests::get_tokenizer; - use serde_json::json; - - #[tokio::test] - async fn test_prepare_chat_input() { - // Mock Backend to avoid network requests - struct MockBackend; - - impl Backend for MockBackend { - fn schedule( - &self, - _request: crate::validation::ValidGenerateRequest, - ) -> Result< - tokio_stream::wrappers::UnboundedReceiverStream< - Result, - >, - InferError, - > { - unimplemented!("Never called in this test"); - } - fn health<'a, 'async_trait>( - &'a self, - _current_health: bool, - ) -> core::pin::Pin< - Box + core::marker::Send + 'async_trait>, - > - where - 'a: 'async_trait, - Self: 'async_trait, - { - unimplemented!("Never called in this test"); - } - } - - let backend = MockBackend {}; - - let mut tokenizer_config = HubTokenizerConfig::default(); - - // mock tokenizer config values - tokenizer_config.bos_token = Some(TokenizerConfigToken::String("".to_string())); - tokenizer_config.eos_token = Some(TokenizerConfigToken::String("".to_string())); - tokenizer_config.chat_template = Some( - ChatTemplateVersions::Single("{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n".to_string()) - ); - - let tokenizer = get_tokenizer(); - - let infer = Infer::new( - backend, - Validation::new(1, tokenizer, None, None, 1, 1, 1, 1, 1, false), - 1, - tokenizer_config, - HubProcessorConfig::default(), - ); - let response_format = None; - let tools = Some(vec![Tool { - r#type: "function".to_string(), - function: FunctionDefinition { - name: "get_current_weather".to_string(), - description: Some("Get the current weather".to_string()), - arguments: json!({ - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA" - }, - "format": { - "type": "string", - "enum": ["celsius", "fahrenheit"], - "description": "The temperature unit to use. Infer this from the users location." - } - }, - "required": ["location", "format"] - }), - }, - }]); - let tool_prompt = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables."; - let guideline = None; - let messages = vec![Message { - name: None, - role: "user".to_string(), - content: MessageContent::SingleText( - "What is the weather like in New York?".to_string(), - ), - }]; - - let result = prepare_chat_input( - &infer, - response_format, - tools, - ToolChoice(None), - tool_prompt, - guideline, - messages, - ); - - assert!(result.is_ok()); - let (inputs, _grammar, using_tools) = result.expect("Failed to prepare chat input"); - assert_eq!(using_tools, true); - assert_eq!(inputs, "[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"content\":{\"description\":\"The response content\",\"type\":\"string\"}},\"required\":[\"content\"],\"type\":\"object\"}, \"description\": \"Open ened response with no specific tool selected\", \"name\": \"no_tool\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string()); - } -}