diff --git a/integration-tests/models/test_tools_llama31.py b/integration-tests/models/test_tools_llama31.py new file mode 100644 index 00000000..4e803b5b --- /dev/null +++ b/integration-tests/models/test_tools_llama31.py @@ -0,0 +1,315 @@ +import pytest +from huggingface_hub import InferenceClient + +# to be removed when the InferenceClient client supports latest parameters +import requests + +@pytest.fixture(scope="module") +def flash_llama_grammar_tools_handle(launcher): + with launcher( + "meta-llama/Meta-Llama-3.1-8B-Instruct", + num_shard=2, + disable_grammar_support=False, + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llama_grammar_tools(flash_llama_grammar_tools_handle): + await flash_llama_grammar_tools_handle.health(300) + return flash_llama_grammar_tools_handle.client + + +# All tests are based on the following model card +# https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/ + +@pytest.mark.asyncio +@pytest.mark.private +async def test_basic_gen(flash_llama_grammar_tools, response_snapshot): + client = InferenceClient( + base_url=flash_llama_grammar_tools.base_url + "/v1", + ) + + output = client.chat.completions.create( + model="meta-llama/Meta-Llama-3.1-8B-Instruct", + messages=[ + { + "role": "system", + "content": "You are a helpful assistant", + }, + { + "role": "user", + "content": "What is the capital of France?", + }, + ], + stream=True, + seed=42, + max_tokens=20, + ) + + final_response = [] + for chunk in output: + final_response.append(chunk.choices[0].delta.content) + resp = ''.join(final_response) + + assert resp == "The capital of France is Paris." + +@pytest.mark.asyncio +@pytest.mark.private +async def test_code_interpreter_gen(flash_llama_grammar_tools, response_snapshot): + client = InferenceClient( + base_url=flash_llama_grammar_tools.base_url + "/v1", + ) + output = client.chat.completions.create( + model="meta-llama/Meta-Llama-3.1-8B-Instruct", + messages=[ + { + "role": "system", + "content": "Environment: ipython", + }, + { + "role": "user", + "content": "Write code to check if number is prime, use that to see if the number 7 is prime", + }, + ], + stream=True, + seed=42, + max_tokens=20, + ) + + final_response = [] + for chunk in output: + final_response.append(chunk.choices[0].delta.content) + resp = ''.join(final_response) + + assert resp == "def is_prime(n):\n if n <= 1:\n return False\n if n" + +@pytest.mark.asyncio +@pytest.mark.private +async def test_code_builtin_tools_gen(flash_llama_grammar_tools, response_snapshot): + url = f"{flash_llama_grammar_tools.base_url}/v1/chat/completions" + + payload = { + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "messages": [ + { + "role": "user", + "content": "What is the current weather in Menlo Park, California?", + } + ], + "stream": False, + "seed": 42, + "max_tokens": 20, + "builtin_tools": ["brave_search", "wolfram_alpha"], + } + + response = requests.request("POST", url, json=payload) + response = response.json() + resp = response.get("choices")[0].get("message").get("content") + assert resp == "brave_search.call(query=\"current weather in Menlo Park, California\")" + +@pytest.mark.asyncio +@pytest.mark.private +async def test_code_builtin_tools_explict_off_gen(flash_llama_grammar_tools, response_snapshot): + url = f"{flash_llama_grammar_tools.base_url}/v1/chat/completions" + + payload = { + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "messages": [ + { + "role": "user", + "content": "What is the current weather in Menlo Park, California?", + } + ], + "stream": False, + "seed": 42, + "max_tokens": 20, + # "builtin_tools": ["brave_search", "wolfram_alpha"], + } + + response = requests.request("POST", url, json=payload) + response = response.json() + resp = response.get("choices")[0].get("message").get("content") + assert resp == "I can't provide real-time weather information. However, I can encourage you to check a weather website" + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_code_builtin_tools_two_gen(flash_llama_grammar_tools, response_snapshot): + url = f"{flash_llama_grammar_tools.base_url}/v1/chat/completions" + + payload = { + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant.", + }, + { + "role": "user", + "content": "Can you help me solve this equation with wolfram_alpha: x^3 - 4x^2 + 6x - 24 = 0", + }, + ], + "stream": False, + "seed": 42, + "max_tokens": 50, + "builtin_tools": ["brave_search", "wolfram_alpha"], + } + + response = requests.request("POST", url, json=payload) + response = response.json() + resp = response.get("choices")[0].get("message").get("content") + assert resp == "wolfram_alpha.call(query=\"solve x^3 - 4x^2 + 6x - 24 = 0\")" + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_code_builtin_tools_function_response_gen(flash_llama_grammar_tools, response_snapshot): + url = f"{flash_llama_grammar_tools.base_url}/v1/chat/completions" + + payload = { + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant.", + }, + { + "role": "user", + "content": "Can you help me solve this equation with wolfram_alpha: x^3 - 4x^2 + 6x - 24 = 0", + }, + { + "role": "assistant", + "content": "wolfram_alpha.call(query=\"solve x^3 - 4x^2 + 6x - 24 = 0\")", + }, + { + "role": "ipython", + "content": "{\"queryresult\": {\"success\": true, \"inputstring\": \"solve x^3 - 4x^2 + 6x - 24 = 0\", \"pods\": [{\"title\": \"Input interpretation\", \"subpods\": [{\"title\": \"\", \"plaintext\": \"solve x^3 - 4 x^2 + 6 x - 24 = 0\"}]}, {\"title\": \"Results\", \"primary\": true, \"subpods\": [{\"title\": \"\", \"plaintext\": \"x = 4\"}, {\"title\": \"\", \"plaintext\": \"x = \u00b1 (i sqrt(6))\"}]}, ... ]}}", + }, + ], + "stream": False, + "seed": 42, + "max_tokens": 50, + "builtin_tools": ["brave_search", "wolfram_alpha"], + } + + response = requests.request("POST", url, json=payload) + response = response.json() + resp = response.get("choices")[0].get("message").get("content") + assert resp == "The solutions to the equation x^3 - 4x^2 + 6x - 24 = 0 are x = 4, x = i√6, and x = -i√6." + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_user_supplied_json_tool_gen(flash_llama_grammar_tools, response_snapshot): + client = InferenceClient( + base_url=flash_llama_grammar_tools.base_url + "/v1", + ) + output = client.chat.completions.create( + model="meta-llama/Meta-Llama-3.1-8B-Instruct", + messages=[ + { + "role": "system", + "content": "You are a helpful assistant with tool calling capabilities" + }, + { + "role": "user", + "content": "Question: what is the weather like in San Fransisco?" + }, + ], + tools=[ + { + "type": "function", + "function": { + "name": "get_current_conditions", + "description": "Get the current weather conditions for a specific location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g., San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["Celsius", "Fahrenheit"], + "description": "The temperature unit to use. Infer this from the user's location." + } + }, + "required": ["location", "unit"] + } + } + } + ], + stream=True, + seed=42, + max_tokens=50, + ) + + final_response = [] + for chunk in output: + final_response.append(chunk.choices[0].delta.content) + resp = ''.join(final_response) + + assert resp == "{\"name\": \"get_current_conditions\", \"parameters\": {\"location\": \"San Francisco, CA\", \"unit\": \"Fahrenheit\"}}" + +@pytest.mark.asyncio +@pytest.mark.private +async def test_user_supplied_json_tool_function_response_gen(flash_llama_grammar_tools, response_snapshot): + client = InferenceClient( + base_url=flash_llama_grammar_tools.base_url + "/v1", + ) + output = client.chat.completions.create( + model="meta-llama/Meta-Llama-3.1-8B-Instruct", + messages=[ + { + "role": "system", + "content": "You are a helpful assistant with tool calling capabilities. When you receive a tool call response, use the output to format an answer to the orginal use question." + }, + { + "role": "user", + "content": "Question: what is the weather like in San Fransisco?" + }, + { + "role": "assistant", + "content": "{\"name\": \"get_current_conditions\", \"parameters\": {\"location\": \"San Francisco, CA\", \"unit\": \"Fahrenheit\"}}", + }, + { + "role": "ipython", + "content": "{\"output\": \"Clouds giving way to sun Hi: 76° Tonight: Mainly clear early, then areas of low clouds forming Lo: 56°\"}", + }, + ], + tools=[ + { + "type": "function", + "function": { + "name": "get_current_conditions", + "description": "Get the current weather conditions for a specific location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g., San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["Celsius", "Fahrenheit"], + "description": "The temperature unit to use. Infer this from the user's location." + } + }, + "required": ["location", "unit"] + } + } + } + ], + stream=True, + seed=42, + max_tokens=50, + ) + + final_response = [] + for chunk in output: + final_response.append(chunk.choices[0].delta.content) + resp = ''.join(final_response) + assert resp == "The current weather conditions in San Francisco, CA are clouds giving way to sun with a high of 76°F and a low of 56°F." \ No newline at end of file diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index 1071d0ba..da1e70f3 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -56,6 +56,7 @@ impl ChatTemplate { guideline: Option<&str>, mut messages: Vec, tools_and_prompt: Option<(Vec, String)>, + builtin_tools: Option>, ) -> Result { // check if guideline is expected but not provided if self.variables.contains("guideline") && guideline.is_none() { @@ -68,12 +69,15 @@ impl ChatTemplate { // if not, we need to append the tools to the last message let text = if self.use_default_tool_template { match serde_json::to_string(&tools) { - Ok(tools_str) => format!("\n---\n{}\n{}", tools_str, tool_prompt), + // Ok(tools_str) => format!("\n---\n{}\n{}", tools_str, tool_prompt), + Ok(tools_str) => format!("\n{}\n{}", tools_str, tool_prompt), Err(e) => return Err(InferError::ToolError(e.to_string())), } } else { // if the `tools` variable is used in the template, we just append the tool_prompt - format!("\n---\n{}", tool_prompt) + // format!("\n---\n{}", tool_prompt) + format!("\n{}", tool_prompt) + // format!("{}", "") }; if let Some(last_message) = messages.last_mut() { last_message.content.push(MessageChunk::Text { text }); @@ -93,6 +97,7 @@ impl ChatTemplate { eos_token: self.eos_token.as_deref(), add_generation_prompt: true, tools, + builtin_tools, }) .map_err(InferError::TemplateError) } diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 1c9d5620..1abd6e07 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -160,11 +160,17 @@ impl Infer { guideline: Option, messages: Vec, tools_and_prompt: Option<(Vec, String)>, + builtin_tools: Option>, ) -> Result { self.chat_template .as_ref() .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? - .apply(guideline.as_deref(), messages, tools_and_prompt) + .apply( + guideline.as_deref(), + messages, + tools_and_prompt, + builtin_tools, + ) .map_err(|e| { metrics::counter!("tgi_request_failure", "err" => "template").increment(1); tracing::error!("{e}"); diff --git a/router/src/infer/tool_grammar.rs b/router/src/infer/tool_grammar.rs index 4fe15720..0243ec6f 100644 --- a/router/src/infer/tool_grammar.rs +++ b/router/src/infer/tool_grammar.rs @@ -1,8 +1,5 @@ use crate::infer::InferError; -use crate::{ - FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice, - ToolType, -}; +use crate::{FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice, ToolType}; use serde_json::{json, Map, Value}; use std::collections::HashMap; @@ -29,27 +26,27 @@ impl ToolGrammar { let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf); - let mut tools = tools.clone(); + // let mut tools = tools.clone(); - // add the notify_error function to the tools - let notify_error = Tool { - r#type: "function".to_string(), - function: FunctionDefinition { - name: "notify_error".to_string(), - description: Some("Notify an error or issue".to_string()), - arguments: json!({ - "type": "object", - "properties": { - "error": { - "type": "string", - "description": "The error or issue to notify" - } - }, - "required": ["error"] - }), - }, - }; - tools.push(notify_error); + // // add the notify_error function to the tools + // let notify_error = Tool { + // r#type: "function".to_string(), + // function: FunctionDefinition { + // name: "notify_error".to_string(), + // description: Some("Notify an error or issue".to_string()), + // arguments: json!({ + // "type": "object", + // "properties": { + // "error": { + // "type": "string", + // "description": "The error or issue to notify" + // } + // }, + // "required": ["error"] + // }), + // }, + // }; + // tools.push(notify_error); // if tools are provided and no tool_choice we default to the OneOf let tools_to_use = match tool_choice { @@ -86,7 +83,7 @@ impl ToolGrammar { }), ); - if let Value::Object(args) = func.arguments { + if let Value::Object(args) = func.parameters { if let Some(Value::Object(props)) = args.get("properties") { properties.extend(props.clone()); } @@ -109,7 +106,7 @@ impl ToolGrammar { }) .collect(); - let tool_schema = JsonSchemaTool { + let _tool_schema = JsonSchemaTool { functions_map: FunctionsMap { functions }, properties: Properties { function: tools_to_use @@ -121,6 +118,7 @@ impl ToolGrammar { }, }; - Ok((tools, Some(tool_schema))) + // Ok((tools, Some(tool_schema))) + Ok((tools, None)) } } diff --git a/router/src/lib.rs b/router/src/lib.rs index 0901bafa..8754a7b2 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -864,6 +864,12 @@ pub(crate) struct ChatRequest { #[schema(nullable = true, default = "null", example = "null")] pub guideline: Option, + /// A list of builtin_tools (these must be trained into the model. + /// See https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#built-in-python-based-tool-calling for more information. + #[serde(default)] + #[schema(nullable = true, example = "null")] + pub builtin_tools: Option>, + /// Options for streaming response. Only set this when you set stream: true. #[serde(default)] #[schema(nullable = true, example = "null")] @@ -885,6 +891,7 @@ impl ChatRequest { temperature, response_format, guideline, + builtin_tools, presence_penalty, frequency_penalty, top_p, @@ -911,8 +918,12 @@ impl ChatRequest { &tool_prompt, guideline, messages, + builtin_tools, )?; + println!("inputs: {}", inputs); + // println!("grammar: {:?}", grammar); + Ok(( GenerateRequest { inputs: inputs.to_string(), @@ -953,7 +964,8 @@ struct StreamOptions { } pub fn default_tool_prompt() -> String { - "\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string() + // "\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string() + "".to_string() } #[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)] @@ -1034,8 +1046,8 @@ pub(crate) struct FunctionDefinition { #[serde(default)] pub description: Option, pub name: String, - #[serde(alias = "parameters")] - pub arguments: serde_json::Value, + // #[serde(alias = "parameters")] + pub parameters: serde_json::Value, } #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] @@ -1056,6 +1068,8 @@ pub(crate) struct ChatTemplateInputs<'a> { add_generation_prompt: bool, tools: Option>, guideline: Option<&'a str>, + // builtin_tools: Option>, + builtin_tools: Option>, } #[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug, PartialEq)] diff --git a/router/src/server.rs b/router/src/server.rs index fb06b245..d4f9194e 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1267,7 +1267,7 @@ async fn chat_completions( function: FunctionDefinition { description: None, name, - arguments, + parameters: arguments, }, }]; (Some(tool_calls), None) @@ -2370,6 +2370,7 @@ pub enum WebServerError { type PreparedInput = (String, Option, bool); +#[allow(clippy::too_many_arguments)] pub(crate) fn prepare_chat_input( infer: &Infer, response_format: Option, @@ -2378,6 +2379,7 @@ pub(crate) fn prepare_chat_input( tool_prompt: &str, guideline: Option, messages: Vec, + builtin_tools: Option>, ) -> Result { if response_format.is_some() && tools.is_some() { return Err(InferError::ToolError( @@ -2387,7 +2389,7 @@ pub(crate) fn prepare_chat_input( // when response_format is set, tools are not included when applying the chat template to generate inputs if let Some(format) = response_format { - let inputs = infer.apply_chat_template(guideline, messages, None)?; + let inputs = infer.apply_chat_template(guideline, messages, None, builtin_tools)?; return Ok((inputs, Some(format), false)); } @@ -2404,12 +2406,13 @@ pub(crate) fn prepare_chat_input( guideline, messages, Some((updated_tools, tool_prompt.into())), + builtin_tools, )?; return Ok((inputs, grammar, tool_schema.is_some())); } // if no response_format or tools are set simply apply the chat template to generate inputs - let inputs = infer.apply_chat_template(guideline, messages, None)?; + let inputs = infer.apply_chat_template(guideline, messages, None, builtin_tools)?; Ok((inputs, None, false)) } diff --git a/router/src/vertex.rs b/router/src/vertex.rs index 0c1467fe..3566f771 100644 --- a/router/src/vertex.rs +++ b/router/src/vertex.rs @@ -138,6 +138,12 @@ pub(crate) struct VertexParameters { #[schema(nullable = true, default = "null", example = "null")] pub guideline: Option, + /// A list of builtin_tools (these must be trained into the model. + /// See https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#built-in-python-based-tool-calling for more information. + #[serde(default)] + #[schema(nullable = true, example = "null")] + pub builtin_tools: Option>, + /// Options for streaming response. Only set this when you set stream: true. #[serde(default)] #[schema(nullable = true, example = "null")] @@ -150,6 +156,7 @@ impl From for ChatRequest { messages: val.messages, frequency_penalty: val.parameters.frequency_penalty, guideline: val.parameters.guideline, + builtin_tools: val.parameters.builtin_tools, logit_bias: val.parameters.logit_bias, logprobs: val.parameters.logprobs, max_tokens: val.parameters.max_tokens,