Pr 2451 ci branch (#2454)
* fix[router]: Fix tools not passed in chat template Signed-off-by: GitHub <noreply@github.com> * feat: improve default tool serialization and lints * feat: refactor tool logic to include notify_error in prompt and adjust typing * fix: adjust non tool template apply * fix: simplify tool grammar logic and improve schema * feat: avoid skip tool test and avoid empty tool prompts * fix: increase test client timeout for grammar compilation tests --------- Signed-off-by: GitHub <noreply@github.com> Co-authored-by: Simone Rossi <simone.rossi.93@gmail.com>
This commit is contained in:
parent
30be188400
commit
cfa73b5c99
|
@ -2174,6 +2174,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "45f7e8e35b6c7b169bf40b0176d2c79291ab8ee53290b84e0668ab21d841aa9d"
|
checksum = "45f7e8e35b6c7b169bf40b0176d2c79291ab8ee53290b84e0668ab21d841aa9d"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
|
"serde_json",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
|
@ -757,7 +757,12 @@ class AsyncClient:
|
||||||
continue
|
continue
|
||||||
payload = byte_payload.decode("utf-8")
|
payload = byte_payload.decode("utf-8")
|
||||||
if payload.startswith("data:"):
|
if payload.startswith("data:"):
|
||||||
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
|
payload_data = (
|
||||||
|
payload.lstrip("data:").rstrip("\n").removeprefix(" ")
|
||||||
|
)
|
||||||
|
if payload_data == "[DONE]":
|
||||||
|
break
|
||||||
|
json_payload = json.loads(payload_data)
|
||||||
try:
|
try:
|
||||||
response = ChatCompletionChunk(**json_payload)
|
response = ChatCompletionChunk(**json_payload)
|
||||||
yield response
|
yield response
|
||||||
|
|
|
@ -924,7 +924,7 @@
|
||||||
"tool_prompt": {
|
"tool_prompt": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "A prompt to be appended before the tools",
|
"description": "A prompt to be appended before the tools",
|
||||||
"example": "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\"",
|
"example": "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.",
|
||||||
"nullable": true
|
"nullable": true
|
||||||
},
|
},
|
||||||
"tools": {
|
"tools": {
|
||||||
|
|
|
@ -257,7 +257,7 @@ class IgnoreLogProbResponseComparator(ResponseComparator):
|
||||||
|
|
||||||
class LauncherHandle:
|
class LauncherHandle:
|
||||||
def __init__(self, port: int):
|
def __init__(self, port: int):
|
||||||
self.client = AsyncClient(f"http://localhost:{port}")
|
self.client = AsyncClient(f"http://localhost:{port}", timeout=30)
|
||||||
|
|
||||||
def _inner_health(self):
|
def _inner_health(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
@ -36,6 +36,7 @@ tools = [
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"required": ["location", "format"],
|
"required": ["location", "format"],
|
||||||
|
"additionalProperties": False,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -62,13 +63,13 @@ tools = [
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"required": ["location", "format", "num_days"],
|
"required": ["location", "format", "num_days"],
|
||||||
|
"additionalProperties": False,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Takes too long to run")
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
|
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
|
||||||
|
@ -76,7 +77,7 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
|
||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
seed=1,
|
seed=1,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
presence_penalty=-1.1,
|
temperature=0.0,
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
|
@ -91,19 +92,18 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
|
||||||
assert response.choices[0].message.content is None
|
assert response.choices[0].message.content is None
|
||||||
assert response.choices[0].message.tool_calls == [
|
assert response.choices[0].message.tool_calls == [
|
||||||
{
|
{
|
||||||
"id": 0,
|
"id": "0",
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"description": None,
|
"description": None,
|
||||||
"name": "get_current_weather",
|
"name": "get_current_weather",
|
||||||
"arguments": {"format": "celsius", "location": "New York, NY"},
|
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Takes too long to run")
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_grammar_tools_auto(
|
async def test_flash_llama_grammar_tools_auto(
|
||||||
|
@ -113,8 +113,8 @@ async def test_flash_llama_grammar_tools_auto(
|
||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
seed=1,
|
seed=1,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
temperature=0.0,
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
presence_penalty=-1.1,
|
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
|
@ -129,12 +129,12 @@ async def test_flash_llama_grammar_tools_auto(
|
||||||
assert response.choices[0].message.content is None
|
assert response.choices[0].message.content is None
|
||||||
assert response.choices[0].message.tool_calls == [
|
assert response.choices[0].message.tool_calls == [
|
||||||
{
|
{
|
||||||
"id": 0,
|
"id": "0",
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"description": None,
|
"description": None,
|
||||||
"name": "get_current_weather",
|
"name": "get_current_weather",
|
||||||
"arguments": {"format": "celsius", "location": "New York, NY"},
|
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
@ -142,7 +142,6 @@ async def test_flash_llama_grammar_tools_auto(
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Takes too long to run")
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_grammar_tools_choice(
|
async def test_flash_llama_grammar_tools_choice(
|
||||||
|
@ -152,8 +151,8 @@ async def test_flash_llama_grammar_tools_choice(
|
||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
seed=1,
|
seed=1,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
temperature=0.0,
|
||||||
tool_choice="get_current_weather",
|
tool_choice="get_current_weather",
|
||||||
presence_penalty=-1.1,
|
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
|
@ -168,12 +167,12 @@ async def test_flash_llama_grammar_tools_choice(
|
||||||
assert response.choices[0].message.content is None
|
assert response.choices[0].message.content is None
|
||||||
assert response.choices[0].message.tool_calls == [
|
assert response.choices[0].message.tool_calls == [
|
||||||
{
|
{
|
||||||
"id": 0,
|
"id": "0",
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"description": None,
|
"description": None,
|
||||||
"name": "get_current_weather",
|
"name": "get_current_weather",
|
||||||
"arguments": {"format": "celsius", "location": "New York, NY"},
|
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
@ -181,7 +180,6 @@ async def test_flash_llama_grammar_tools_choice(
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Takes too long to run")
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_grammar_tools_stream(
|
async def test_flash_llama_grammar_tools_stream(
|
||||||
|
@ -191,8 +189,8 @@ async def test_flash_llama_grammar_tools_stream(
|
||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
seed=1,
|
seed=1,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
temperature=0.0,
|
||||||
tool_choice="get_current_weather",
|
tool_choice="get_current_weather",
|
||||||
presence_penalty=-1.1,
|
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
|
@ -210,11 +208,10 @@ async def test_flash_llama_grammar_tools_stream(
|
||||||
async for response in responses:
|
async for response in responses:
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
assert count == 38
|
assert count == 48
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Takes too long to run")
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_grammar_tools_insufficient_information(
|
async def test_flash_llama_grammar_tools_insufficient_information(
|
||||||
|
@ -222,13 +219,13 @@ async def test_flash_llama_grammar_tools_insufficient_information(
|
||||||
):
|
):
|
||||||
responses = await flash_llama_grammar_tools.chat(
|
responses = await flash_llama_grammar_tools.chat(
|
||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
seed=8,
|
seed=24,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": "ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
|
"content": "STRICTLY ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
|
@ -239,18 +236,7 @@ async def test_flash_llama_grammar_tools_insufficient_information(
|
||||||
)
|
)
|
||||||
|
|
||||||
assert responses.choices[0].message.content is None
|
assert responses.choices[0].message.content is None
|
||||||
assert responses.choices[0].message.tool_calls == [
|
assert (
|
||||||
{
|
responses.choices[0].message.tool_calls[0]["function"]["name"] == "notify_error"
|
||||||
"function": {
|
)
|
||||||
"arguments": {
|
|
||||||
"error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options."
|
|
||||||
},
|
|
||||||
"description": None,
|
|
||||||
"name": "notify_error",
|
|
||||||
},
|
|
||||||
"id": 0,
|
|
||||||
"type": "function",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
assert responses == response_snapshot
|
assert responses == response_snapshot
|
||||||
|
|
|
@ -46,7 +46,7 @@ ngrok = { version = "0.13.1", features = ["axum"], optional = true }
|
||||||
init-tracing-opentelemetry = { version = "0.14.1", features = [
|
init-tracing-opentelemetry = { version = "0.14.1", features = [
|
||||||
"opentelemetry-otlp",
|
"opentelemetry-otlp",
|
||||||
] }
|
] }
|
||||||
minijinja = { version = "2.0.2" }
|
minijinja = { version = "2.0.2", features = ["json"] }
|
||||||
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
|
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
|
||||||
futures-util = "0.3.30"
|
futures-util = "0.3.30"
|
||||||
regex = "1.10.3"
|
regex = "1.10.3"
|
||||||
|
|
|
@ -1,9 +1,7 @@
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
|
|
||||||
use crate::infer::InferError;
|
use crate::infer::InferError;
|
||||||
use crate::{
|
use crate::{ChatTemplateInputs, Message, MessageChunk, TextMessage, TokenizerConfigToken, Tool};
|
||||||
ChatTemplateInputs, GrammarType, Message, MessageChunk, TextMessage, TokenizerConfigToken,
|
|
||||||
};
|
|
||||||
use minijinja::{Environment, ErrorKind, Template};
|
use minijinja::{Environment, ErrorKind, Template};
|
||||||
use minijinja_contrib::pycompat;
|
use minijinja_contrib::pycompat;
|
||||||
|
|
||||||
|
@ -32,6 +30,7 @@ impl ChatTemplate {
|
||||||
env.set_unknown_method_callback(pycompat::unknown_method_callback);
|
env.set_unknown_method_callback(pycompat::unknown_method_callback);
|
||||||
let template_str = template.into_boxed_str();
|
let template_str = template.into_boxed_str();
|
||||||
env.add_function("raise_exception", raise_exception);
|
env.add_function("raise_exception", raise_exception);
|
||||||
|
tracing::debug!("Loading template: {:#?}", template_str);
|
||||||
|
|
||||||
// leaking env and template_str as read-only, static resources for performance.
|
// leaking env and template_str as read-only, static resources for performance.
|
||||||
let template = Box::leak(env)
|
let template = Box::leak(env)
|
||||||
|
@ -42,6 +41,7 @@ impl ChatTemplate {
|
||||||
let variables = template.undeclared_variables(true);
|
let variables = template.undeclared_variables(true);
|
||||||
// check if the `tools` variable is used in the template
|
// check if the `tools` variable is used in the template
|
||||||
let use_default_tool_template = !variables.contains("tools");
|
let use_default_tool_template = !variables.contains("tools");
|
||||||
|
tracing::debug!("Use default tool template: {}", use_default_tool_template);
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
template,
|
template,
|
||||||
|
@ -56,25 +56,36 @@ impl ChatTemplate {
|
||||||
&self,
|
&self,
|
||||||
guideline: Option<&str>,
|
guideline: Option<&str>,
|
||||||
mut messages: Vec<Message>,
|
mut messages: Vec<Message>,
|
||||||
grammar_with_prompt: Option<(GrammarType, String)>,
|
tools_and_prompt: Option<(Vec<Tool>, String)>,
|
||||||
) -> Result<String, InferError> {
|
) -> Result<String, InferError> {
|
||||||
if self.use_default_tool_template {
|
|
||||||
if let Some(last_message) = messages.last_mut() {
|
|
||||||
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
|
|
||||||
last_message.content.push(MessageChunk::Text {
|
|
||||||
text: format!("\n---\n{}\n{}", tool_prompt, tools),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
|
||||||
|
|
||||||
// check if guideline is expected but not provided
|
// check if guideline is expected but not provided
|
||||||
if self.variables.contains("guideline") && guideline.is_none() {
|
if self.variables.contains("guideline") && guideline.is_none() {
|
||||||
return Err(InferError::MissingTemplateVariable("guideline".to_string()));
|
return Err(InferError::MissingTemplateVariable("guideline".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let tools = match tools_and_prompt {
|
||||||
|
Some((tools, tool_prompt)) => {
|
||||||
|
// check if the `tools` variable is used in the template
|
||||||
|
// if not, we need to append the tools to the last message
|
||||||
|
let text = if self.use_default_tool_template {
|
||||||
|
match serde_json::to_string(&tools) {
|
||||||
|
Ok(tools_str) => format!("\n---\n{}\n{}", tools_str, tool_prompt),
|
||||||
|
Err(e) => return Err(InferError::ToolError(e.to_string())),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// if the `tools` variable is used in the template, we just append the tool_prompt
|
||||||
|
format!("\n---\n{}", tool_prompt)
|
||||||
|
};
|
||||||
|
if let Some(last_message) = messages.last_mut() {
|
||||||
|
last_message.content.push(MessageChunk::Text { text });
|
||||||
|
}
|
||||||
|
Some(tools)
|
||||||
|
}
|
||||||
|
None => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
||||||
|
|
||||||
self.template
|
self.template
|
||||||
.render(ChatTemplateInputs {
|
.render(ChatTemplateInputs {
|
||||||
guideline,
|
guideline,
|
||||||
|
@ -82,8 +93,7 @@ impl ChatTemplate {
|
||||||
bos_token: self.bos_token.as_deref(),
|
bos_token: self.bos_token.as_deref(),
|
||||||
eos_token: self.eos_token.as_deref(),
|
eos_token: self.eos_token.as_deref(),
|
||||||
add_generation_prompt: true,
|
add_generation_prompt: true,
|
||||||
tools: None,
|
tools,
|
||||||
tools_prompt: None,
|
|
||||||
})
|
})
|
||||||
.map_err(InferError::TemplateError)
|
.map_err(InferError::TemplateError)
|
||||||
}
|
}
|
||||||
|
@ -95,7 +105,7 @@ mod tests {
|
||||||
use crate::infer::chat_template::raise_exception;
|
use crate::infer::chat_template::raise_exception;
|
||||||
use crate::infer::ChatTemplate;
|
use crate::infer::ChatTemplate;
|
||||||
use crate::{
|
use crate::{
|
||||||
ChatTemplateInputs, GrammarType, Message, MessageContent, TextMessage, TokenizerConfigToken,
|
ChatTemplateInputs, Message, MessageContent, TextMessage, TokenizerConfigToken, Tool,
|
||||||
};
|
};
|
||||||
use minijinja::Environment;
|
use minijinja::Environment;
|
||||||
|
|
||||||
|
@ -854,11 +864,12 @@ mod tests {
|
||||||
content: MessageContent::SingleText("Just testing".to_string()),
|
content: MessageContent::SingleText("Just testing".to_string()),
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
let tools = serde_json::json!("[]");
|
let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string();
|
||||||
|
let tools: Vec<Tool> = serde_json::from_str(&tools_string).unwrap();
|
||||||
let tool_prompt = "This default prompt will be used".to_string();
|
let tool_prompt = "This default prompt will be used".to_string();
|
||||||
let grammer_with_prompt = (GrammarType::Json(tools), tool_prompt);
|
let tools_and_prompt = Some((tools, tool_prompt));
|
||||||
let result = ct.apply(None, msgs, Some(grammer_with_prompt));
|
let result = ct.apply(None, msgs, tools_and_prompt);
|
||||||
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\nThis default prompt will be used\n\"[]\" [/INST]".to_string();
|
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}}}]\nThis default prompt will be used [/INST]".to_string();
|
||||||
assert_eq!(result.unwrap(), expected);
|
assert_eq!(result.unwrap(), expected);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,7 +3,7 @@ mod chat_template;
|
||||||
pub mod tool_grammar;
|
pub mod tool_grammar;
|
||||||
|
|
||||||
use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
|
use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
|
||||||
use crate::GrammarType;
|
use crate::Tool;
|
||||||
use crate::{
|
use crate::{
|
||||||
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
|
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
|
||||||
Message, PrefillToken, Token,
|
Message, PrefillToken, Token,
|
||||||
|
@ -140,12 +140,12 @@ impl Infer {
|
||||||
&self,
|
&self,
|
||||||
guideline: Option<String>,
|
guideline: Option<String>,
|
||||||
messages: Vec<Message>,
|
messages: Vec<Message>,
|
||||||
grammar_with_prompt: Option<(GrammarType, String)>,
|
tools_and_prompt: Option<(Vec<Tool>, String)>,
|
||||||
) -> Result<String, InferError> {
|
) -> Result<String, InferError> {
|
||||||
self.chat_template
|
self.chat_template
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
||||||
.apply(guideline.as_deref(), messages, grammar_with_prompt)
|
.apply(guideline.as_deref(), messages, tools_and_prompt)
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
metrics::counter!("tgi_request_failure", "err" => "template").increment(1);
|
metrics::counter!("tgi_request_failure", "err" => "template").increment(1);
|
||||||
tracing::error!("{e}");
|
tracing::error!("{e}");
|
||||||
|
|
|
@ -1,5 +1,8 @@
|
||||||
use crate::infer::InferError;
|
use crate::infer::InferError;
|
||||||
use crate::{FunctionRef, FunctionsMap, Properties, Tool, ToolChoice, ToolType, Tools};
|
use crate::{
|
||||||
|
FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice,
|
||||||
|
ToolType,
|
||||||
|
};
|
||||||
use serde_json::{json, Map, Value};
|
use serde_json::{json, Map, Value};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
@ -16,17 +19,38 @@ impl ToolGrammar {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn apply(
|
pub fn apply(
|
||||||
tools: Option<Vec<Tool>>,
|
tools: Vec<Tool>,
|
||||||
tool_choice: ToolChoice,
|
tool_choice: ToolChoice,
|
||||||
) -> Result<Option<Tools>, InferError> {
|
) -> Result<(Vec<Tool>, Option<JsonSchemaTool>), InferError> {
|
||||||
// if no tools are provided, we return None
|
// if no tools are provided, we return None
|
||||||
let tools = match tools {
|
if tools.is_empty() {
|
||||||
Some(tools) if !tools.is_empty() => tools,
|
return Ok((tools, None));
|
||||||
_ => return Ok(None),
|
}
|
||||||
};
|
|
||||||
|
|
||||||
let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
|
let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
|
||||||
|
|
||||||
|
let mut tools = tools.clone();
|
||||||
|
|
||||||
|
// add the notify_error function to the tools
|
||||||
|
let notify_error = Tool {
|
||||||
|
r#type: "function".to_string(),
|
||||||
|
function: FunctionDefinition {
|
||||||
|
name: "notify_error".to_string(),
|
||||||
|
description: Some("Notify an error or issue".to_string()),
|
||||||
|
arguments: json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"error": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The error or issue to notify"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["error"]
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
tools.push(notify_error);
|
||||||
|
|
||||||
// if tools are provided and no tool_choice we default to the OneOf
|
// if tools are provided and no tool_choice we default to the OneOf
|
||||||
let tools_to_use = match tool_choice {
|
let tools_to_use = match tool_choice {
|
||||||
ToolType::FunctionName(name) => {
|
ToolType::FunctionName(name) => {
|
||||||
|
@ -35,87 +59,57 @@ impl ToolGrammar {
|
||||||
ToolType::Function { function } => {
|
ToolType::Function { function } => {
|
||||||
vec![Self::find_tool_by_name(&tools, &function.name)?]
|
vec![Self::find_tool_by_name(&tools, &function.name)?]
|
||||||
}
|
}
|
||||||
ToolType::OneOf => tools,
|
ToolType::OneOf => tools.clone(),
|
||||||
ToolType::NoTool => return Ok(None),
|
ToolType::NoTool => return Ok((tools, None)),
|
||||||
};
|
};
|
||||||
|
|
||||||
// adds the error notification function for LLM feedback if required
|
|
||||||
let mut text_response_properties = Map::new();
|
|
||||||
text_response_properties.insert(
|
|
||||||
"error".to_string(),
|
|
||||||
serde_json::json!({
|
|
||||||
"type": "string",
|
|
||||||
"description": "The error or issue to notify"
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
text_response_properties.insert(
|
|
||||||
"_name".to_string(),
|
|
||||||
serde_json::json!({
|
|
||||||
"type": "string",
|
|
||||||
"const": "notify_error"
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
|
|
||||||
let functions: HashMap<String, serde_json::Value> = tools_to_use
|
let functions: HashMap<String, serde_json::Value> = tools_to_use
|
||||||
.iter()
|
.iter()
|
||||||
.map(|tool| {
|
.map(|tool| {
|
||||||
let func = tool.function.clone();
|
let func = tool.function.clone();
|
||||||
|
|
||||||
// Clone the existing parameters, which are expected to be a JSON object
|
let mut params = Map::new();
|
||||||
let mut params = if let Value::Object(params) = &func.arguments {
|
|
||||||
params.clone()
|
|
||||||
} else {
|
|
||||||
Map::new()
|
|
||||||
};
|
|
||||||
|
|
||||||
// Insert the function's description at the top level, outside of properties
|
|
||||||
params.insert(
|
params.insert(
|
||||||
"description".to_string(),
|
"description".to_string(),
|
||||||
Value::String(func.description.clone().unwrap_or_default()),
|
Value::String(func.description.unwrap_or_default()),
|
||||||
);
|
);
|
||||||
|
|
||||||
// Ensure 'properties' exists and is an object
|
let mut properties = Map::new();
|
||||||
let properties = params
|
let mut required = vec![Value::String("_name".to_string())];
|
||||||
.entry("properties".to_string())
|
|
||||||
.or_insert_with(|| json!({}))
|
|
||||||
.as_object_mut()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
// Insert the constant for the function name inside 'properties'
|
|
||||||
properties.insert(
|
properties.insert(
|
||||||
"_name".to_string(),
|
"_name".to_string(),
|
||||||
json!({
|
json!({
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"const": func.name.clone(),
|
"const": func.name.clone(),
|
||||||
// "description": "The name of the function"
|
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
|
|
||||||
// Check if 'required' exists, and it is an array. If not, create an empty array.
|
if let Value::Object(args) = func.arguments {
|
||||||
let required = params
|
if let Some(Value::Object(props)) = args.get("properties") {
|
||||||
.entry("required".to_string())
|
properties.extend(props.clone());
|
||||||
.or_insert_with(|| json!([]))
|
|
||||||
.as_array_mut()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
// Add 'name' to the 'required' array if it is not already present
|
|
||||||
if !required.iter().any(|r| r == "_name") {
|
|
||||||
required.push(json!("_name"));
|
|
||||||
}
|
}
|
||||||
|
if let Some(Value::Array(reqs)) = args.get("required") {
|
||||||
|
required.extend(reqs.clone());
|
||||||
|
}
|
||||||
|
params.insert(
|
||||||
|
"additionalProperties".to_string(),
|
||||||
|
Value::Bool(
|
||||||
|
args.get("additionalProperties").and_then(|v| v.as_str())
|
||||||
|
== Some("true"),
|
||||||
|
),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
params.insert("properties".to_string(), Value::Object(properties));
|
||||||
|
params.insert("required".to_string(), Value::Array(required));
|
||||||
|
|
||||||
(func.name, Value::Object(params))
|
(func.name, Value::Object(params))
|
||||||
})
|
})
|
||||||
.chain([(
|
|
||||||
"notify_error".to_string(),
|
|
||||||
serde_json::json!({
|
|
||||||
"properties": text_response_properties,
|
|
||||||
"required": ["error", "_name"],
|
|
||||||
"type": "object"
|
|
||||||
}),
|
|
||||||
)])
|
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let tools = Tools {
|
let tool_schema = JsonSchemaTool {
|
||||||
functions_map: FunctionsMap { functions },
|
functions_map: FunctionsMap { functions },
|
||||||
properties: Properties {
|
properties: Properties {
|
||||||
function: tools_to_use
|
function: tools_to_use
|
||||||
|
@ -123,13 +117,10 @@ impl ToolGrammar {
|
||||||
.map(|tool| FunctionRef {
|
.map(|tool| FunctionRef {
|
||||||
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
|
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
|
||||||
})
|
})
|
||||||
.chain(std::iter::once(FunctionRef {
|
|
||||||
ref_path: "#/$functions/notify_error".to_string(),
|
|
||||||
}))
|
|
||||||
.collect(),
|
.collect(),
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(Some(tools))
|
Ok((tools, Some(tool_schema)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -840,10 +840,10 @@ pub(crate) struct ChatRequest {
|
||||||
pub tools: Option<Vec<Tool>>,
|
pub tools: Option<Vec<Tool>>,
|
||||||
|
|
||||||
/// A prompt to be appended before the tools
|
/// A prompt to be appended before the tools
|
||||||
#[serde(default = "default_tool_prompt")]
|
#[serde(default)]
|
||||||
#[schema(
|
#[schema(
|
||||||
nullable = true,
|
nullable = true,
|
||||||
example = "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\""
|
example = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables."
|
||||||
)]
|
)]
|
||||||
pub tool_prompt: Option<String>,
|
pub tool_prompt: Option<String>,
|
||||||
|
|
||||||
|
@ -865,10 +865,8 @@ pub(crate) struct ChatRequest {
|
||||||
pub guideline: Option<String>,
|
pub guideline: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_tool_prompt() -> Option<String> {
|
pub fn default_tool_prompt() -> String {
|
||||||
Some(
|
"\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string()
|
||||||
"\nYou will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n".to_string(),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)]
|
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)]
|
||||||
|
@ -910,7 +908,7 @@ impl From<ToolTypeDeserializer> for ToolChoice {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Serialize, ToSchema, PartialEq)]
|
#[derive(Debug, Deserialize, Serialize, ToSchema, PartialEq)]
|
||||||
pub struct Tools {
|
pub struct JsonSchemaTool {
|
||||||
#[serde(flatten)]
|
#[serde(flatten)]
|
||||||
functions_map: FunctionsMap,
|
functions_map: FunctionsMap,
|
||||||
properties: Properties,
|
properties: Properties,
|
||||||
|
@ -968,8 +966,7 @@ pub(crate) struct ChatTemplateInputs<'a> {
|
||||||
bos_token: Option<&'a str>,
|
bos_token: Option<&'a str>,
|
||||||
eos_token: Option<&'a str>,
|
eos_token: Option<&'a str>,
|
||||||
add_generation_prompt: bool,
|
add_generation_prompt: bool,
|
||||||
tools: Option<&'a str>,
|
tools: Option<Vec<Tool>>,
|
||||||
tools_prompt: Option<&'a str>,
|
|
||||||
guideline: Option<&'a str>,
|
guideline: Option<&'a str>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ use crate::kserve::{
|
||||||
kserve_model_metadata, kserve_model_metadata_ready,
|
kserve_model_metadata, kserve_model_metadata_ready,
|
||||||
};
|
};
|
||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
use crate::ChatTokenizeResponse;
|
use crate::{default_tool_prompt, ChatTokenizeResponse};
|
||||||
use crate::{
|
use crate::{
|
||||||
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
|
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
|
||||||
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
||||||
|
@ -23,7 +23,7 @@ use crate::{
|
||||||
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest,
|
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest,
|
||||||
VertexResponse,
|
VertexResponse,
|
||||||
};
|
};
|
||||||
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType, Tools};
|
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType};
|
||||||
use async_stream::__private::AsyncStream;
|
use async_stream::__private::AsyncStream;
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, HeaderValue, Method, StatusCode};
|
use axum::http::{HeaderMap, HeaderValue, Method, StatusCode};
|
||||||
|
@ -146,7 +146,7 @@ async fn get_chat_tokenize(
|
||||||
} = req;
|
} = req;
|
||||||
|
|
||||||
let tool_prompt = tool_prompt.unwrap_or_default();
|
let tool_prompt = tool_prompt.unwrap_or_default();
|
||||||
let (inputs, _grammar, _tool_grammar) = prepare_chat_input(
|
let (inputs, _grammar, _using_tools) = prepare_chat_input(
|
||||||
&infer,
|
&infer,
|
||||||
response_format,
|
response_format,
|
||||||
tools,
|
tools,
|
||||||
|
@ -1158,14 +1158,16 @@ async fn chat_completions(
|
||||||
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
|
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
|
||||||
let max_new_tokens = max_tokens.or(Some(100));
|
let max_new_tokens = max_tokens.or(Some(100));
|
||||||
let logprobs = logprobs.unwrap_or(false);
|
let logprobs = logprobs.unwrap_or(false);
|
||||||
let tool_prompt = tool_prompt.unwrap_or_default();
|
let tool_prompt = tool_prompt
|
||||||
|
.filter(|s| !s.is_empty())
|
||||||
|
.unwrap_or_else(default_tool_prompt);
|
||||||
let stop = stop.unwrap_or_default();
|
let stop = stop.unwrap_or_default();
|
||||||
// enable greedy only when temperature is 0
|
// enable greedy only when temperature is 0
|
||||||
let (do_sample, temperature) = match temperature {
|
let (do_sample, temperature) = match temperature {
|
||||||
Some(temperature) if temperature == 0.0 => (false, None),
|
Some(temperature) if temperature == 0.0 => (false, None),
|
||||||
other => (true, other),
|
other => (true, other),
|
||||||
};
|
};
|
||||||
let (inputs, grammar, tool_grammar) = prepare_chat_input(
|
let (inputs, grammar, using_tools) = prepare_chat_input(
|
||||||
&infer,
|
&infer,
|
||||||
response_format,
|
response_format,
|
||||||
tools,
|
tools,
|
||||||
|
@ -1221,7 +1223,7 @@ async fn chat_completions(
|
||||||
});
|
});
|
||||||
|
|
||||||
// replace the content with the tool calls if grammar is present
|
// replace the content with the tool calls if grammar is present
|
||||||
let (content, tool_calls) = if tool_grammar.is_some() {
|
let (content, tool_calls) = if using_tools {
|
||||||
(None, Some(vec![stream_token.token.text]))
|
(None, Some(vec![stream_token.token.text]))
|
||||||
} else {
|
} else {
|
||||||
let content = if !stream_token.token.special {
|
let content = if !stream_token.token.special {
|
||||||
|
@ -1275,7 +1277,7 @@ async fn chat_completions(
|
||||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||||
.as_secs();
|
.as_secs();
|
||||||
|
|
||||||
let (tool_calls, output) = if tool_grammar.is_some() {
|
let (tool_calls, output) = if using_tools {
|
||||||
let gen_text_value: Value =
|
let gen_text_value: Value =
|
||||||
serde_json::from_str(&generation.generated_text).map_err(|e| {
|
serde_json::from_str(&generation.generated_text).map_err(|e| {
|
||||||
InferError::ToolError(format!(
|
InferError::ToolError(format!(
|
||||||
|
@ -2539,7 +2541,7 @@ fn create_post_processor(
|
||||||
Ok(post_processor)
|
Ok(post_processor)
|
||||||
}
|
}
|
||||||
|
|
||||||
type PreparedInput = (String, Option<GrammarType>, Option<Tools>);
|
type PreparedInput = (String, Option<GrammarType>, bool);
|
||||||
|
|
||||||
fn prepare_chat_input(
|
fn prepare_chat_input(
|
||||||
infer: &Infer,
|
infer: &Infer,
|
||||||
|
@ -2556,19 +2558,139 @@ fn prepare_chat_input(
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// when response_format is set, tools are not included when applying the chat template to generate inputs
|
||||||
if let Some(format) = response_format {
|
if let Some(format) = response_format {
|
||||||
let inputs = infer.apply_chat_template(guideline, messages, None)?;
|
let inputs = infer.apply_chat_template(guideline, messages, None)?;
|
||||||
return Ok((inputs, Some(format), None));
|
return Ok((inputs, Some(format), false));
|
||||||
}
|
}
|
||||||
|
|
||||||
// if tools are set, apply the tool grammar and then the chat template
|
// when no response_format is set and tools are included, apply the chat template with the tools
|
||||||
let tool_grammar: Option<Tools> = ToolGrammar::apply(tools, tool_choice)?;
|
// to generate inputs
|
||||||
let grammar = tool_grammar
|
if let Some(tools) = tools {
|
||||||
|
let (updated_tools, tool_schema) = ToolGrammar::apply(tools, tool_choice)?;
|
||||||
|
|
||||||
|
let grammar = tool_schema
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.map(|t| GrammarType::Json(serde_json::json!(t)));
|
.map(|t| GrammarType::Json(serde_json::json!(t)));
|
||||||
let tools_grammar_prompt = tool_grammar
|
|
||||||
.as_ref()
|
let inputs: String = infer.apply_chat_template(
|
||||||
.map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt.into()));
|
guideline,
|
||||||
let inputs = infer.apply_chat_template(guideline, messages, tools_grammar_prompt)?;
|
messages,
|
||||||
Ok((inputs, grammar, tool_grammar))
|
Some((updated_tools, tool_prompt.into())),
|
||||||
|
)?;
|
||||||
|
return Ok((inputs, grammar, tool_schema.is_some()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// if no response_format or tools are set simply apply the chat template to generate inputs
|
||||||
|
let inputs = infer.apply_chat_template(guideline, messages, None)?;
|
||||||
|
Ok((inputs, None, false))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::ChatTemplateVersions;
|
||||||
|
use crate::HubTokenizerConfig;
|
||||||
|
use crate::TokenizerConfigToken;
|
||||||
|
use crate::Tool;
|
||||||
|
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_prepare_chat_input() {
|
||||||
|
// Mock Backend to avoid network requests
|
||||||
|
struct MockBackend;
|
||||||
|
|
||||||
|
impl Backend for MockBackend {
|
||||||
|
fn schedule(
|
||||||
|
&self,
|
||||||
|
_request: crate::validation::ValidGenerateRequest,
|
||||||
|
) -> Result<
|
||||||
|
tokio_stream::wrappers::UnboundedReceiverStream<
|
||||||
|
Result<InferStreamResponse, InferError>,
|
||||||
|
>,
|
||||||
|
InferError,
|
||||||
|
> {
|
||||||
|
unimplemented!("Never called in this test");
|
||||||
|
}
|
||||||
|
fn health<'a, 'async_trait>(
|
||||||
|
&'a self,
|
||||||
|
_current_health: bool,
|
||||||
|
) -> core::pin::Pin<
|
||||||
|
Box<dyn core::future::Future<Output = bool> + core::marker::Send + 'async_trait>,
|
||||||
|
>
|
||||||
|
where
|
||||||
|
'a: 'async_trait,
|
||||||
|
Self: 'async_trait,
|
||||||
|
{
|
||||||
|
unimplemented!("Never called in this test");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let backend = MockBackend {};
|
||||||
|
|
||||||
|
let mut tokenizer_config = HubTokenizerConfig::default();
|
||||||
|
|
||||||
|
// mock tokenizer config values
|
||||||
|
tokenizer_config.bos_token = Some(TokenizerConfigToken::String("<s>".to_string()));
|
||||||
|
tokenizer_config.eos_token = Some(TokenizerConfigToken::String("</s>".to_string()));
|
||||||
|
tokenizer_config.chat_template = Some(
|
||||||
|
ChatTemplateVersions::Single("{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n".to_string())
|
||||||
|
);
|
||||||
|
|
||||||
|
let infer = Infer::new(
|
||||||
|
backend,
|
||||||
|
Validation::new(1, None, None, None, 1, 1, 1, 1, 1, false),
|
||||||
|
1,
|
||||||
|
tokenizer_config,
|
||||||
|
HubProcessorConfig::default(),
|
||||||
|
);
|
||||||
|
let response_format = None;
|
||||||
|
let tools = Some(vec![Tool {
|
||||||
|
r#type: "function".to_string(),
|
||||||
|
function: FunctionDefinition {
|
||||||
|
name: "get_current_weather".to_string(),
|
||||||
|
description: Some("Get the current weather".to_string()),
|
||||||
|
arguments: json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA"
|
||||||
|
},
|
||||||
|
"format": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"description": "The temperature unit to use. Infer this from the users location."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location", "format"]
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
}]);
|
||||||
|
let tool_prompt = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.";
|
||||||
|
let guideline = None;
|
||||||
|
let messages = vec![Message {
|
||||||
|
name: None,
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: MessageContent::SingleText(
|
||||||
|
"What is the weather like in New York?".to_string(),
|
||||||
|
),
|
||||||
|
}];
|
||||||
|
|
||||||
|
let result = prepare_chat_input(
|
||||||
|
&infer,
|
||||||
|
response_format,
|
||||||
|
tools,
|
||||||
|
ToolChoice(None),
|
||||||
|
tool_prompt,
|
||||||
|
guideline,
|
||||||
|
messages,
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(result.is_ok());
|
||||||
|
let (inputs, _grammar, using_tools) = result.unwrap();
|
||||||
|
assert_eq!(using_tools, true);
|
||||||
|
assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"error\":{\"description\":\"The error or issue to notify\",\"type\":\"string\"}},\"required\":[\"error\"],\"type\":\"object\"}, \"description\": \"Notify an error or issue\", \"name\": \"notify_error\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue