diff --git a/docs/openapi.json b/docs/openapi.json index 3e7050ab..7000c7b7 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -909,7 +909,7 @@ "tool_choice": { "allOf": [ { - "$ref": "#/components/schemas/ToolType" + "$ref": "#/components/schemas/ToolChoice" } ], "nullable": true @@ -2035,6 +2035,14 @@ } } }, + "ToolChoice": { + "allOf": [ + { + "$ref": "#/components/schemas/ToolType" + } + ], + "nullable": true + }, "ToolType": { "oneOf": [ { @@ -2055,6 +2063,11 @@ "$ref": "#/components/schemas/FunctionName" } } + }, + { + "type": "object", + "default": null, + "nullable": true } ] }, diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index f3b10450..db9070d4 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -7,7 +7,7 @@ pub(crate) use health::HealthCheck; use crate::validation::{ValidGenerateRequest, Validation, ValidationError}; use crate::{ ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, - HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token, + HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token, ToolChoice, }; use crate::{ FunctionRef, FunctionsMap, GrammarType, Properties, TokenizerConfigToken, Tool, ToolType, Tools, @@ -332,126 +332,131 @@ impl ChatTemplate { pub struct ToolGrammar {} impl ToolGrammar { + // find a tool by name + fn find_tool_by_name(tools: &[Tool], name: &str) -> Result { + tools + .iter() + .find(|tool| tool.function.name == name) + .cloned() + .ok_or_else(|| InferError::ToolError(format!("Tool with name {} not found", name))) + } + pub fn apply( tools: Option>, - tool_choice: Option, + tool_choice: ToolChoice, ) -> Result, InferError> { - if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) { - // let tool_prompt = tool_prompt.unwrap_or_default(); - let tools_to_use = match tool_choice { - ToolType::FunctionName(name) => { - vec![req_tools - .iter() - .find(|tool| tool.function.name == *name) - .unwrap_or_else(|| panic!("Tool with name {} not found", name)) - .clone()] - } - ToolType::Function { function } => { - let tool = req_tools - .iter() - .find(|tool| tool.function.name == function.name) - .unwrap_or_else(|| panic!("Tool with name {} not found", function.name)) - .clone(); - vec![tool] - } - ToolType::OneOf => req_tools.to_owned(), - }; + // if no tools are provided, we return None + let tools = match tools { + Some(tools) if !tools.is_empty() => tools, + _ => return Ok(None), + }; - // adds the error notification function for LLM feedback if required - let mut text_response_properties = Map::new(); - text_response_properties.insert( - "error".to_string(), - serde_json::json!({ - "type": "string", - "description": "The error or issue to notify" - }), - ); - text_response_properties.insert( - "_name".to_string(), - serde_json::json!({ - "type": "string", - "const": "notify_error" - }), - ); + let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf); - let functions: HashMap = tools_to_use - .iter() - .map(|tool| { - let func = tool.function.clone(); + // if tools are provided and no tool_choice we default to the OneOf + let tools_to_use = match tool_choice { + ToolType::FunctionName(name) => { + vec![Self::find_tool_by_name(&tools, &name)?] + } + ToolType::Function { function } => { + vec![Self::find_tool_by_name(&tools, &function.name)?] + } + ToolType::OneOf => tools, + ToolType::NoTool => return Ok(None), + }; - // Clone the existing parameters, which are expected to be a JSON object - let mut params = if let Value::Object(params) = &func.arguments { - params.clone() - } else { - Map::new() - }; + // adds the error notification function for LLM feedback if required + let mut text_response_properties = Map::new(); + text_response_properties.insert( + "error".to_string(), + serde_json::json!({ + "type": "string", + "description": "The error or issue to notify" + }), + ); + text_response_properties.insert( + "_name".to_string(), + serde_json::json!({ + "type": "string", + "const": "notify_error" + }), + ); - // Insert the function's description at the top level, outside of properties - params.insert( - "description".to_string(), - Value::String(func.description.clone().unwrap_or_default()), - ); + let functions: HashMap = tools_to_use + .iter() + .map(|tool| { + let func = tool.function.clone(); - // Ensure 'properties' exists and is an object - let properties = params - .entry("properties".to_string()) - .or_insert_with(|| json!({})) - .as_object_mut() - .unwrap(); + // Clone the existing parameters, which are expected to be a JSON object + let mut params = if let Value::Object(params) = &func.arguments { + params.clone() + } else { + Map::new() + }; - // Insert the constant for the function name inside 'properties' - properties.insert( - "_name".to_string(), - json!({ - "type": "string", - "const": func.name.clone(), - // "description": "The name of the function" - }), - ); + // Insert the function's description at the top level, outside of properties + params.insert( + "description".to_string(), + Value::String(func.description.clone().unwrap_or_default()), + ); - // Check if 'required' exists, and it is an array. If not, create an empty array. - let required = params - .entry("required".to_string()) - .or_insert_with(|| json!([])) - .as_array_mut() - .unwrap(); + // Ensure 'properties' exists and is an object + let properties = params + .entry("properties".to_string()) + .or_insert_with(|| json!({})) + .as_object_mut() + .unwrap(); - // Add 'name' to the 'required' array if it is not already present - if !required.iter().any(|r| r == "_name") { - required.push(json!("_name")); - } - - (func.name, Value::Object(params)) - }) - .chain([( - "notify_error".to_string(), - serde_json::json!({ - "properties": text_response_properties, - "required": ["error", "_name"], - "type": "object" + // Insert the constant for the function name inside 'properties' + properties.insert( + "_name".to_string(), + json!({ + "type": "string", + "const": func.name.clone(), + // "description": "The name of the function" }), - )]) - .collect(); + ); - let tools = Tools { - functions_map: FunctionsMap { functions }, - properties: Properties { - function: tools_to_use - .iter() - .map(|tool| FunctionRef { - ref_path: format!("#/$functions/{}", tool.function.name.clone()), - }) - .chain(std::iter::once(FunctionRef { - ref_path: "#/$functions/notify_error".to_string(), - })) - .collect(), - }, - }; + // Check if 'required' exists, and it is an array. If not, create an empty array. + let required = params + .entry("required".to_string()) + .or_insert_with(|| json!([])) + .as_array_mut() + .unwrap(); - return Ok(Some(tools)); - } - // Err(InferError::ToolError("No tools provided".to_string())) - Ok(None) + // Add 'name' to the 'required' array if it is not already present + if !required.iter().any(|r| r == "_name") { + required.push(json!("_name")); + } + + (func.name, Value::Object(params)) + }) + .chain([( + "notify_error".to_string(), + serde_json::json!({ + "properties": text_response_properties, + "required": ["error", "_name"], + "type": "object" + }), + )]) + .collect(); + + let tools = Tools { + functions_map: FunctionsMap { functions }, + properties: Properties { + function: tools_to_use + .iter() + .map(|tool| FunctionRef { + ref_path: format!("#/$functions/{}", tool.function.name.clone()), + }) + .chain(std::iter::once(FunctionRef { + ref_path: "#/$functions/notify_error".to_string(), + })) + .collect(), + }, + }; + + Ok(Some(tools)) } } diff --git a/router/src/lib.rs b/router/src/lib.rs index 52926c6c..b6e0d09d 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -826,7 +826,7 @@ pub(crate) struct ChatRequest { /// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter. #[serde(default)] #[schema(nullable = true, example = "null")] - pub tool_choice: Option, + pub tool_choice: ToolChoice, /// Response format constraints for the generation. /// @@ -848,6 +848,7 @@ pub enum ToolType { OneOf, FunctionName(String), Function { function: FunctionName }, + NoTool, } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)] @@ -855,27 +856,26 @@ pub struct FunctionName { pub name: String, } -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default, ToSchema)] #[serde(from = "ToolTypeDeserializer")] pub struct ToolChoice(pub Option); #[derive(Deserialize)] #[serde(untagged)] enum ToolTypeDeserializer { - None(Option), - Some(ToolType), + String(String), + ToolType(ToolType), } impl From for ToolChoice { fn from(value: ToolTypeDeserializer) -> Self { match value { - ToolTypeDeserializer::None(opt) => match opt.as_deref() { - Some("none") => ToolChoice(None), - Some("auto") => ToolChoice(Some(ToolType::OneOf)), - Some(s) => ToolChoice(Some(ToolType::FunctionName(s.to_string()))), - None => ToolChoice(Some(ToolType::OneOf)), + ToolTypeDeserializer::String(s) => match s.as_str() { + "none" => ToolChoice(Some(ToolType::NoTool)), + "auto" => ToolChoice(Some(ToolType::OneOf)), + _ => ToolChoice(Some(ToolType::FunctionName(s))), }, - ToolTypeDeserializer::Some(tool_type) => ToolChoice(Some(tool_type)), + ToolTypeDeserializer::ToolType(tool_type) => ToolChoice(Some(tool_type)), } } } diff --git a/router/src/server.rs b/router/src/server.rs index d3a280ca..c56c39a3 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -24,7 +24,7 @@ use crate::{ CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest, VertexResponse, }; -use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolType}; +use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType}; use async_stream::__private::AsyncStream; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; @@ -1192,39 +1192,33 @@ async fn chat_completions( .as_secs(); let (tool_calls, output) = if tool_grammar.is_some() { - // gen_text should be valid json - let gen_text_value: Value = - serde_json::from_str(&generation.generated_text).map_err(|e| { - ( - StatusCode::UNPROCESSABLE_ENTITY, - Json(ErrorResponse { - error: e.to_string(), - error_type: "Input validation error".to_string(), - }), - ) - })?; + let gen_text_value: Value = serde_json::from_str(&generation.generated_text) + .map_err(|e| InferError::ToolError(e.to_string()))?; + + let function = gen_text_value.get("function").ok_or(InferError::ToolError( + "No function found in generated text".to_string(), + ))?; + + let name = function + .get("_name") + .and_then(Value::as_str) + .ok_or(InferError::ToolError( + "No _name found in generated text".to_string(), + ))? + .to_string(); + + let mut arguments = function.clone(); + if let Value::Object(ref mut props) = arguments { + props.remove("_name"); + } + let tool_calls = vec![ToolCall { id: "0".to_string(), r#type: "function".to_string(), function: FunctionDefinition { description: None, - name: gen_text_value - .get("function") - .and_then(|f| f.get("_name")) - .and_then(|name| name.as_str()) - .unwrap_or("default_function_name") - .to_string(), - // Serialize the JSON object obtained from "function" to an escaped JSON string - arguments: gen_text_value - .get("function") - .map(|f| { - let mut f_cloned = f.clone(); - if let Value::Object(ref mut props) = f_cloned { - props.remove("_name"); - } - f_cloned - }) - .unwrap_or_default(), + name, + arguments, }, }]; (Some(tool_calls), None) @@ -1498,6 +1492,7 @@ pub async fn run( ToolCall, Function, FunctionDefinition, + ToolChoice, ) ), tags(