PR 2634 CI - Fix the tool_choice format for named choice by adapting OpenAIs scheme (#2645)

* add OpenAI like tool_choice for named choice

* add tests

* fix: run linter and bump api docs

* fix: consolidate changes and remove old tool type

* feat: improve, simplify and rename tool choice struct add required support and refactor

* fix: simplify tool choice logic, improve tests, openapi and rust docs

* fix: refactor away prepare_chat_input and improve tool grammar apply control flow

* feat: update docs and add tool choice configuration section

* fix: simplify naming, tool choice default and improve test

* fix: adjust tool choice none logic, add test and small refactors

* fix: add missing snapshot file

* fix: adjust tool choice type in test

* fix: adjust default when json tool choice is

* fix: remove trailing space lint after rebase

* fix: remove mostly mocked unit test

---------

Co-authored-by: Linus Bierhoff <linus.bierhoff@icloud.com>
This commit is contained in:
drbh 2024-11-19 13:31:59 -05:00 committed by GitHub
parent 2007a9473a
commit 5489406c4a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 442 additions and 242 deletions

View File

@ -1102,6 +1102,7 @@
"$ref": "#/components/schemas/ToolChoice" "$ref": "#/components/schemas/ToolChoice"
} }
], ],
"default": "auto",
"nullable": true "nullable": true
}, },
"tool_prompt": { "tool_prompt": {
@ -2294,14 +2295,6 @@
} }
}, },
"ToolChoice": { "ToolChoice": {
"allOf": [
{
"$ref": "#/components/schemas/ToolType"
}
],
"nullable": true
},
"ToolType": {
"oneOf": [ "oneOf": [
{ {
"type": "string", "type": "string",
@ -2317,6 +2310,13 @@
"none" "none"
] ]
}, },
{
"type": "string",
"description": "Means the model must call one or more tools.",
"enum": [
"required"
]
},
{ {
"type": "object", "type": "object",
"required": [ "required": [
@ -2329,8 +2329,7 @@
} }
} }
], ],
"description": "Controls which (if any) tool is called by the model.", "description": "<https://platform.openai.com/docs/guides/function-calling/configuring-function-calling-behavior-using-the-tool_choice-parameter>"
"example": "auto"
}, },
"Url": { "Url": {
"type": "object", "type": "object",

View File

@ -315,8 +315,6 @@ print(chat.choices[0].message.tool_calls)
TGI exposes an OpenAI-compatible API, which means you can use OpenAI's client libraries to interact with TGI's Messages API and Tool functions. TGI exposes an OpenAI-compatible API, which means you can use OpenAI's client libraries to interact with TGI's Messages API and Tool functions.
However there are some minor differences in the API, for example `tool_choice="auto"` will ALWAYS choose the tool for you. This is different from OpenAI's API where `tool_choice="auto"` will choose a tool if the model thinks it's necessary.
```python ```python
from openai import OpenAI from openai import OpenAI
@ -362,3 +360,61 @@ print(called)
# }, # },
# } # }
``` ```
### Tool Choice Configuration
When configuring how the model interacts with tools during a chat completion, there are several options for determining if or how a tool should be called. These options are controlled by the `tool_choice` parameter, which specifies the behavior of the model in relation to tool usage. The following modes are supported:
1. **`auto`**:
- The model decides whether to call a tool or generate a response message based on the user's input.
- If tools are provided, this is the default mode.
- Example usage:
```python
tool_choice="auto"
```
2. **`none`**:
- The model will never call any tools and will only generate a response message.
- If no tools are provided, this is the default mode.
- Example usage:
```python
tool_choice="none"
```
3. **`required`**:
- The model must call one or more tools and will not generate a response message on its own.
- Example usage:
```python
tool_choice="required"
```
4. **Specific Tool Call by Function Name**:
- You can force the model to call a specific tool either by specifying the tool function directly or by using an object definition.
- Two ways to do this:
1. Provide the function name as a string:
```python
tool_choice="get_current_weather"
```
2. Use the function object format:
```python
tool_choice={
"type": "function",
"function": {
"name": "get_current_weather"
}
}
```
These options allow flexibility when integrating tools with the chat completions endpoint. You can configure the model to either rely on tools automatically or force it to follow a predefined behavior, based on the needs of the task at hand.
---
| **Tool Choice Option** | **Description** | **When to Use** |
| ------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------- |
| `auto` | The model decides whether to call a tool or generate a message. This is the default if tools are provided. | Use when you want the model to decide when a tool is necessary. |
| `none` | The model generates a message without calling any tools. This is the default if no tools are provided. | Use when you do not want the model to call any tools. |
| `required` | The model must call one or more tools and will not generate a message on its own. | Use when a tool call is mandatory, and you do not want a regular message generated. |
| Specific Tool Call (`name` or object) | Force the model to call a specific tool either by specifying its name (`tool_choice="get_current_weather"`) or using an object. | Use when you want to restrict the model to calling a particular tool for the response. |

View File

@ -0,0 +1,27 @@
{
"choices": [
{
"delta": {
"role": "assistant",
"tool_calls": {
"function": {
"arguments": "<|eot_id|>",
"name": null
},
"id": "",
"index": 0,
"type": "function"
}
},
"finish_reason": "stop",
"index": 0,
"logprobs": null
}
],
"created": 1729084854,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.3.2-dev0-native",
"usage": null
}

View File

@ -0,0 +1,20 @@
{
"choices": [
{
"delta": {
"content": " deep",
"role": "assistant",
"tool_calls": null
},
"finish_reason": "length",
"index": 0,
"logprobs": null
}
],
"created": 1729262528,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.3.2-dev0-native",
"usage": null
}

View File

@ -0,0 +1,28 @@
{
"choices": [
{
"delta": {
"content": null,
"role": "assistant",
"tool_calls": {
"function": {
"arguments": "<|eot_id|>",
"name": null
},
"id": "",
"index": 0,
"type": "function"
}
},
"finish_reason": "stop",
"index": 0,
"logprobs": null
}
],
"created": 1729084850,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.3.2-dev0-native",
"usage": null
}

View File

@ -1,4 +1,6 @@
import pytest import pytest
import requests
import json
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
@ -174,7 +176,7 @@ async def test_flash_llama_grammar_tools_choice(
"function": { "function": {
"description": None, "description": None,
"name": "get_current_weather", "name": "get_current_weather",
"arguments": {"format": "celsius", "location": "Brooklyn, NY"}, "arguments": {"format": "celsius", "location": "Brooklyn, New York"},
}, },
} }
] ]
@ -327,3 +329,142 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream(
== "Once upon a time, in the ocean, there lived three sea creatures. There was a wise old octopus named Bob, a mischievous seagull named Sam, and a gentle sea turtle named Luna. They all lived together in a beautiful coral reef, surrounded by colorful fish and swaying sea fans" == "Once upon a time, in the ocean, there lived three sea creatures. There was a wise old octopus named Bob, a mischievous seagull named Sam, and a gentle sea turtle named Luna. They all lived together in a beautiful coral reef, surrounded by colorful fish and swaying sea fans"
) )
assert last_response == response_snapshot assert last_response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_sea_creatures_stream_required(
flash_llama_grammar_tools, response_snapshot
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=24,
tools=tools,
tool_choice="required",
messages=[
{
"role": "system",
"content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.",
},
{
"role": "user",
"content": "Tell me a story about 3 sea creatures",
},
],
stream=True,
)
count = 0
tool_calls_generated = ""
last_response = None
async for response in responses:
count += 1
assert response.choices[0].delta.content is None
tool_calls_generated += response.choices[0].delta.tool_calls.function.arguments
last_response = response
assert count == 29
assert (
tool_calls_generated
== '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "San Francisco, CA"}}<|eot_id|>'
)
assert last_response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_sea_creatures_stream_none(
flash_llama_grammar_tools, response_snapshot
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=24,
tools=tools,
tool_choice="none",
messages=[
{
"role": "system",
"content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.",
},
{
"role": "user",
"content": "Tell me a story about 3 sea creatures",
},
],
stream=True,
)
count = 0
content_generated = ""
last_response = None
async for response in responses:
count += 1
content_generated += response.choices[0].delta.content
last_response = response
assert response.choices[0].delta.tool_calls is None
assert count == 100
print(content_generated)
assert (
content_generated
== "Once upon a time, in a vibrant ocean filled with coral reefs and schools of shimmering fish, lived three dear friends: Luna the sea turtle, Finley the friendly fish, and Crusty the wise crab.\n\nLuna was the oldest of the three. She had traveled the world, exploring hidden caves and shipwrecks, and collecting sparkling shells and shiny pebbles. Her shell was a beautiful mosaic of blues and greens, and her gentle eyes twinkled with the secrets of the deep"
)
assert last_response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object(
flash_llama_grammar_tools, response_snapshot
):
# using `requests` to send the request until the client library supports tool_choice as a function object
responses = requests.post(
f"{flash_llama_grammar_tools.base_url}/v1/chat/completions",
headers=flash_llama_grammar_tools.headers,
json={
"model": "tgi",
"messages": [
{
"role": "system",
"content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.",
},
{
"role": "user",
"content": "Tell me a story about 3 sea creatures",
},
],
"tools": tools,
"tool_choice": {
"type": "function",
"function": {"name": "get_n_day_weather_forecast"},
},
"seed": 24,
"max_tokens": 100,
"stream": True,
},
stream=True,
)
# iterate over the response in chunks
count = 0
tool_calls_generated = ""
last_response = None
for chunk in responses.iter_content(chunk_size=1024):
if chunk:
count += 1
# remove the "data: " prefix, trailing newline, and split the chunk into individual lines
lines = chunk.decode("utf-8").replace("data: ", "").rstrip("\n").split("\n")
for line in lines:
if line == "[DONE]":
break
response = json.loads(line)
tool_calls_generated += response["choices"][0]["delta"]["tool_calls"][
"function"
]["arguments"]
last_response = response
assert count == 39
assert (
tool_calls_generated
== '{"function": {"_name": "get_n_day_weather_forecast", "format": "celsius", "location": "San Francisco, CA", "num_days":3}}<|eot_id|>'
)
assert last_response == response_snapshot

View File

@ -1,7 +1,6 @@
use crate::infer::InferError; use crate::infer::InferError;
use crate::{ use crate::{
FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice, FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice,
ToolType,
}; };
use serde_json::{json, Map, Value}; use serde_json::{json, Map, Value};
use std::collections::HashMap; use std::collections::HashMap;
@ -21,45 +20,46 @@ impl ToolGrammar {
pub fn apply( pub fn apply(
tools: Vec<Tool>, tools: Vec<Tool>,
tool_choice: ToolChoice, tool_choice: ToolChoice,
) -> Result<(Vec<Tool>, Option<JsonSchemaTool>), InferError> { ) -> Result<Option<(Vec<Tool>, JsonSchemaTool)>, InferError> {
// if no tools are provided, we return None
if tools.is_empty() {
return Ok((tools, None));
}
let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
let mut tools = tools.clone();
// add the no_tool function to the tools
let no_tool = Tool {
r#type: "function".to_string(),
function: FunctionDefinition {
name: "no_tool".to_string(),
description: Some("Open ened response with no specific tool selected".to_string()),
arguments: json!({
"type": "object",
"properties": {
"content": {
"type": "string",
"description": "The response content",
}
},
"required": ["content"]
}),
},
};
tools.push(no_tool);
// if tools are provided and no tool_choice we default to the OneOf
let tools_to_use = match tool_choice { let tools_to_use = match tool_choice {
ToolType::Function(function) => { ToolChoice::Function(function) => {
vec![Self::find_tool_by_name(&tools, &function.name)?] vec![Self::find_tool_by_name(&tools, &function.name)?]
} }
ToolType::OneOf => tools.clone(), ToolChoice::Required => tools,
ToolType::NoTool => return Ok((tools, None)), ToolChoice::Auto => {
// only add the no_tool function if the user has selected the auto option
tools
.iter()
.cloned()
.chain(std::iter::once(Tool {
r#type: "function".to_string(),
function: FunctionDefinition {
name: "no_tool".to_string(),
description: Some(
"Open ended response with no specific tool selected".to_string(),
),
arguments: json!({
"type": "object",
"properties": {
"content": {
"type": "string",
"description": "The response content",
}
},
"required": ["content"]
}),
},
}))
.collect::<Vec<_>>()
}
ToolChoice::NoTool => vec![],
}; };
// if no tools are provided or if the user has selected the no_tool option, return None
if tools_to_use.is_empty() {
return Ok(None);
}
let functions: HashMap<String, serde_json::Value> = tools_to_use let functions: HashMap<String, serde_json::Value> = tools_to_use
.iter() .iter()
.map(|tool| { .map(|tool| {
@ -118,6 +118,6 @@ impl ToolGrammar {
}, },
}; };
Ok((tools, Some(tool_schema))) Ok(Some((tools_to_use, tool_schema)))
} }
} }

View File

@ -12,8 +12,8 @@ mod sagemaker;
pub mod usage_stats; pub mod usage_stats;
mod vertex; mod vertex;
use crate::infer::tool_grammar::ToolGrammar;
use crate::infer::{Infer, InferError}; use crate::infer::{Infer, InferError};
use crate::server::prepare_chat_input;
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::types::IntoPyDict; use pyo3::types::IntoPyDict;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -899,7 +899,7 @@ pub(crate) struct ChatRequest {
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter. /// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
#[serde(default)] #[serde(default)]
#[schema(nullable = true, example = "null")] #[schema(nullable = true, default = "auto", example = "auto")]
pub tool_choice: ToolChoice, pub tool_choice: ToolChoice,
/// Response format constraints for the generation. /// Response format constraints for the generation.
@ -953,15 +953,43 @@ impl ChatRequest {
Some(temperature) if temperature == 0.0 => (false, None), Some(temperature) if temperature == 0.0 => (false, None),
other => (true, other), other => (true, other),
}; };
let (inputs, grammar, using_tools) = prepare_chat_input(
infer, if response_format.is_some() && tools.is_some() {
response_format, return Err(InferError::ToolError(
tools, "Grammar and tools are mutually exclusive".into(),
tool_choice, ));
&tool_prompt, }
guideline,
messages, let (inputs, grammar, using_tools) = match response_format {
)?; Some(format) => {
let inputs = infer.apply_chat_template(guideline, messages, None)?;
(inputs, Some(format), false)
}
None => {
if let Some(tools) = tools {
match ToolGrammar::apply(tools, tool_choice)? {
Some((updated_tools, tool_schema)) => {
let grammar = GrammarType::Json(serde_json::json!(tool_schema));
let inputs: String = infer.apply_chat_template(
guideline,
messages,
Some((updated_tools, tool_prompt)),
)?;
(inputs, Some(grammar), true)
}
None => {
// same as if no response_format or tools are set
let inputs = infer.apply_chat_template(guideline, messages, None)?;
(inputs, None, false)
}
}
} else {
// if no response_format or tools are set simply apply the chat template to generate inputs
let inputs = infer.apply_chat_template(guideline, messages, None)?;
(inputs, None, false)
}
}
};
Ok(( Ok((
GenerateRequest { GenerateRequest {
@ -1006,19 +1034,11 @@ pub fn default_tool_prompt() -> String {
"\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string() "\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string()
} }
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)] #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
#[schema(example = "auto")] #[serde(tag = "type")]
/// Controls which (if any) tool is called by the model. pub enum TypedChoice {
pub enum ToolType { #[serde(rename = "function")]
/// Means the model can pick between generating a message or calling one or more tools. Function { function: FunctionName },
#[schema(rename = "auto")]
OneOf,
/// Means the model will not call any tool and instead generates a message.
#[schema(rename = "none")]
NoTool,
/// Forces the model to call a specific tool.
#[schema(rename = "function")]
Function(FunctionName),
} }
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)]
@ -1026,28 +1046,58 @@ pub struct FunctionName {
pub name: String, pub name: String,
} }
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default, ToSchema)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema, Default)]
#[serde(from = "ToolTypeDeserializer")] #[serde(from = "ToolTypeDeserializer")]
pub struct ToolChoice(pub Option<ToolType>); #[serde(rename_all = "snake_case")]
/// <https://platform.openai.com/docs/guides/function-calling/configuring-function-calling-behavior-using-the-tool_choice-parameter>
pub enum ToolChoice {
/// Means the model can pick between generating a message or calling one or more tools.
#[default]
Auto,
/// Means the model will not call any tool and instead generates a message.
#[serde(rename = "none")]
NoTool,
/// Means the model must call one or more tools.
Required,
/// Forces the model to call a specific tool. This structure aligns with the `OpenAI` API schema to force a specific tool.
Function(FunctionName),
}
#[derive(Deserialize)] #[derive(Deserialize, ToSchema)]
#[serde(untagged)] #[serde(untagged)]
/// Controls which (if any) tool is called by the model.
/// - `none` means the model will not call any tool and instead generates a message.
/// - `auto` means the model can pick between generating a message or calling one or more tools.
/// - `required` means the model must call one or more tools.
/// - Specifying a particular tool via `{\"type\": \"function\", \"function\": {\"name\": \"my_function\"}}` forces the model to call that tool.
///
/// `none` is the default when no tools are present. `auto` is the default if tools are present."
enum ToolTypeDeserializer { enum ToolTypeDeserializer {
/// None means `null` was passed in the JSON, and the default choice is applied based on the presence of tools.
Null, Null,
/// `auto` means the model can pick between generating a message or calling one or more tools.
#[schema(example = "auto")]
String(String), String(String),
ToolType(ToolType),
/// Specifying a particular tool forces the model to call that tool, with structured function details.
#[schema(example = r#"{"type": "function", "function": {"name": "my_function"}}"#)]
TypedChoice(TypedChoice),
} }
impl From<ToolTypeDeserializer> for ToolChoice { impl From<ToolTypeDeserializer> for ToolChoice {
fn from(value: ToolTypeDeserializer) -> Self { fn from(value: ToolTypeDeserializer) -> Self {
match value { match value {
ToolTypeDeserializer::Null => ToolChoice(None), ToolTypeDeserializer::Null => ToolChoice::Auto,
ToolTypeDeserializer::String(s) => match s.as_str() { ToolTypeDeserializer::String(s) => match s.as_str() {
"none" => ToolChoice(Some(ToolType::NoTool)), "none" => ToolChoice::NoTool,
"auto" => ToolChoice(Some(ToolType::OneOf)), "auto" => ToolChoice::Auto,
_ => ToolChoice(Some(ToolType::Function(FunctionName { name: s }))), "required" => ToolChoice::Required,
_ => ToolChoice::Function(FunctionName { name: s }),
}, },
ToolTypeDeserializer::ToolType(tool_type) => ToolChoice(Some(tool_type)), ToolTypeDeserializer::TypedChoice(TypedChoice::Function { function }) => {
ToolChoice::Function(function)
}
} }
} }
} }
@ -1213,6 +1263,7 @@ pub(crate) enum OutputMessage {
} }
#[derive(Clone, Debug, Deserialize, ToSchema)] #[derive(Clone, Debug, Deserialize, ToSchema)]
#[cfg_attr(test, derive(PartialEq))]
pub(crate) struct GenerateRequest { pub(crate) struct GenerateRequest {
#[schema(example = "My name is Olivier and I")] #[schema(example = "My name is Olivier and I")]
pub inputs: String, pub inputs: String,
@ -1653,4 +1704,41 @@ mod tests {
r#"{"role":"assistant","tool_calls":[{"id":"0","type":"function","function":{"description":null,"name":"myfn","arguments":{"format":"csv"}}}]}"# r#"{"role":"assistant","tool_calls":[{"id":"0","type":"function","function":{"description":null,"name":"myfn","arguments":{"format":"csv"}}}]}"#
); );
} }
#[test]
fn tool_choice_formats() {
#[derive(Deserialize)]
struct TestRequest {
tool_choice: ToolChoice,
}
let de_none: TestRequest = serde_json::from_str(r#"{"tool_choice":"none"}"#).unwrap();
assert_eq!(de_none.tool_choice, ToolChoice::NoTool);
let de_auto: TestRequest = serde_json::from_str(r#"{"tool_choice":"auto"}"#).unwrap();
assert_eq!(de_auto.tool_choice, ToolChoice::Auto);
let de_required: TestRequest =
serde_json::from_str(r#"{"tool_choice":"required"}"#).unwrap();
assert_eq!(de_required.tool_choice, ToolChoice::Required);
let de_named: TestRequest = serde_json::from_str(r#"{"tool_choice":"myfn"}"#).unwrap();
assert_eq!(
de_named.tool_choice,
ToolChoice::Function(FunctionName {
name: "myfn".to_string(),
})
);
let de_openai_named: TestRequest = serde_json::from_str(
r#"{"tool_choice":{"type":"function","function":{"name":"myfn"}}}"#,
)
.unwrap();
assert_eq!(
de_openai_named.tool_choice,
ToolChoice::Function(FunctionName {
name: "myfn".to_string(),
})
);
}
} }

View File

@ -1,6 +1,5 @@
/// HTTP Server logic /// HTTP Server logic
use crate::config::Config; use crate::config::Config;
use crate::infer::tool_grammar::ToolGrammar;
use crate::infer::{Backend, Infer, InferError, InferResponse, InferStreamResponse}; use crate::infer::{Backend, Infer, InferError, InferResponse, InferStreamResponse};
#[cfg(feature = "kserve")] #[cfg(feature = "kserve")]
use crate::kserve::{ use crate::kserve::{
@ -28,7 +27,7 @@ use crate::{
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal, ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
}; };
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType}; use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice};
use crate::{ModelInfo, ModelsInfo}; use crate::{ModelInfo, ModelsInfo};
use async_stream::__private::AsyncStream; use async_stream::__private::AsyncStream;
use axum::extract::Extension; use axum::extract::Extension;
@ -1559,7 +1558,6 @@ GrammarType,
Usage, Usage,
StreamOptions, StreamOptions,
DeltaToolCall, DeltaToolCall,
ToolType,
Tool, Tool,
ToolCall, ToolCall,
Function, Function,
@ -2525,160 +2523,3 @@ pub enum WebServerError {
#[error("Axum error: {0}")] #[error("Axum error: {0}")]
Axum(#[from] axum::BoxError), Axum(#[from] axum::BoxError),
} }
type PreparedInput = (String, Option<GrammarType>, bool);
pub(crate) fn prepare_chat_input(
infer: &Infer,
response_format: Option<GrammarType>,
tools: Option<Vec<Tool>>,
tool_choice: ToolChoice,
tool_prompt: &str,
guideline: Option<String>,
messages: Vec<Message>,
) -> Result<PreparedInput, InferError> {
if response_format.is_some() && tools.is_some() {
return Err(InferError::ToolError(
"Grammar and tools are mutually exclusive".into(),
));
}
// when response_format is set, tools are not included when applying the chat template to generate inputs
if let Some(format) = response_format {
let inputs = infer.apply_chat_template(guideline, messages, None)?;
return Ok((inputs, Some(format), false));
}
// when no response_format is set and tools are included, apply the chat template with the tools
// to generate inputs
if let Some(tools) = tools {
let (updated_tools, tool_schema) = ToolGrammar::apply(tools, tool_choice)?;
let grammar = tool_schema
.as_ref()
.map(|t| GrammarType::Json(serde_json::json!(t)));
let inputs: String = infer.apply_chat_template(
guideline,
messages,
Some((updated_tools, tool_prompt.into())),
)?;
return Ok((inputs, grammar, tool_schema.is_some()));
}
// if no response_format or tools are set simply apply the chat template to generate inputs
let inputs = infer.apply_chat_template(guideline, messages, None)?;
Ok((inputs, None, false))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ChatTemplateVersions;
use crate::HubTokenizerConfig;
use crate::TokenizerConfigToken;
use crate::Tool;
use crate::tests::get_tokenizer;
use serde_json::json;
#[tokio::test]
async fn test_prepare_chat_input() {
// Mock Backend to avoid network requests
struct MockBackend;
impl Backend for MockBackend {
fn schedule(
&self,
_request: crate::validation::ValidGenerateRequest,
) -> Result<
tokio_stream::wrappers::UnboundedReceiverStream<
Result<InferStreamResponse, InferError>,
>,
InferError,
> {
unimplemented!("Never called in this test");
}
fn health<'a, 'async_trait>(
&'a self,
_current_health: bool,
) -> core::pin::Pin<
Box<dyn core::future::Future<Output = bool> + core::marker::Send + 'async_trait>,
>
where
'a: 'async_trait,
Self: 'async_trait,
{
unimplemented!("Never called in this test");
}
}
let backend = MockBackend {};
let mut tokenizer_config = HubTokenizerConfig::default();
// mock tokenizer config values
tokenizer_config.bos_token = Some(TokenizerConfigToken::String("<s>".to_string()));
tokenizer_config.eos_token = Some(TokenizerConfigToken::String("</s>".to_string()));
tokenizer_config.chat_template = Some(
ChatTemplateVersions::Single("{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n".to_string())
);
let tokenizer = get_tokenizer();
let infer = Infer::new(
backend,
Validation::new(1, tokenizer, None, None, 1, 1, 1, 1, 1, false),
1,
tokenizer_config,
HubProcessorConfig::default(),
);
let response_format = None;
let tools = Some(vec![Tool {
r#type: "function".to_string(),
function: FunctionDefinition {
name: "get_current_weather".to_string(),
description: Some("Get the current weather".to_string()),
arguments: json!({
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the users location."
}
},
"required": ["location", "format"]
}),
},
}]);
let tool_prompt = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.";
let guideline = None;
let messages = vec![Message {
name: None,
role: "user".to_string(),
content: MessageContent::SingleText(
"What is the weather like in New York?".to_string(),
),
}];
let result = prepare_chat_input(
&infer,
response_format,
tools,
ToolChoice(None),
tool_prompt,
guideline,
messages,
);
assert!(result.is_ok());
let (inputs, _grammar, using_tools) = result.expect("Failed to prepare chat input");
assert_eq!(using_tools, true);
assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"content\":{\"description\":\"The response content\",\"type\":\"string\"}},\"required\":[\"content\"],\"type\":\"object\"}, \"description\": \"Open ened response with no specific tool selected\", \"name\": \"no_tool\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string());
}
}