feat: support llama 3.1 tooling and remove grammar schema
This commit is contained in:
parent
f6e2f05b16
commit
3f07ddb469
|
@ -0,0 +1,315 @@
|
||||||
|
import pytest
|
||||||
|
from huggingface_hub import InferenceClient
|
||||||
|
|
||||||
|
# to be removed when the InferenceClient client supports latest parameters
|
||||||
|
import requests
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_llama_grammar_tools_handle(launcher):
|
||||||
|
with launcher(
|
||||||
|
"meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
|
num_shard=2,
|
||||||
|
disable_grammar_support=False,
|
||||||
|
) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_llama_grammar_tools(flash_llama_grammar_tools_handle):
|
||||||
|
await flash_llama_grammar_tools_handle.health(300)
|
||||||
|
return flash_llama_grammar_tools_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
# All tests are based on the following model card
|
||||||
|
# https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_basic_gen(flash_llama_grammar_tools, response_snapshot):
|
||||||
|
client = InferenceClient(
|
||||||
|
base_url=flash_llama_grammar_tools.base_url + "/v1",
|
||||||
|
)
|
||||||
|
|
||||||
|
output = client.chat.completions.create(
|
||||||
|
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is the capital of France?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
stream=True,
|
||||||
|
seed=42,
|
||||||
|
max_tokens=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
final_response = []
|
||||||
|
for chunk in output:
|
||||||
|
final_response.append(chunk.choices[0].delta.content)
|
||||||
|
resp = ''.join(final_response)
|
||||||
|
|
||||||
|
assert resp == "The capital of France is Paris."
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_code_interpreter_gen(flash_llama_grammar_tools, response_snapshot):
|
||||||
|
client = InferenceClient(
|
||||||
|
base_url=flash_llama_grammar_tools.base_url + "/v1",
|
||||||
|
)
|
||||||
|
output = client.chat.completions.create(
|
||||||
|
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Environment: ipython",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Write code to check if number is prime, use that to see if the number 7 is prime",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
stream=True,
|
||||||
|
seed=42,
|
||||||
|
max_tokens=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
final_response = []
|
||||||
|
for chunk in output:
|
||||||
|
final_response.append(chunk.choices[0].delta.content)
|
||||||
|
resp = ''.join(final_response)
|
||||||
|
|
||||||
|
assert resp == "def is_prime(n):\n if n <= 1:\n return False\n if n"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_code_builtin_tools_gen(flash_llama_grammar_tools, response_snapshot):
|
||||||
|
url = f"{flash_llama_grammar_tools.base_url}/v1/chat/completions"
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is the current weather in Menlo Park, California?",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stream": False,
|
||||||
|
"seed": 42,
|
||||||
|
"max_tokens": 20,
|
||||||
|
"builtin_tools": ["brave_search", "wolfram_alpha"],
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.request("POST", url, json=payload)
|
||||||
|
response = response.json()
|
||||||
|
resp = response.get("choices")[0].get("message").get("content")
|
||||||
|
assert resp == "brave_search.call(query=\"current weather in Menlo Park, California\")"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_code_builtin_tools_explict_off_gen(flash_llama_grammar_tools, response_snapshot):
|
||||||
|
url = f"{flash_llama_grammar_tools.base_url}/v1/chat/completions"
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is the current weather in Menlo Park, California?",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stream": False,
|
||||||
|
"seed": 42,
|
||||||
|
"max_tokens": 20,
|
||||||
|
# "builtin_tools": ["brave_search", "wolfram_alpha"],
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.request("POST", url, json=payload)
|
||||||
|
response = response.json()
|
||||||
|
resp = response.get("choices")[0].get("message").get("content")
|
||||||
|
assert resp == "I can't provide real-time weather information. However, I can encourage you to check a weather website"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_code_builtin_tools_two_gen(flash_llama_grammar_tools, response_snapshot):
|
||||||
|
url = f"{flash_llama_grammar_tools.base_url}/v1/chat/completions"
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Can you help me solve this equation with wolfram_alpha: x^3 - 4x^2 + 6x - 24 = 0",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"stream": False,
|
||||||
|
"seed": 42,
|
||||||
|
"max_tokens": 50,
|
||||||
|
"builtin_tools": ["brave_search", "wolfram_alpha"],
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.request("POST", url, json=payload)
|
||||||
|
response = response.json()
|
||||||
|
resp = response.get("choices")[0].get("message").get("content")
|
||||||
|
assert resp == "wolfram_alpha.call(query=\"solve x^3 - 4x^2 + 6x - 24 = 0\")"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_code_builtin_tools_function_response_gen(flash_llama_grammar_tools, response_snapshot):
|
||||||
|
url = f"{flash_llama_grammar_tools.base_url}/v1/chat/completions"
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Can you help me solve this equation with wolfram_alpha: x^3 - 4x^2 + 6x - 24 = 0",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "wolfram_alpha.call(query=\"solve x^3 - 4x^2 + 6x - 24 = 0\")",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "ipython",
|
||||||
|
"content": "{\"queryresult\": {\"success\": true, \"inputstring\": \"solve x^3 - 4x^2 + 6x - 24 = 0\", \"pods\": [{\"title\": \"Input interpretation\", \"subpods\": [{\"title\": \"\", \"plaintext\": \"solve x^3 - 4 x^2 + 6 x - 24 = 0\"}]}, {\"title\": \"Results\", \"primary\": true, \"subpods\": [{\"title\": \"\", \"plaintext\": \"x = 4\"}, {\"title\": \"\", \"plaintext\": \"x = \u00b1 (i sqrt(6))\"}]}, ... ]}}",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"stream": False,
|
||||||
|
"seed": 42,
|
||||||
|
"max_tokens": 50,
|
||||||
|
"builtin_tools": ["brave_search", "wolfram_alpha"],
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.request("POST", url, json=payload)
|
||||||
|
response = response.json()
|
||||||
|
resp = response.get("choices")[0].get("message").get("content")
|
||||||
|
assert resp == "The solutions to the equation x^3 - 4x^2 + 6x - 24 = 0 are x = 4, x = i√6, and x = -i√6."
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_user_supplied_json_tool_gen(flash_llama_grammar_tools, response_snapshot):
|
||||||
|
client = InferenceClient(
|
||||||
|
base_url=flash_llama_grammar_tools.base_url + "/v1",
|
||||||
|
)
|
||||||
|
output = client.chat.completions.create(
|
||||||
|
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant with tool calling capabilities"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Question: what is the weather like in San Fransisco?"
|
||||||
|
},
|
||||||
|
],
|
||||||
|
tools=[
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_conditions",
|
||||||
|
"description": "Get the current weather conditions for a specific location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g., San Francisco, CA"
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["Celsius", "Fahrenheit"],
|
||||||
|
"description": "The temperature unit to use. Infer this from the user's location."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location", "unit"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
stream=True,
|
||||||
|
seed=42,
|
||||||
|
max_tokens=50,
|
||||||
|
)
|
||||||
|
|
||||||
|
final_response = []
|
||||||
|
for chunk in output:
|
||||||
|
final_response.append(chunk.choices[0].delta.content)
|
||||||
|
resp = ''.join(final_response)
|
||||||
|
|
||||||
|
assert resp == "{\"name\": \"get_current_conditions\", \"parameters\": {\"location\": \"San Francisco, CA\", \"unit\": \"Fahrenheit\"}}"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_user_supplied_json_tool_function_response_gen(flash_llama_grammar_tools, response_snapshot):
|
||||||
|
client = InferenceClient(
|
||||||
|
base_url=flash_llama_grammar_tools.base_url + "/v1",
|
||||||
|
)
|
||||||
|
output = client.chat.completions.create(
|
||||||
|
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant with tool calling capabilities. When you receive a tool call response, use the output to format an answer to the orginal use question."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Question: what is the weather like in San Fransisco?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "{\"name\": \"get_current_conditions\", \"parameters\": {\"location\": \"San Francisco, CA\", \"unit\": \"Fahrenheit\"}}",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "ipython",
|
||||||
|
"content": "{\"output\": \"Clouds giving way to sun Hi: 76° Tonight: Mainly clear early, then areas of low clouds forming Lo: 56°\"}",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
tools=[
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_conditions",
|
||||||
|
"description": "Get the current weather conditions for a specific location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g., San Francisco, CA"
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["Celsius", "Fahrenheit"],
|
||||||
|
"description": "The temperature unit to use. Infer this from the user's location."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location", "unit"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
stream=True,
|
||||||
|
seed=42,
|
||||||
|
max_tokens=50,
|
||||||
|
)
|
||||||
|
|
||||||
|
final_response = []
|
||||||
|
for chunk in output:
|
||||||
|
final_response.append(chunk.choices[0].delta.content)
|
||||||
|
resp = ''.join(final_response)
|
||||||
|
assert resp == "The current weather conditions in San Francisco, CA are clouds giving way to sun with a high of 76°F and a low of 56°F."
|
|
@ -56,6 +56,7 @@ impl ChatTemplate {
|
||||||
guideline: Option<&str>,
|
guideline: Option<&str>,
|
||||||
mut messages: Vec<Message>,
|
mut messages: Vec<Message>,
|
||||||
tools_and_prompt: Option<(Vec<Tool>, String)>,
|
tools_and_prompt: Option<(Vec<Tool>, String)>,
|
||||||
|
builtin_tools: Option<Vec<String>>,
|
||||||
) -> Result<String, InferError> {
|
) -> Result<String, InferError> {
|
||||||
// check if guideline is expected but not provided
|
// check if guideline is expected but not provided
|
||||||
if self.variables.contains("guideline") && guideline.is_none() {
|
if self.variables.contains("guideline") && guideline.is_none() {
|
||||||
|
@ -68,12 +69,15 @@ impl ChatTemplate {
|
||||||
// if not, we need to append the tools to the last message
|
// if not, we need to append the tools to the last message
|
||||||
let text = if self.use_default_tool_template {
|
let text = if self.use_default_tool_template {
|
||||||
match serde_json::to_string(&tools) {
|
match serde_json::to_string(&tools) {
|
||||||
Ok(tools_str) => format!("\n---\n{}\n{}", tools_str, tool_prompt),
|
// Ok(tools_str) => format!("\n---\n{}\n{}", tools_str, tool_prompt),
|
||||||
|
Ok(tools_str) => format!("\n{}\n{}", tools_str, tool_prompt),
|
||||||
Err(e) => return Err(InferError::ToolError(e.to_string())),
|
Err(e) => return Err(InferError::ToolError(e.to_string())),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// if the `tools` variable is used in the template, we just append the tool_prompt
|
// if the `tools` variable is used in the template, we just append the tool_prompt
|
||||||
format!("\n---\n{}", tool_prompt)
|
// format!("\n---\n{}", tool_prompt)
|
||||||
|
format!("\n{}", tool_prompt)
|
||||||
|
// format!("{}", "")
|
||||||
};
|
};
|
||||||
if let Some(last_message) = messages.last_mut() {
|
if let Some(last_message) = messages.last_mut() {
|
||||||
last_message.content.push(MessageChunk::Text { text });
|
last_message.content.push(MessageChunk::Text { text });
|
||||||
|
@ -93,6 +97,7 @@ impl ChatTemplate {
|
||||||
eos_token: self.eos_token.as_deref(),
|
eos_token: self.eos_token.as_deref(),
|
||||||
add_generation_prompt: true,
|
add_generation_prompt: true,
|
||||||
tools,
|
tools,
|
||||||
|
builtin_tools,
|
||||||
})
|
})
|
||||||
.map_err(InferError::TemplateError)
|
.map_err(InferError::TemplateError)
|
||||||
}
|
}
|
||||||
|
|
|
@ -160,11 +160,17 @@ impl Infer {
|
||||||
guideline: Option<String>,
|
guideline: Option<String>,
|
||||||
messages: Vec<Message>,
|
messages: Vec<Message>,
|
||||||
tools_and_prompt: Option<(Vec<Tool>, String)>,
|
tools_and_prompt: Option<(Vec<Tool>, String)>,
|
||||||
|
builtin_tools: Option<Vec<String>>,
|
||||||
) -> Result<String, InferError> {
|
) -> Result<String, InferError> {
|
||||||
self.chat_template
|
self.chat_template
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
||||||
.apply(guideline.as_deref(), messages, tools_and_prompt)
|
.apply(
|
||||||
|
guideline.as_deref(),
|
||||||
|
messages,
|
||||||
|
tools_and_prompt,
|
||||||
|
builtin_tools,
|
||||||
|
)
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
metrics::counter!("tgi_request_failure", "err" => "template").increment(1);
|
metrics::counter!("tgi_request_failure", "err" => "template").increment(1);
|
||||||
tracing::error!("{e}");
|
tracing::error!("{e}");
|
||||||
|
|
|
@ -1,8 +1,5 @@
|
||||||
use crate::infer::InferError;
|
use crate::infer::InferError;
|
||||||
use crate::{
|
use crate::{FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice, ToolType};
|
||||||
FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice,
|
|
||||||
ToolType,
|
|
||||||
};
|
|
||||||
use serde_json::{json, Map, Value};
|
use serde_json::{json, Map, Value};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
@ -29,27 +26,27 @@ impl ToolGrammar {
|
||||||
|
|
||||||
let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
|
let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
|
||||||
|
|
||||||
let mut tools = tools.clone();
|
// let mut tools = tools.clone();
|
||||||
|
|
||||||
// add the notify_error function to the tools
|
// // add the notify_error function to the tools
|
||||||
let notify_error = Tool {
|
// let notify_error = Tool {
|
||||||
r#type: "function".to_string(),
|
// r#type: "function".to_string(),
|
||||||
function: FunctionDefinition {
|
// function: FunctionDefinition {
|
||||||
name: "notify_error".to_string(),
|
// name: "notify_error".to_string(),
|
||||||
description: Some("Notify an error or issue".to_string()),
|
// description: Some("Notify an error or issue".to_string()),
|
||||||
arguments: json!({
|
// arguments: json!({
|
||||||
"type": "object",
|
// "type": "object",
|
||||||
"properties": {
|
// "properties": {
|
||||||
"error": {
|
// "error": {
|
||||||
"type": "string",
|
// "type": "string",
|
||||||
"description": "The error or issue to notify"
|
// "description": "The error or issue to notify"
|
||||||
}
|
// }
|
||||||
},
|
// },
|
||||||
"required": ["error"]
|
// "required": ["error"]
|
||||||
}),
|
// }),
|
||||||
},
|
// },
|
||||||
};
|
// };
|
||||||
tools.push(notify_error);
|
// tools.push(notify_error);
|
||||||
|
|
||||||
// if tools are provided and no tool_choice we default to the OneOf
|
// if tools are provided and no tool_choice we default to the OneOf
|
||||||
let tools_to_use = match tool_choice {
|
let tools_to_use = match tool_choice {
|
||||||
|
@ -86,7 +83,7 @@ impl ToolGrammar {
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
|
|
||||||
if let Value::Object(args) = func.arguments {
|
if let Value::Object(args) = func.parameters {
|
||||||
if let Some(Value::Object(props)) = args.get("properties") {
|
if let Some(Value::Object(props)) = args.get("properties") {
|
||||||
properties.extend(props.clone());
|
properties.extend(props.clone());
|
||||||
}
|
}
|
||||||
|
@ -109,7 +106,7 @@ impl ToolGrammar {
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let tool_schema = JsonSchemaTool {
|
let _tool_schema = JsonSchemaTool {
|
||||||
functions_map: FunctionsMap { functions },
|
functions_map: FunctionsMap { functions },
|
||||||
properties: Properties {
|
properties: Properties {
|
||||||
function: tools_to_use
|
function: tools_to_use
|
||||||
|
@ -121,6 +118,7 @@ impl ToolGrammar {
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok((tools, Some(tool_schema)))
|
// Ok((tools, Some(tool_schema)))
|
||||||
|
Ok((tools, None))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -864,6 +864,12 @@ pub(crate) struct ChatRequest {
|
||||||
#[schema(nullable = true, default = "null", example = "null")]
|
#[schema(nullable = true, default = "null", example = "null")]
|
||||||
pub guideline: Option<String>,
|
pub guideline: Option<String>,
|
||||||
|
|
||||||
|
/// A list of builtin_tools (these must be trained into the model.
|
||||||
|
/// See https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#built-in-python-based-tool-calling for more information.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = "null")]
|
||||||
|
pub builtin_tools: Option<Vec<String>>,
|
||||||
|
|
||||||
/// Options for streaming response. Only set this when you set stream: true.
|
/// Options for streaming response. Only set this when you set stream: true.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(nullable = true, example = "null")]
|
#[schema(nullable = true, example = "null")]
|
||||||
|
@ -885,6 +891,7 @@ impl ChatRequest {
|
||||||
temperature,
|
temperature,
|
||||||
response_format,
|
response_format,
|
||||||
guideline,
|
guideline,
|
||||||
|
builtin_tools,
|
||||||
presence_penalty,
|
presence_penalty,
|
||||||
frequency_penalty,
|
frequency_penalty,
|
||||||
top_p,
|
top_p,
|
||||||
|
@ -911,8 +918,12 @@ impl ChatRequest {
|
||||||
&tool_prompt,
|
&tool_prompt,
|
||||||
guideline,
|
guideline,
|
||||||
messages,
|
messages,
|
||||||
|
builtin_tools,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
|
println!("inputs: {}", inputs);
|
||||||
|
// println!("grammar: {:?}", grammar);
|
||||||
|
|
||||||
Ok((
|
Ok((
|
||||||
GenerateRequest {
|
GenerateRequest {
|
||||||
inputs: inputs.to_string(),
|
inputs: inputs.to_string(),
|
||||||
|
@ -953,7 +964,8 @@ struct StreamOptions {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn default_tool_prompt() -> String {
|
pub fn default_tool_prompt() -> String {
|
||||||
"\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string()
|
// "\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string()
|
||||||
|
"".to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)]
|
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)]
|
||||||
|
@ -1034,8 +1046,8 @@ pub(crate) struct FunctionDefinition {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub description: Option<String>,
|
pub description: Option<String>,
|
||||||
pub name: String,
|
pub name: String,
|
||||||
#[serde(alias = "parameters")]
|
// #[serde(alias = "parameters")]
|
||||||
pub arguments: serde_json::Value,
|
pub parameters: serde_json::Value,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||||
|
@ -1056,6 +1068,8 @@ pub(crate) struct ChatTemplateInputs<'a> {
|
||||||
add_generation_prompt: bool,
|
add_generation_prompt: bool,
|
||||||
tools: Option<Vec<Tool>>,
|
tools: Option<Vec<Tool>>,
|
||||||
guideline: Option<&'a str>,
|
guideline: Option<&'a str>,
|
||||||
|
// builtin_tools: Option<Vec<&'a str>>,
|
||||||
|
builtin_tools: Option<Vec<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug, PartialEq)]
|
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug, PartialEq)]
|
||||||
|
|
|
@ -1267,7 +1267,7 @@ async fn chat_completions(
|
||||||
function: FunctionDefinition {
|
function: FunctionDefinition {
|
||||||
description: None,
|
description: None,
|
||||||
name,
|
name,
|
||||||
arguments,
|
parameters: arguments,
|
||||||
},
|
},
|
||||||
}];
|
}];
|
||||||
(Some(tool_calls), None)
|
(Some(tool_calls), None)
|
||||||
|
@ -2370,6 +2370,7 @@ pub enum WebServerError {
|
||||||
|
|
||||||
type PreparedInput = (String, Option<GrammarType>, bool);
|
type PreparedInput = (String, Option<GrammarType>, bool);
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub(crate) fn prepare_chat_input(
|
pub(crate) fn prepare_chat_input(
|
||||||
infer: &Infer,
|
infer: &Infer,
|
||||||
response_format: Option<GrammarType>,
|
response_format: Option<GrammarType>,
|
||||||
|
@ -2378,6 +2379,7 @@ pub(crate) fn prepare_chat_input(
|
||||||
tool_prompt: &str,
|
tool_prompt: &str,
|
||||||
guideline: Option<String>,
|
guideline: Option<String>,
|
||||||
messages: Vec<Message>,
|
messages: Vec<Message>,
|
||||||
|
builtin_tools: Option<Vec<String>>,
|
||||||
) -> Result<PreparedInput, InferError> {
|
) -> Result<PreparedInput, InferError> {
|
||||||
if response_format.is_some() && tools.is_some() {
|
if response_format.is_some() && tools.is_some() {
|
||||||
return Err(InferError::ToolError(
|
return Err(InferError::ToolError(
|
||||||
|
@ -2387,7 +2389,7 @@ pub(crate) fn prepare_chat_input(
|
||||||
|
|
||||||
// when response_format is set, tools are not included when applying the chat template to generate inputs
|
// when response_format is set, tools are not included when applying the chat template to generate inputs
|
||||||
if let Some(format) = response_format {
|
if let Some(format) = response_format {
|
||||||
let inputs = infer.apply_chat_template(guideline, messages, None)?;
|
let inputs = infer.apply_chat_template(guideline, messages, None, builtin_tools)?;
|
||||||
return Ok((inputs, Some(format), false));
|
return Ok((inputs, Some(format), false));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2404,12 +2406,13 @@ pub(crate) fn prepare_chat_input(
|
||||||
guideline,
|
guideline,
|
||||||
messages,
|
messages,
|
||||||
Some((updated_tools, tool_prompt.into())),
|
Some((updated_tools, tool_prompt.into())),
|
||||||
|
builtin_tools,
|
||||||
)?;
|
)?;
|
||||||
return Ok((inputs, grammar, tool_schema.is_some()));
|
return Ok((inputs, grammar, tool_schema.is_some()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// if no response_format or tools are set simply apply the chat template to generate inputs
|
// if no response_format or tools are set simply apply the chat template to generate inputs
|
||||||
let inputs = infer.apply_chat_template(guideline, messages, None)?;
|
let inputs = infer.apply_chat_template(guideline, messages, None, builtin_tools)?;
|
||||||
Ok((inputs, None, false))
|
Ok((inputs, None, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -138,6 +138,12 @@ pub(crate) struct VertexParameters {
|
||||||
#[schema(nullable = true, default = "null", example = "null")]
|
#[schema(nullable = true, default = "null", example = "null")]
|
||||||
pub guideline: Option<String>,
|
pub guideline: Option<String>,
|
||||||
|
|
||||||
|
/// A list of builtin_tools (these must be trained into the model.
|
||||||
|
/// See https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#built-in-python-based-tool-calling for more information.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = "null")]
|
||||||
|
pub builtin_tools: Option<Vec<String>>,
|
||||||
|
|
||||||
/// Options for streaming response. Only set this when you set stream: true.
|
/// Options for streaming response. Only set this when you set stream: true.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(nullable = true, example = "null")]
|
#[schema(nullable = true, example = "null")]
|
||||||
|
@ -150,6 +156,7 @@ impl From<VertexChat> for ChatRequest {
|
||||||
messages: val.messages,
|
messages: val.messages,
|
||||||
frequency_penalty: val.parameters.frequency_penalty,
|
frequency_penalty: val.parameters.frequency_penalty,
|
||||||
guideline: val.parameters.guideline,
|
guideline: val.parameters.guideline,
|
||||||
|
builtin_tools: val.parameters.builtin_tools,
|
||||||
logit_bias: val.parameters.logit_bias,
|
logit_bias: val.parameters.logit_bias,
|
||||||
logprobs: val.parameters.logprobs,
|
logprobs: val.parameters.logprobs,
|
||||||
max_tokens: val.parameters.max_tokens,
|
max_tokens: val.parameters.max_tokens,
|
||||||
|
|
Loading…
Reference in New Issue