feat: improve tools to include name and add tests (#1693)
This PR makes tool calling aware of the name of the function selected. Fixes: https://github.com/huggingface/text-generation-inference/issues/1657 Thank you @puppetm4st3r for the helpful snippets, large parts of this PR are simply refactors of the code shared 🙏 **opening draft PR because small tweaks are needed before merging
This commit is contained in:
parent
88702d8763
commit
7276d43495
|
@ -13,7 +13,7 @@
|
|||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1710795556,
|
||||
"created": 1712874856,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
|
@ -11,13 +11,12 @@
|
|||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"description": null,
|
||||
"name": "tools",
|
||||
"parameters": {
|
||||
"arguments": {
|
||||
"format": "celsius",
|
||||
"location": "New York, NY",
|
||||
"num_days": 14
|
||||
}
|
||||
"location": "Brooklyn"
|
||||
},
|
||||
"description": null,
|
||||
"name": "get_current_weather"
|
||||
},
|
||||
"id": 0,
|
||||
"type": "function"
|
||||
|
@ -27,14 +26,14 @@
|
|||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1710795556,
|
||||
"created": 1712782670,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 29,
|
||||
"prompt_tokens": 316,
|
||||
"total_tokens": 345
|
||||
"completion_tokens": 37,
|
||||
"prompt_tokens": 524,
|
||||
"total_tokens": 561
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,13 +11,12 @@
|
|||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"description": null,
|
||||
"name": "tools",
|
||||
"parameters": {
|
||||
"arguments": {
|
||||
"format": "celsius",
|
||||
"location": "New York, NY",
|
||||
"num_days": 14
|
||||
}
|
||||
"location": "Brooklyn"
|
||||
},
|
||||
"description": null,
|
||||
"name": "get_current_weather"
|
||||
},
|
||||
"id": 0,
|
||||
"type": "function"
|
||||
|
@ -27,14 +26,14 @@
|
|||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1710795557,
|
||||
"created": 1712787937,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 29,
|
||||
"prompt_tokens": 316,
|
||||
"total_tokens": 345
|
||||
"completion_tokens": 37,
|
||||
"prompt_tokens": 524,
|
||||
"total_tokens": 561
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,12 +11,12 @@
|
|||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"description": null,
|
||||
"name": "tools",
|
||||
"parameters": {
|
||||
"arguments": {
|
||||
"format": "celsius",
|
||||
"location": "New York, NY"
|
||||
}
|
||||
},
|
||||
"description": null,
|
||||
"name": "get_current_weather"
|
||||
},
|
||||
"id": 0,
|
||||
"type": "function"
|
||||
|
@ -26,14 +26,14 @@
|
|||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1710795557,
|
||||
"created": 1712852394,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 21,
|
||||
"prompt_tokens": 187,
|
||||
"total_tokens": 208
|
||||
"completion_tokens": 48,
|
||||
"prompt_tokens": 320,
|
||||
"total_tokens": 368
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "eos_token",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": null,
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"arguments": {
|
||||
"error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options."
|
||||
},
|
||||
"description": null,
|
||||
"name": "notify_error"
|
||||
},
|
||||
"id": 0,
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1712852597,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "1.4.5-native",
|
||||
"usage": {
|
||||
"completion_tokens": 39,
|
||||
"prompt_tokens": 496,
|
||||
"total_tokens": 535
|
||||
}
|
||||
}
|
|
@ -19,7 +19,7 @@
|
|||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1710795499,
|
||||
"created": 1712788218,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
import pytest
|
||||
import json
|
||||
|
||||
from text_generation.types import GrammarType
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_llama_chat_handle(launcher):
|
||||
with launcher(
|
||||
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_llama_chat(flash_llama_chat_handle):
|
||||
await flash_llama_chat_handle.health(300)
|
||||
return flash_llama_chat_handle.client
|
||||
|
||||
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_simple(flash_llama_chat, response_snapshot):
|
||||
response = await flash_llama_chat.chat(
|
||||
max_tokens=100,
|
||||
seed=1,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Youre a helpful assistant! Answer the users question best you can.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather like in Brooklyn, New York?",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
assert (
|
||||
response.choices[0].message.content
|
||||
== "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally"
|
||||
)
|
||||
assert response == response_snapshot
|
|
@ -71,34 +71,7 @@ tools = [
|
|||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_no_tools(
|
||||
flash_llama_grammar_tools, response_snapshot
|
||||
):
|
||||
response = await flash_llama_grammar_tools.chat(
|
||||
max_tokens=100,
|
||||
seed=1,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Youre a helpful assistant! Answer the users question best you can.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather like in Brooklyn, New York?",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
assert (
|
||||
response.choices[0].message.content
|
||||
== "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally"
|
||||
)
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.skip(reason="Takes too long to run")
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
|
||||
|
@ -121,23 +94,19 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
|
|||
assert response.choices[0].message.content == None
|
||||
assert response.choices[0].message.tool_calls == [
|
||||
{
|
||||
"function": {
|
||||
"description": None,
|
||||
"name": "tools",
|
||||
"parameters": {
|
||||
"format": "celsius",
|
||||
"location": "New York, NY",
|
||||
"num_days": 14,
|
||||
},
|
||||
},
|
||||
"id": 0,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"description": None,
|
||||
"name": "get_current_weather",
|
||||
"arguments": {"format": "celsius", "location": "New York, NY"},
|
||||
},
|
||||
}
|
||||
]
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.skip(reason="Takes too long to run")
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_auto(
|
||||
|
@ -163,23 +132,20 @@ async def test_flash_llama_grammar_tools_auto(
|
|||
assert response.choices[0].message.content == None
|
||||
assert response.choices[0].message.tool_calls == [
|
||||
{
|
||||
"function": {
|
||||
"description": None,
|
||||
"name": "tools",
|
||||
"parameters": {
|
||||
"format": "celsius",
|
||||
"location": "New York, NY",
|
||||
"num_days": 14,
|
||||
},
|
||||
},
|
||||
"id": 0,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"description": None,
|
||||
"name": "get_current_weather",
|
||||
"arguments": {"format": "celsius", "location": "New York, NY"},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.skip(reason="Takes too long to run")
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_choice(
|
||||
|
@ -209,15 +175,16 @@ async def test_flash_llama_grammar_tools_choice(
|
|||
"type": "function",
|
||||
"function": {
|
||||
"description": None,
|
||||
"name": "tools",
|
||||
"parameters": {"format": "celsius", "location": "New York, NY"},
|
||||
"name": "get_current_weather",
|
||||
"arguments": {"format": "celsius", "location": "New York, NY"},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.skip(reason="Takes too long to run")
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_stream(
|
||||
|
@ -246,5 +213,47 @@ async def test_flash_llama_grammar_tools_stream(
|
|||
async for response in responses:
|
||||
count += 1
|
||||
|
||||
assert count == 20
|
||||
assert count == 38
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Takes too long to run")
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_insufficient_information(
|
||||
flash_llama_grammar_tools, response_snapshot
|
||||
):
|
||||
responses = await flash_llama_grammar_tools.chat(
|
||||
max_tokens=100,
|
||||
seed=8,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Tell me a story about 3 sea creatures",
|
||||
},
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert responses.choices[0].message.content == None
|
||||
assert responses.choices[0].message.tool_calls == [
|
||||
{
|
||||
"function": {
|
||||
"arguments": {
|
||||
"error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options."
|
||||
},
|
||||
"description": None,
|
||||
"name": "notify_error",
|
||||
},
|
||||
"id": 0,
|
||||
"type": "function",
|
||||
}
|
||||
]
|
||||
|
||||
assert responses == response_snapshot
|
||||
|
|
|
@ -4,9 +4,12 @@ use crate::{
|
|||
ChatTemplateInputs, ChatTemplateVersions, Entry, GenerateRequest, GenerateStreamResponse,
|
||||
HubTokenizerConfig, Message, PrefillToken, Queue, Token,
|
||||
};
|
||||
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
|
||||
use futures::future::try_join_all;
|
||||
use minijinja::{Environment, ErrorKind, Template};
|
||||
use nohash_hasher::IntMap;
|
||||
use serde_json::{json, Map, Value};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
|
@ -185,11 +188,15 @@ impl Infer {
|
|||
|
||||
/// Apply the chat template to the chat request
|
||||
#[instrument(skip_all)]
|
||||
pub(crate) fn apply_chat_template(&self, messages: Vec<Message>) -> Result<String, InferError> {
|
||||
pub(crate) fn apply_chat_template(
|
||||
&self,
|
||||
messages: Vec<Message>,
|
||||
grammar_with_prompt: Option<(GrammarType, String)>,
|
||||
) -> Result<String, InferError> {
|
||||
self.chat_template
|
||||
.as_ref()
|
||||
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
||||
.apply(messages)
|
||||
.apply(messages, grammar_with_prompt)
|
||||
.map_err(|e| {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "template");
|
||||
tracing::error!("{e}");
|
||||
|
@ -322,6 +329,7 @@ struct ChatTemplate {
|
|||
template: Template<'static, 'static>,
|
||||
bos_token: Option<String>,
|
||||
eos_token: Option<String>,
|
||||
use_default_tool_template: bool,
|
||||
}
|
||||
|
||||
impl ChatTemplate {
|
||||
|
@ -329,6 +337,10 @@ impl ChatTemplate {
|
|||
let mut env = Box::new(Environment::new());
|
||||
let template_str = template.into_boxed_str();
|
||||
env.add_function("raise_exception", raise_exception);
|
||||
|
||||
// check if contains the tools variable within the template
|
||||
let use_default_tool_template =
|
||||
!template_str.as_ref().replace(' ', "").contains("{{tools}}");
|
||||
// leaking env and template_str as read-only, static resources for performance.
|
||||
let template = Box::leak(env)
|
||||
.template_from_str(Box::leak(template_str))
|
||||
|
@ -338,21 +350,159 @@ impl ChatTemplate {
|
|||
template,
|
||||
bos_token,
|
||||
eos_token,
|
||||
use_default_tool_template,
|
||||
}
|
||||
}
|
||||
|
||||
fn apply(
|
||||
&self,
|
||||
mut messages: Vec<Message>,
|
||||
grammar_with_prompt: Option<(GrammarType, String)>,
|
||||
) -> Result<String, InferError> {
|
||||
if self.use_default_tool_template {
|
||||
if let Some(last_message) = messages.last_mut() {
|
||||
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
|
||||
last_message.content = Some(format!(
|
||||
"{}\n---\n{}\n{}",
|
||||
last_message.content.as_deref().unwrap_or_default(),
|
||||
tool_prompt,
|
||||
tools
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn apply(&self, messages: Vec<Message>) -> Result<String, InferError> {
|
||||
self.template
|
||||
.render(ChatTemplateInputs {
|
||||
messages,
|
||||
bos_token: self.bos_token.as_deref(),
|
||||
eos_token: self.eos_token.as_deref(),
|
||||
add_generation_prompt: true,
|
||||
tools: None,
|
||||
tools_prompt: None,
|
||||
})
|
||||
.map_err(InferError::TemplateError)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ToolGrammar {}
|
||||
|
||||
impl ToolGrammar {
|
||||
pub fn apply(
|
||||
tools: Option<Vec<Tool>>,
|
||||
tool_choice: Option<ToolType>,
|
||||
) -> Result<Option<Tools>, InferError> {
|
||||
if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) {
|
||||
// let tool_prompt = tool_prompt.unwrap_or_default();
|
||||
let tools_to_use = match tool_choice {
|
||||
ToolType::FunctionName(name) => {
|
||||
vec![req_tools
|
||||
.iter()
|
||||
.find(|tool| tool.function.name == *name)
|
||||
.unwrap_or_else(|| panic!("Tool with name {} not found", name))
|
||||
.clone()]
|
||||
}
|
||||
ToolType::OneOf => req_tools.to_owned(),
|
||||
};
|
||||
|
||||
// adds the error notification function for LLM feedback if required
|
||||
let mut text_response_properties = Map::new();
|
||||
text_response_properties.insert(
|
||||
"error".to_string(),
|
||||
serde_json::json!({
|
||||
"type": "string",
|
||||
"description": "The error or issue to notify"
|
||||
}),
|
||||
);
|
||||
text_response_properties.insert(
|
||||
"_name".to_string(),
|
||||
serde_json::json!({
|
||||
"type": "string",
|
||||
"const": "notify_error"
|
||||
}),
|
||||
);
|
||||
|
||||
let functions: HashMap<String, serde_json::Value> = tools_to_use
|
||||
.iter()
|
||||
.map(|tool| {
|
||||
let func = tool.function.clone();
|
||||
|
||||
// Clone the existing parameters, which are expected to be a JSON object
|
||||
let mut params = if let Value::Object(params) = &func.arguments {
|
||||
params.clone()
|
||||
} else {
|
||||
Map::new()
|
||||
};
|
||||
|
||||
// Insert the function's description at the top level, outside of properties
|
||||
params.insert(
|
||||
"description".to_string(),
|
||||
Value::String(func.description.clone().unwrap_or_default()),
|
||||
);
|
||||
|
||||
// Ensure 'properties' exists and is an object
|
||||
let properties = params
|
||||
.entry("properties".to_string())
|
||||
.or_insert_with(|| json!({}))
|
||||
.as_object_mut()
|
||||
.unwrap();
|
||||
|
||||
// Insert the constant for the function name inside 'properties'
|
||||
properties.insert(
|
||||
"_name".to_string(),
|
||||
json!({
|
||||
"type": "string",
|
||||
"const": func.name.clone(),
|
||||
// "description": "The name of the function"
|
||||
}),
|
||||
);
|
||||
|
||||
// Check if 'required' exists, and it is an array. If not, create an empty array.
|
||||
let required = params
|
||||
.entry("required".to_string())
|
||||
.or_insert_with(|| json!([]))
|
||||
.as_array_mut()
|
||||
.unwrap();
|
||||
|
||||
// Add 'name' to the 'required' array if it is not already present
|
||||
if !required.iter().any(|r| r == "_name") {
|
||||
required.push(json!("_name"));
|
||||
}
|
||||
|
||||
(func.name, Value::Object(params))
|
||||
})
|
||||
.chain([(
|
||||
"notify_error".to_string(),
|
||||
serde_json::json!({
|
||||
"properties": text_response_properties,
|
||||
"required": ["error", "_name"],
|
||||
"type": "object"
|
||||
}),
|
||||
)])
|
||||
.collect();
|
||||
|
||||
let tools = Tools {
|
||||
functions_map: FunctionsMap { functions },
|
||||
properties: Properties {
|
||||
function: tools_to_use
|
||||
.iter()
|
||||
.map(|tool| FunctionRef {
|
||||
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
|
||||
})
|
||||
.chain(std::iter::once(FunctionRef {
|
||||
ref_path: "#/$functions/notify_error".to_string(),
|
||||
}))
|
||||
.collect(),
|
||||
},
|
||||
};
|
||||
|
||||
return Ok(Some(tools));
|
||||
}
|
||||
// Err(InferError::ToolError("No tools provided".to_string()))
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Batching logic
|
||||
/// Will be launched in a background Tokio task
|
||||
///
|
||||
|
@ -768,6 +918,8 @@ pub enum InferError {
|
|||
IncompleteGeneration,
|
||||
#[error("Template error: {0}")]
|
||||
TemplateError(#[from] minijinja::Error),
|
||||
#[error("Tool error: {0}")]
|
||||
ToolError(String),
|
||||
}
|
||||
|
||||
impl InferError {
|
||||
|
@ -778,6 +930,7 @@ impl InferError {
|
|||
InferError::ValidationError(_) => "validation",
|
||||
InferError::IncompleteGeneration => "incomplete_generation",
|
||||
InferError::TemplateError(_) => "template_error",
|
||||
InferError::ToolError(_) => "tool_error",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -849,6 +1002,7 @@ mod tests {
|
|||
bos_token: Some("[BOS]"),
|
||||
eos_token: Some("[EOS]"),
|
||||
add_generation_prompt: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
|
||||
|
@ -924,6 +1078,7 @@ mod tests {
|
|||
bos_token: Some("[BOS]"),
|
||||
eos_token: Some("[EOS]"),
|
||||
add_generation_prompt: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap();
|
||||
|
@ -998,6 +1153,7 @@ mod tests {
|
|||
bos_token: Some("[BOS]"),
|
||||
eos_token: Some("[EOS]"),
|
||||
add_generation_prompt: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
|
||||
|
@ -1056,6 +1212,7 @@ mod tests {
|
|||
bos_token: Some("[BOS]"),
|
||||
eos_token: Some("[EOS]"),
|
||||
add_generation_prompt: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
|
||||
|
@ -1115,6 +1272,7 @@ mod tests {
|
|||
add_generation_prompt: false,
|
||||
bos_token: Some(""),
|
||||
eos_token: Some(""),
|
||||
..Default::default()
|
||||
},
|
||||
target: "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n",
|
||||
},
|
||||
|
@ -1126,6 +1284,7 @@ mod tests {
|
|||
add_generation_prompt: false,
|
||||
bos_token: Some(""),
|
||||
eos_token: Some("</s>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: " Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!</s>",
|
||||
},
|
||||
|
@ -1137,6 +1296,7 @@ mod tests {
|
|||
add_generation_prompt: false,
|
||||
bos_token: Some(""),
|
||||
eos_token: Some("</s>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: " Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!</s>",
|
||||
},
|
||||
|
@ -1148,6 +1308,7 @@ mod tests {
|
|||
add_generation_prompt: false,
|
||||
bos_token: Some(""),
|
||||
eos_token: Some("</s>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "Hello, how are you?</s>I'm doing great. How can I help you today?</s>I'd like to show off how chat templating works!</s>",
|
||||
},
|
||||
|
@ -1159,6 +1320,7 @@ mod tests {
|
|||
add_generation_prompt: false,
|
||||
bos_token: Some(""),
|
||||
eos_token: Some("<|endoftext|>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>",
|
||||
},
|
||||
|
@ -1170,6 +1332,7 @@ mod tests {
|
|||
add_generation_prompt: false,
|
||||
bos_token: Some(""),
|
||||
eos_token: Some("<|endoftext|>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>",
|
||||
},
|
||||
|
@ -1182,6 +1345,7 @@ mod tests {
|
|||
add_generation_prompt: true,
|
||||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "<s>[INST] <<SYS>>\nYou are a friendly chatbot who always responds in the style of a pirate\n<</SYS>>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? </s><s>[INST] I'd like to show off how chat templating works! [/INST]",
|
||||
},
|
||||
|
@ -1193,6 +1357,7 @@ mod tests {
|
|||
add_generation_prompt: true,
|
||||
bos_token: Some(""),
|
||||
eos_token: Some("<|endoftext|>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>",
|
||||
},
|
||||
|
@ -1222,6 +1387,7 @@ mod tests {
|
|||
add_generation_prompt: false,
|
||||
bos_token: Some(""),
|
||||
eos_token: Some("</s>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate</s><|user|>\nHello, how are you?</s><|assistant|>\nI'm doing great. How can I help you today?</s><|user|>\nI'd like to show off how chat templating works!</s>",
|
||||
},
|
||||
|
@ -1246,6 +1412,7 @@ mod tests {
|
|||
add_generation_prompt: true,
|
||||
bos_token: Some(""),
|
||||
eos_token: Some("</s>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate</s><|user|>\nHow many helicopters can a human eat in one sitting?</s><|assistant|>",
|
||||
},
|
||||
|
@ -1257,6 +1424,7 @@ mod tests {
|
|||
add_generation_prompt: false,
|
||||
bos_token: Some("<bos>"),
|
||||
eos_token: Some("<eos>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "<bos><|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n",
|
||||
},
|
||||
|
@ -1268,6 +1436,7 @@ mod tests {
|
|||
add_generation_prompt: false,
|
||||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "<s>[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?</s> [INST] I'd like to show off how chat templating works! [/INST]",
|
||||
},
|
||||
|
@ -1279,6 +1448,7 @@ mod tests {
|
|||
add_generation_prompt: false,
|
||||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "<s>[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?</s>[INST] I'd like to show off how chat templating works! [/INST]",
|
||||
},
|
||||
|
@ -1290,6 +1460,7 @@ mod tests {
|
|||
add_generation_prompt: false,
|
||||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n",
|
||||
},
|
||||
|
@ -1302,6 +1473,7 @@ mod tests {
|
|||
add_generation_prompt: false,
|
||||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "<s>GPT4 Correct User: Hello, how are you?<|end_of_turn|>GPT4 Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT4 Correct User: I'd like to show off how chat templating works!<|end_of_turn|>",
|
||||
},
|
||||
|
@ -1313,6 +1485,7 @@ mod tests {
|
|||
add_generation_prompt: false,
|
||||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "Hello, how are you?</s>I'm doing great. How can I help you today?</s>I'd like to show off how chat templating works!</s>",
|
||||
},
|
||||
|
@ -1325,6 +1498,7 @@ mod tests {
|
|||
add_generation_prompt: false,
|
||||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "<s>Source: user\n\n Hello, how are you? <step> Source: assistant\n\n I'm doing great. How can I help you today? <step> Source: user\n\n I'd like to show off how chat templating works! <step> Source: assistant\nDestination: user\n\n ",
|
||||
},
|
||||
|
@ -1336,6 +1510,7 @@ mod tests {
|
|||
add_generation_prompt: false,
|
||||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "### User:\nHello, how are you?### Assistant:\nI'm doing great. How can I help you today?### User:\nI'd like to show off how chat templating works!",
|
||||
},
|
||||
|
@ -1347,6 +1522,7 @@ mod tests {
|
|||
add_generation_prompt: false,
|
||||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!",
|
||||
},
|
||||
|
@ -1358,6 +1534,7 @@ mod tests {
|
|||
add_generation_prompt: false,
|
||||
bos_token: Some("<|begin▁of▁sentence|>"),
|
||||
eos_token: Some("<|end▁of▁sentence|>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "<|begin▁of▁sentence|>User: Hello, how are you?\n\nAssistant: I'm doing great. How can I help you today?<|end▁of▁sentence|>User: I'd like to show off how chat templating works!\n\n",
|
||||
},
|
||||
|
@ -1369,6 +1546,7 @@ mod tests {
|
|||
add_generation_prompt: false,
|
||||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "<|prompt|>Hello, how are you?</s><|answer|>I'm doing great. How can I help you today?</s><|prompt|>I'd like to show off how chat templating works!</s>",
|
||||
},
|
||||
|
@ -1380,6 +1558,7 @@ mod tests {
|
|||
add_generation_prompt: false,
|
||||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "<s><|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n",
|
||||
},
|
||||
|
@ -1391,6 +1570,7 @@ mod tests {
|
|||
add_generation_prompt: false,
|
||||
bos_token: Some("<|begin▁of▁sentence|>"),
|
||||
eos_token: Some("<|EOT|>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\n### Instruction:\nHello, how are you?\n### Response:\nI'm doing great. How can I help you today?\n<|EOT|>\n### Instruction:\nI'd like to show off how chat templating works!\n### Response:\n",
|
||||
},
|
||||
|
@ -1403,6 +1583,7 @@ mod tests {
|
|||
add_generation_prompt: false,
|
||||
bos_token: Some("<|endoftext|>"),
|
||||
eos_token: Some("<|endoftext|>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "[INST] Hello, how are you? [RESP] I'm doing great. How can I help you today?<|endoftext|>[INST] I'd like to show off how chat templating works!",
|
||||
},
|
||||
|
@ -1414,6 +1595,7 @@ mod tests {
|
|||
add_generation_prompt: false,
|
||||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "Hello, how are you? [/INST] I'm doing great. How can I help you today? </s><s>[INST] I'd like to show off how chat templating works! [/INST]",
|
||||
},
|
||||
|
@ -1425,6 +1607,7 @@ mod tests {
|
|||
add_generation_prompt: false,
|
||||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "Below is an instruction that describes a task. Write a response that appropriately completes the request.### Instruction:Hello, how are you?### Response:I'm doing great. How can I help you today?### Instruction:I'd like to show off how chat templating works!",
|
||||
},
|
||||
|
@ -1436,6 +1619,7 @@ mod tests {
|
|||
add_generation_prompt: false,
|
||||
bos_token: Some("<|begin▁of▁sentence|>"),
|
||||
eos_token: Some("</EOT>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "<|begin▁of▁sentence|>You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\n### Instruction:\nHello, how are you?\n### Response:\nI'm doing great. How can I help you today?\n<|EOT|>\n### Instruction:\nI'd like to show off how chat templating works!\n",
|
||||
},
|
||||
|
@ -1451,6 +1635,7 @@ mod tests {
|
|||
add_generation_prompt: false,
|
||||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
..Default::default()
|
||||
},
|
||||
target: "You are a friendly chatbot who always responds in the style of a pirateYou are a friendly chatbot who always responds in the style of a pirate### Instruction: Hello, how are you?### Response: I'm doing great. How can I help you today?### Instruction: I'd like to show off how chat templating works!",
|
||||
},
|
||||
|
|
|
@ -79,7 +79,7 @@ impl HubTokenizerConfig {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||
#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
|
||||
#[serde(tag = "type", content = "value")]
|
||||
pub(crate) enum GrammarType {
|
||||
/// A string that represents a [JSON Schema](https://json-schema.org/).
|
||||
|
@ -669,7 +669,7 @@ pub(crate) struct ChatRequest {
|
|||
#[serde(default = "default_tool_prompt")]
|
||||
#[schema(
|
||||
nullable = true,
|
||||
example = "\"Based on the conversation, please choose the most appropriate tool to use: \""
|
||||
example = "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\""
|
||||
)]
|
||||
pub tool_prompt: Option<String>,
|
||||
|
||||
|
@ -682,7 +682,7 @@ pub(crate) struct ChatRequest {
|
|||
|
||||
fn default_tool_prompt() -> Option<String> {
|
||||
Some(
|
||||
"\nBased on the conversation, please choose the most appropriate tool to use: ".to_string(),
|
||||
"\nYou will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n".to_string(),
|
||||
)
|
||||
}
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||
|
@ -727,26 +727,26 @@ mod deserialize_tool_choice {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, ToSchema)]
|
||||
#[derive(Debug, Deserialize, Serialize, ToSchema, PartialEq)]
|
||||
pub struct Tools {
|
||||
#[serde(flatten)]
|
||||
functions_map: FunctionsMap,
|
||||
properties: Properties,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq)]
|
||||
struct FunctionsMap {
|
||||
#[serde(rename = "$functions")]
|
||||
functions: std::collections::HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq)]
|
||||
struct FunctionRef {
|
||||
#[serde(rename = "$ref")]
|
||||
ref_path: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq)]
|
||||
struct Properties {
|
||||
#[serde(serialize_with = "serialize_function")]
|
||||
function: Vec<FunctionRef>,
|
||||
|
@ -767,7 +767,8 @@ pub(crate) struct FunctionDefinition {
|
|||
#[serde(default)]
|
||||
pub description: Option<String>,
|
||||
pub name: String,
|
||||
pub parameters: serde_json::Value,
|
||||
#[serde(alias = "parameters")]
|
||||
pub arguments: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||
|
@ -779,12 +780,14 @@ pub(crate) struct Tool {
|
|||
pub function: FunctionDefinition,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
#[derive(Clone, Serialize, Deserialize, Default)]
|
||||
pub(crate) struct ChatTemplateInputs<'a> {
|
||||
messages: Vec<Message>,
|
||||
bos_token: Option<&'a str>,
|
||||
eos_token: Option<&'a str>,
|
||||
add_generation_prompt: bool,
|
||||
tools: Option<&'a str>,
|
||||
tools_prompt: Option<&'a str>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::config::Config;
|
||||
/// HTTP Server logic
|
||||
use crate::health::Health;
|
||||
use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
||||
use crate::infer::{InferError, InferResponse, InferStreamResponse, ToolGrammar};
|
||||
use crate::validation::ValidationError;
|
||||
use crate::{
|
||||
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
|
||||
|
@ -15,7 +15,7 @@ use crate::{
|
|||
ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk,
|
||||
CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse,
|
||||
};
|
||||
use crate::{FunctionDefinition, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools};
|
||||
use crate::{FunctionDefinition, ToolCall, ToolType};
|
||||
use axum::extract::Extension;
|
||||
use axum::http::{HeaderMap, Method, StatusCode};
|
||||
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||
|
@ -29,7 +29,6 @@ use futures::Stream;
|
|||
use futures::TryStreamExt;
|
||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::convert::Infallible;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
|
@ -757,19 +756,29 @@ async fn chat_completions(
|
|||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
metrics::increment_counter!("tgi_request_count");
|
||||
|
||||
let stream = req.stream;
|
||||
let max_new_tokens = req.max_tokens.or(Some(100));
|
||||
let repetition_penalty = req
|
||||
.presence_penalty
|
||||
// rescale repetition_penalty from (-2.0, 2.0) to (0.0, 4.0)
|
||||
.map(|x| x + 2.0);
|
||||
let logprobs = req.logprobs.unwrap_or(false);
|
||||
let seed = req.seed;
|
||||
let stop = req.stop.unwrap_or_default();
|
||||
let ChatRequest {
|
||||
logprobs,
|
||||
max_tokens,
|
||||
messages,
|
||||
presence_penalty,
|
||||
seed,
|
||||
stop,
|
||||
stream,
|
||||
tools,
|
||||
tool_choice,
|
||||
tool_prompt,
|
||||
..
|
||||
} = req;
|
||||
|
||||
// apply chat template to flatten the request into a single input
|
||||
let mut inputs = match infer.apply_chat_template(req.messages) {
|
||||
Ok(inputs) => inputs,
|
||||
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
|
||||
let max_new_tokens = max_tokens.or(Some(100));
|
||||
let logprobs = logprobs.unwrap_or(false);
|
||||
let tool_prompt = tool_prompt.unwrap_or_default();
|
||||
let stop = stop.unwrap_or_default();
|
||||
|
||||
// extract tool grammar if present
|
||||
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
|
||||
Ok(grammar) => grammar,
|
||||
Err(err) => {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
tracing::error!("{err}");
|
||||
|
@ -783,60 +792,28 @@ async fn chat_completions(
|
|||
}
|
||||
};
|
||||
|
||||
let tool_grammar = if let Some((req_tools, tool_choice)) = req.tools.zip(req.tool_choice) {
|
||||
let tool_prompt = req.tool_prompt.unwrap_or_default();
|
||||
let tools_to_use = match tool_choice {
|
||||
ToolType::FunctionName(name) => {
|
||||
vec![req_tools
|
||||
.iter()
|
||||
.find(|tool| tool.function.name == *name)
|
||||
.ok_or_else(|| {
|
||||
(
|
||||
let grammar_with_prompt = tool_grammar
|
||||
.as_ref()
|
||||
.map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt));
|
||||
|
||||
let typed_grammar = grammar_with_prompt
|
||||
.as_ref()
|
||||
.map(|(grammar, _)| grammar.clone());
|
||||
|
||||
// apply chat template to flatten the request into a single input
|
||||
let inputs = match infer.apply_chat_template(messages, grammar_with_prompt) {
|
||||
Ok(inputs) => inputs,
|
||||
Err(err) => {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
tracing::error!("{err}");
|
||||
return Err((
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: "Tool choice not found in tool names".to_string(),
|
||||
error_type: "Tool not found".to_string(),
|
||||
error: err.to_string(),
|
||||
error_type: err.error_type().to_string(),
|
||||
}),
|
||||
)
|
||||
})?
|
||||
.clone()]
|
||||
));
|
||||
}
|
||||
ToolType::OneOf => req_tools.to_owned(),
|
||||
};
|
||||
|
||||
let functions: HashMap<String, Value> = tools_to_use
|
||||
.iter()
|
||||
.map(|tool| {
|
||||
let func = tool.function.clone();
|
||||
(func.name, func.parameters)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let tools = Tools {
|
||||
functions_map: FunctionsMap { functions },
|
||||
properties: Properties {
|
||||
function: tools_to_use
|
||||
.iter()
|
||||
.map(|tool| FunctionRef {
|
||||
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
|
||||
})
|
||||
.collect(),
|
||||
},
|
||||
};
|
||||
|
||||
let tools_str = serde_json::to_string(&tools).map_err(|e| {
|
||||
(
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: e.to_string(),
|
||||
error_type: "Input validation error".to_string(),
|
||||
}),
|
||||
)
|
||||
})?;
|
||||
inputs = format!("{inputs}{tool_prompt}{tools_str}");
|
||||
Some(GrammarType::Json(serde_json::json!(tools)))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// build the request passing some parameters
|
||||
|
@ -860,7 +837,7 @@ async fn chat_completions(
|
|||
decoder_input_details: !stream,
|
||||
seed,
|
||||
top_n_tokens: req.top_logprobs,
|
||||
grammar: tool_grammar.clone(),
|
||||
grammar: typed_grammar,
|
||||
},
|
||||
};
|
||||
|
||||
|
@ -943,27 +920,28 @@ async fn chat_completions(
|
|||
}),
|
||||
)
|
||||
})?;
|
||||
|
||||
let tool_calls = vec![ToolCall {
|
||||
id: 0,
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionDefinition {
|
||||
description: None,
|
||||
name: "tools".to_string(),
|
||||
parameters: gen_text_value.get("function").map_or_else(
|
||||
|| {
|
||||
serde_json::from_str(&generation.generated_text).map_err(|e| {
|
||||
(
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: e.to_string(),
|
||||
error_type: "Input validation error".to_string(),
|
||||
}),
|
||||
)
|
||||
name: gen_text_value
|
||||
.get("function")
|
||||
.and_then(|f| f.get("_name"))
|
||||
.and_then(|name| name.as_str())
|
||||
.unwrap_or("default_function_name")
|
||||
.to_string(),
|
||||
// Serialize the JSON object obtained from "function" to an escaped JSON string
|
||||
arguments: gen_text_value
|
||||
.get("function")
|
||||
.map(|f| {
|
||||
let mut f_cloned = f.clone();
|
||||
if let Value::Object(ref mut props) = f_cloned {
|
||||
props.remove("_name");
|
||||
}
|
||||
f_cloned
|
||||
})
|
||||
},
|
||||
|f| Ok(f.clone()),
|
||||
)?,
|
||||
.unwrap_or_default(),
|
||||
},
|
||||
}];
|
||||
(Some(tool_calls), None)
|
||||
|
@ -1539,6 +1517,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
|||
InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||
InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||
};
|
||||
|
||||
(
|
||||
|
|
Loading…
Reference in New Issue