fix: adjust default tool choice (#2244)

* fix: adjust default tool choice

* feat: improve tool choice syntax and response parsing/errors

* fix: remove dev tests

* feat: add ToolChoice to docs
This commit is contained in:
drbh 2024-07-19 11:12:02 -04:00 committed by GitHub
parent 40f5dc3ed6
commit 68a9685f1b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 160 additions and 147 deletions

View File

@ -909,7 +909,7 @@
"tool_choice": { "tool_choice": {
"allOf": [ "allOf": [
{ {
"$ref": "#/components/schemas/ToolType" "$ref": "#/components/schemas/ToolChoice"
} }
], ],
"nullable": true "nullable": true
@ -2035,6 +2035,14 @@
} }
} }
}, },
"ToolChoice": {
"allOf": [
{
"$ref": "#/components/schemas/ToolType"
}
],
"nullable": true
},
"ToolType": { "ToolType": {
"oneOf": [ "oneOf": [
{ {
@ -2055,6 +2063,11 @@
"$ref": "#/components/schemas/FunctionName" "$ref": "#/components/schemas/FunctionName"
} }
} }
},
{
"type": "object",
"default": null,
"nullable": true
} }
] ]
}, },

View File

@ -7,7 +7,7 @@ pub(crate) use health::HealthCheck;
use crate::validation::{ValidGenerateRequest, Validation, ValidationError}; use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
use crate::{ use crate::{
ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig,
HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token, HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token, ToolChoice,
}; };
use crate::{ use crate::{
FunctionRef, FunctionsMap, GrammarType, Properties, TokenizerConfigToken, Tool, ToolType, Tools, FunctionRef, FunctionsMap, GrammarType, Properties, TokenizerConfigToken, Tool, ToolType, Tools,
@ -332,126 +332,131 @@ impl ChatTemplate {
pub struct ToolGrammar {} pub struct ToolGrammar {}
impl ToolGrammar { impl ToolGrammar {
// find a tool by name
fn find_tool_by_name(tools: &[Tool], name: &str) -> Result<Tool, InferError> {
tools
.iter()
.find(|tool| tool.function.name == name)
.cloned()
.ok_or_else(|| InferError::ToolError(format!("Tool with name {} not found", name)))
}
pub fn apply( pub fn apply(
tools: Option<Vec<Tool>>, tools: Option<Vec<Tool>>,
tool_choice: Option<ToolType>, tool_choice: ToolChoice,
) -> Result<Option<Tools>, InferError> { ) -> Result<Option<Tools>, InferError> {
if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) { // if no tools are provided, we return None
// let tool_prompt = tool_prompt.unwrap_or_default(); let tools = match tools {
let tools_to_use = match tool_choice { Some(tools) if !tools.is_empty() => tools,
ToolType::FunctionName(name) => { _ => return Ok(None),
vec![req_tools };
.iter()
.find(|tool| tool.function.name == *name)
.unwrap_or_else(|| panic!("Tool with name {} not found", name))
.clone()]
}
ToolType::Function { function } => {
let tool = req_tools
.iter()
.find(|tool| tool.function.name == function.name)
.unwrap_or_else(|| panic!("Tool with name {} not found", function.name))
.clone();
vec![tool]
}
ToolType::OneOf => req_tools.to_owned(),
};
// adds the error notification function for LLM feedback if required let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
let mut text_response_properties = Map::new();
text_response_properties.insert(
"error".to_string(),
serde_json::json!({
"type": "string",
"description": "The error or issue to notify"
}),
);
text_response_properties.insert(
"_name".to_string(),
serde_json::json!({
"type": "string",
"const": "notify_error"
}),
);
let functions: HashMap<String, serde_json::Value> = tools_to_use // if tools are provided and no tool_choice we default to the OneOf
.iter() let tools_to_use = match tool_choice {
.map(|tool| { ToolType::FunctionName(name) => {
let func = tool.function.clone(); vec![Self::find_tool_by_name(&tools, &name)?]
}
ToolType::Function { function } => {
vec![Self::find_tool_by_name(&tools, &function.name)?]
}
ToolType::OneOf => tools,
ToolType::NoTool => return Ok(None),
};
// Clone the existing parameters, which are expected to be a JSON object // adds the error notification function for LLM feedback if required
let mut params = if let Value::Object(params) = &func.arguments { let mut text_response_properties = Map::new();
params.clone() text_response_properties.insert(
} else { "error".to_string(),
Map::new() serde_json::json!({
}; "type": "string",
"description": "The error or issue to notify"
}),
);
text_response_properties.insert(
"_name".to_string(),
serde_json::json!({
"type": "string",
"const": "notify_error"
}),
);
// Insert the function's description at the top level, outside of properties let functions: HashMap<String, serde_json::Value> = tools_to_use
params.insert( .iter()
"description".to_string(), .map(|tool| {
Value::String(func.description.clone().unwrap_or_default()), let func = tool.function.clone();
);
// Ensure 'properties' exists and is an object // Clone the existing parameters, which are expected to be a JSON object
let properties = params let mut params = if let Value::Object(params) = &func.arguments {
.entry("properties".to_string()) params.clone()
.or_insert_with(|| json!({})) } else {
.as_object_mut() Map::new()
.unwrap(); };
// Insert the constant for the function name inside 'properties' // Insert the function's description at the top level, outside of properties
properties.insert( params.insert(
"_name".to_string(), "description".to_string(),
json!({ Value::String(func.description.clone().unwrap_or_default()),
"type": "string", );
"const": func.name.clone(),
// "description": "The name of the function"
}),
);
// Check if 'required' exists, and it is an array. If not, create an empty array. // Ensure 'properties' exists and is an object
let required = params let properties = params
.entry("required".to_string()) .entry("properties".to_string())
.or_insert_with(|| json!([])) .or_insert_with(|| json!({}))
.as_array_mut() .as_object_mut()
.unwrap(); .unwrap();
// Add 'name' to the 'required' array if it is not already present // Insert the constant for the function name inside 'properties'
if !required.iter().any(|r| r == "_name") { properties.insert(
required.push(json!("_name")); "_name".to_string(),
} json!({
"type": "string",
(func.name, Value::Object(params)) "const": func.name.clone(),
}) // "description": "The name of the function"
.chain([(
"notify_error".to_string(),
serde_json::json!({
"properties": text_response_properties,
"required": ["error", "_name"],
"type": "object"
}), }),
)]) );
.collect();
let tools = Tools { // Check if 'required' exists, and it is an array. If not, create an empty array.
functions_map: FunctionsMap { functions }, let required = params
properties: Properties { .entry("required".to_string())
function: tools_to_use .or_insert_with(|| json!([]))
.iter() .as_array_mut()
.map(|tool| FunctionRef { .unwrap();
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
})
.chain(std::iter::once(FunctionRef {
ref_path: "#/$functions/notify_error".to_string(),
}))
.collect(),
},
};
return Ok(Some(tools)); // Add 'name' to the 'required' array if it is not already present
} if !required.iter().any(|r| r == "_name") {
// Err(InferError::ToolError("No tools provided".to_string())) required.push(json!("_name"));
Ok(None) }
(func.name, Value::Object(params))
})
.chain([(
"notify_error".to_string(),
serde_json::json!({
"properties": text_response_properties,
"required": ["error", "_name"],
"type": "object"
}),
)])
.collect();
let tools = Tools {
functions_map: FunctionsMap { functions },
properties: Properties {
function: tools_to_use
.iter()
.map(|tool| FunctionRef {
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
})
.chain(std::iter::once(FunctionRef {
ref_path: "#/$functions/notify_error".to_string(),
}))
.collect(),
},
};
Ok(Some(tools))
} }
} }

View File

@ -826,7 +826,7 @@ pub(crate) struct ChatRequest {
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter. /// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
#[serde(default)] #[serde(default)]
#[schema(nullable = true, example = "null")] #[schema(nullable = true, example = "null")]
pub tool_choice: Option<ToolType>, pub tool_choice: ToolChoice,
/// Response format constraints for the generation. /// Response format constraints for the generation.
/// ///
@ -848,6 +848,7 @@ pub enum ToolType {
OneOf, OneOf,
FunctionName(String), FunctionName(String),
Function { function: FunctionName }, Function { function: FunctionName },
NoTool,
} }
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)]
@ -855,27 +856,26 @@ pub struct FunctionName {
pub name: String, pub name: String,
} }
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default, ToSchema)]
#[serde(from = "ToolTypeDeserializer")] #[serde(from = "ToolTypeDeserializer")]
pub struct ToolChoice(pub Option<ToolType>); pub struct ToolChoice(pub Option<ToolType>);
#[derive(Deserialize)] #[derive(Deserialize)]
#[serde(untagged)] #[serde(untagged)]
enum ToolTypeDeserializer { enum ToolTypeDeserializer {
None(Option<String>), String(String),
Some(ToolType), ToolType(ToolType),
} }
impl From<ToolTypeDeserializer> for ToolChoice { impl From<ToolTypeDeserializer> for ToolChoice {
fn from(value: ToolTypeDeserializer) -> Self { fn from(value: ToolTypeDeserializer) -> Self {
match value { match value {
ToolTypeDeserializer::None(opt) => match opt.as_deref() { ToolTypeDeserializer::String(s) => match s.as_str() {
Some("none") => ToolChoice(None), "none" => ToolChoice(Some(ToolType::NoTool)),
Some("auto") => ToolChoice(Some(ToolType::OneOf)), "auto" => ToolChoice(Some(ToolType::OneOf)),
Some(s) => ToolChoice(Some(ToolType::FunctionName(s.to_string()))), _ => ToolChoice(Some(ToolType::FunctionName(s))),
None => ToolChoice(Some(ToolType::OneOf)),
}, },
ToolTypeDeserializer::Some(tool_type) => ToolChoice(Some(tool_type)), ToolTypeDeserializer::ToolType(tool_type) => ToolChoice(Some(tool_type)),
} }
} }
} }

View File

@ -24,7 +24,7 @@ use crate::{
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest, CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest,
VertexResponse, VertexResponse,
}; };
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolType}; use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType};
use async_stream::__private::AsyncStream; use async_stream::__private::AsyncStream;
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode}; use axum::http::{HeaderMap, Method, StatusCode};
@ -1192,39 +1192,33 @@ async fn chat_completions(
.as_secs(); .as_secs();
let (tool_calls, output) = if tool_grammar.is_some() { let (tool_calls, output) = if tool_grammar.is_some() {
// gen_text should be valid json let gen_text_value: Value = serde_json::from_str(&generation.generated_text)
let gen_text_value: Value = .map_err(|e| InferError::ToolError(e.to_string()))?;
serde_json::from_str(&generation.generated_text).map_err(|e| {
( let function = gen_text_value.get("function").ok_or(InferError::ToolError(
StatusCode::UNPROCESSABLE_ENTITY, "No function found in generated text".to_string(),
Json(ErrorResponse { ))?;
error: e.to_string(),
error_type: "Input validation error".to_string(), let name = function
}), .get("_name")
) .and_then(Value::as_str)
})?; .ok_or(InferError::ToolError(
"No _name found in generated text".to_string(),
))?
.to_string();
let mut arguments = function.clone();
if let Value::Object(ref mut props) = arguments {
props.remove("_name");
}
let tool_calls = vec![ToolCall { let tool_calls = vec![ToolCall {
id: "0".to_string(), id: "0".to_string(),
r#type: "function".to_string(), r#type: "function".to_string(),
function: FunctionDefinition { function: FunctionDefinition {
description: None, description: None,
name: gen_text_value name,
.get("function") arguments,
.and_then(|f| f.get("_name"))
.and_then(|name| name.as_str())
.unwrap_or("default_function_name")
.to_string(),
// Serialize the JSON object obtained from "function" to an escaped JSON string
arguments: gen_text_value
.get("function")
.map(|f| {
let mut f_cloned = f.clone();
if let Value::Object(ref mut props) = f_cloned {
props.remove("_name");
}
f_cloned
})
.unwrap_or_default(),
}, },
}]; }];
(Some(tool_calls), None) (Some(tool_calls), None)
@ -1498,6 +1492,7 @@ pub async fn run(
ToolCall, ToolCall,
Function, Function,
FunctionDefinition, FunctionDefinition,
ToolChoice,
) )
), ),
tags( tags(