feat: allow tool calling to respond without a tool (#2614)

* feat: process token stream before returning to client

* fix: expect content in test

* fix: improve comparison via ruff lint

* fix: return event in all cases

* fix: always send event on error, avoid unwraps, refactor and improve tests

* fix: prefer no_tool over notify_error to improve reponse

* fix: adjust chat input test for no_tool

* fix: adjust test expected content

---------

Co-authored-by: System administrator <root@ip-10-90-0-186.ec2.internal>
This commit is contained in:
drbh 2024-10-10 09:28:25 -04:00 committed by GitHub
parent 43f39f6894
commit e36dfaa8de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 472 additions and 196 deletions

View File

@ -1,38 +1,26 @@
{ {
"choices": [ "choices": [
{ {
"finish_reason": "eos_token", "finish_reason": "stop",
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"message": { "message": {
"content": null, "content": "I am an AI assistant",
"name": null, "name": null,
"role": "assistant", "role": "assistant",
"tool_calls": [ "tool_calls": null
{
"function": {
"arguments": {
"error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options."
},
"description": null,
"name": "notify_error"
},
"id": 0,
"type": "function"
}
]
}, },
"usage": null "usage": null
} }
], ],
"created": 1712852597, "created": 1728497062,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "chat.completion",
"system_fingerprint": "1.4.5-native", "system_fingerprint": "2.3.2-dev0-native",
"usage": { "usage": {
"completion_tokens": 39, "completion_tokens": 23,
"prompt_tokens": 496, "prompt_tokens": 604,
"total_tokens": 535 "total_tokens": 627
} }
} }

View File

@ -0,0 +1,20 @@
{
"choices": [
{
"delta": {
"content": " assistant",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1728497531,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.3.2-dev0-native",
"usage": null
}

View File

@ -0,0 +1,20 @@
{
"choices": [
{
"delta": {
"content": " fans",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1728497461,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.3.2-dev0-native",
"usage": null
}

View File

@ -207,11 +207,20 @@ async def test_flash_llama_grammar_tools_stream(
) )
count = 0 count = 0
tool_calls_generated = ""
last_response = None
async for response in responses: async for response in responses:
count += 1 count += 1
tool_calls_generated += response.choices[0].delta.tool_calls.function.arguments
last_response = response
assert response.choices[0].delta.content is None
assert (
tool_calls_generated
== '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "Paris, France"}}<|eot_id|>'
)
assert count == 28 assert count == 28
assert response == response_snapshot assert last_response == response_snapshot
@pytest.mark.asyncio @pytest.mark.asyncio
@ -227,18 +236,94 @@ async def test_flash_llama_grammar_tools_insufficient_information(
messages=[ messages=[
{ {
"role": "system", "role": "system",
"content": "STRICTLY ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION", "content": "You're a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "Who are you?",
},
],
stream=False,
)
assert responses.choices[0].message.tool_calls is None
assert responses.choices[0].message.content == "I am an AI assistant"
assert responses == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_insufficient_information_stream(
flash_llama_grammar_tools, response_snapshot
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=24,
tools=tools,
tool_choice="auto",
messages=[
{
"role": "system",
"content": "You're a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "Who are you?",
},
],
stream=True,
)
count = 0
content_generated = ""
last_response = None
async for response in responses:
count += 1
content_generated += response.choices[0].delta.content
last_response = response
assert response.choices[0].delta.tool_calls is None
assert count == 5
assert content_generated == "I am an AI assistant"
assert last_response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_sea_creatures_stream(
flash_llama_grammar_tools, response_snapshot
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=24,
tools=tools,
tool_choice="auto",
messages=[
{
"role": "system",
"content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.",
}, },
{ {
"role": "user", "role": "user",
"content": "Tell me a story about 3 sea creatures", "content": "Tell me a story about 3 sea creatures",
}, },
], ],
stream=False, stream=True,
) )
assert responses.choices[0].message.content is None count = 0
content_generated = ""
last_response = None
async for response in responses:
count += 1
content_generated += response.choices[0].delta.content
last_response = response
assert response.choices[0].delta.tool_calls is None
assert count == 62
assert ( assert (
responses.choices[0].message.tool_calls[0]["function"]["name"] == "notify_error" content_generated
== "Once upon a time, in the ocean, there lived three sea creatures. There was a wise old octopus named Bob, a mischievous seagull named Sam, and a gentle sea turtle named Luna. They all lived together in a beautiful coral reef, surrounded by colorful fish and swaying sea fans"
) )
assert responses == response_snapshot assert last_response == response_snapshot

View File

@ -355,6 +355,8 @@ pub enum InferError {
MissingTemplateVariable(String), MissingTemplateVariable(String),
#[error("Tool error: {0}")] #[error("Tool error: {0}")]
ToolError(String), ToolError(String),
#[error("Stream event serialization error")]
StreamSerializationError(String),
} }
impl InferError { impl InferError {
@ -368,6 +370,7 @@ impl InferError {
InferError::TemplateError(_) => "template_error", InferError::TemplateError(_) => "template_error",
InferError::MissingTemplateVariable(_) => "missing_template_variable", InferError::MissingTemplateVariable(_) => "missing_template_variable",
InferError::ToolError(_) => "tool_error", InferError::ToolError(_) => "tool_error",
InferError::StreamSerializationError(_) => "stream_serialization_error",
} }
} }
} }

View File

@ -31,25 +31,25 @@ impl ToolGrammar {
let mut tools = tools.clone(); let mut tools = tools.clone();
// add the notify_error function to the tools // add the no_tool function to the tools
let notify_error = Tool { let no_tool = Tool {
r#type: "function".to_string(), r#type: "function".to_string(),
function: FunctionDefinition { function: FunctionDefinition {
name: "notify_error".to_string(), name: "no_tool".to_string(),
description: Some("Notify an error or issue".to_string()), description: Some("Open ened response with no specific tool selected".to_string()),
arguments: json!({ arguments: json!({
"type": "object", "type": "object",
"properties": { "properties": {
"error": { "content": {
"type": "string", "type": "string",
"description": "The error or issue to notify" "description": "The response content",
} }
}, },
"required": ["error"] "required": ["content"]
}), }),
}, },
}; };
tools.push(notify_error); tools.push(no_tool);
// if tools are provided and no tool_choice we default to the OneOf // if tools are provided and no tool_choice we default to the OneOf
let tools_to_use = match tool_choice { let tools_to_use = match tool_choice {

View File

@ -42,6 +42,7 @@ use hf_hub::{Cache, Repo, RepoType};
use http::header::AUTHORIZATION; use http::header::AUTHORIZATION;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use pyo3::types::IntoPyDict; use pyo3::types::IntoPyDict;
use regex::Regex;
use serde_json::Value; use serde_json::Value;
use std::convert::Infallible; use std::convert::Infallible;
use std::fs::File; use std::fs::File;
@ -452,12 +453,20 @@ async fn generate_stream(
Sse<impl Stream<Item = Result<Event, Infallible>>>, Sse<impl Stream<Item = Result<Event, Infallible>>>,
) { ) {
let span = tracing::Span::current(); let span = tracing::Span::current();
let on_message_callback = |stream_token: StreamResponse| {
let event = Event::default();
event.json_data(stream_token).unwrap()
};
let (headers, response_stream) = let (headers, response_stream) =
generate_stream_internal(infer, compute_type, Json(req), on_message_callback, span).await; generate_stream_internal(infer, compute_type, Json(req), span).await;
let response_stream = async_stream::stream! {
let mut response_stream = Box::pin(response_stream);
while let Some(raw_event) = response_stream.next().await {
yield Ok(raw_event.map_or_else(Event::from, |token| {
Event::default()
.json_data(token)
.unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into())
}));
}
};
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
(headers, sse) (headers, sse)
} }
@ -466,9 +475,11 @@ async fn generate_stream_internal(
infer: Infer, infer: Infer,
ComputeType(compute_type): ComputeType, ComputeType(compute_type): ComputeType,
Json(req): Json<GenerateRequest>, Json(req): Json<GenerateRequest>,
on_message_callback: impl Fn(StreamResponse) -> Event,
span: tracing::Span, span: tracing::Span,
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) { ) -> (
HeaderMap,
impl Stream<Item = Result<StreamResponse, InferError>>,
) {
let start_time = Instant::now(); let start_time = Instant::now();
metrics::counter!("tgi_request_count").increment(1); metrics::counter!("tgi_request_count").increment(1);
@ -500,12 +511,12 @@ async fn generate_stream_internal(
let err = InferError::from(ValidationError::BestOfStream); let err = InferError::from(ValidationError::BestOfStream);
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
tracing::error!("{err}"); tracing::error!("{err}");
yield Ok(Event::from(err)); yield Err(err);
} else if req.parameters.decoder_input_details { } else if req.parameters.decoder_input_details {
let err = InferError::from(ValidationError::PrefillDetailsStream); let err = InferError::from(ValidationError::PrefillDetailsStream);
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
tracing::error!("{err}"); tracing::error!("{err}");
yield Ok(Event::from(err)); yield Err(err);
} else { } else {
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
// Keep permit as long as generate_stream lives // Keep permit as long as generate_stream lives
@ -535,8 +546,7 @@ async fn generate_stream_internal(
generated_text: None, generated_text: None,
details: None, details: None,
}; };
let event = on_message_callback(stream_token); yield Ok(stream_token);
yield Ok(event);
} }
// Yield event for last token and compute timings // Yield event for last token and compute timings
InferStreamResponse::End { InferStreamResponse::End {
@ -600,9 +610,7 @@ async fn generate_stream_internal(
details details
}; };
yield Ok(stream_token);
let event = on_message_callback(stream_token);
yield Ok(event);
break; break;
} }
} }
@ -610,7 +618,7 @@ async fn generate_stream_internal(
// yield error // yield error
Err(err) => { Err(err) => {
error = true; error = true;
yield Ok(Event::from(err)); yield Err(err);
break; break;
} }
} }
@ -619,7 +627,7 @@ async fn generate_stream_internal(
// yield error // yield error
Err(err) => { Err(err) => {
error = true; error = true;
yield Ok(Event::from(err)); yield Err(err);
} }
} }
// Check if generation reached the end // Check if generation reached the end
@ -628,7 +636,7 @@ async fn generate_stream_internal(
let err = InferError::IncompleteGenerationStream; let err = InferError::IncompleteGenerationStream;
metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1); metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1);
tracing::error!("{err}"); tracing::error!("{err}");
yield Ok(Event::from(err)); yield Err(err);
} }
} }
}; };
@ -771,75 +779,85 @@ async fn completions(
// Create a future for each generate_stream_internal call. // Create a future for each generate_stream_internal call.
let generate_future = async move { let generate_future = async move {
let on_message_callback = move |stream_token: StreamResponse| {
let event = Event::default();
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
let message = match stream_token.details {
Some(details) => {
let completion_tokens = details.generated_tokens;
let prompt_tokens = details.input_length;
let total_tokens = prompt_tokens + completion_tokens;
Completion::Final(CompletionFinal {
id: String::new(),
created: current_time,
model: model_id.clone(),
system_fingerprint: system_fingerprint.clone(),
choices: vec![CompletionComplete {
finish_reason: details.finish_reason.to_string(),
index: index as u32,
logprobs: None,
text: stream_token.token.text,
}],
usage: Usage {
prompt_tokens,
completion_tokens,
total_tokens,
},
})
}
None => Completion::Chunk(Chunk {
id: String::new(),
created: current_time,
choices: vec![CompletionComplete {
finish_reason: String::new(),
index: index as u32,
logprobs: None,
text: stream_token.token.text,
}],
model: model_id.clone(),
system_fingerprint: system_fingerprint.clone(),
}),
};
event
.json_data(message)
.unwrap_or_else(|_e| Event::default())
};
let (header_tx, header_rx) = oneshot::channel(); let (header_tx, header_rx) = oneshot::channel();
let (sse_tx, sse_rx) = tokio::sync::mpsc::unbounded_channel(); let (sse_tx, sse_rx) = tokio::sync::mpsc::unbounded_channel();
tokio::spawn(async move { tokio::spawn(async move {
let (header_map, sse) = generate_stream_internal( let (headers, response_stream) = generate_stream_internal(
infer_clone.clone(), infer_clone.clone(),
compute_type_clone.clone(), compute_type_clone.clone(),
Json(generate_request), Json(generate_request),
on_message_callback,
span_clone.clone(), span_clone.clone(),
) )
.await; .await;
let response_stream = async_stream::stream! {
let mut response_stream = Box::pin(response_stream);
while let Some(stream_token) = response_stream.next().await {
match stream_token {
Ok(stream_token) => {
let event = Event::default();
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
let message = match stream_token.details {
Some(details) => {
let completion_tokens = details.generated_tokens;
let prompt_tokens = details.input_length;
let total_tokens = prompt_tokens + completion_tokens;
Completion::Final(CompletionFinal {
id: String::new(),
created: current_time,
model: model_id.clone(),
system_fingerprint: system_fingerprint.clone(),
choices: vec![CompletionComplete {
finish_reason: details.finish_reason.to_string(),
index: index as u32,
logprobs: None,
text: stream_token.token.text,
}],
usage: Usage {
prompt_tokens,
completion_tokens,
total_tokens,
},
})
}
None => Completion::Chunk(Chunk {
id: String::new(),
created: current_time,
choices: vec![CompletionComplete {
finish_reason: String::new(),
index: index as u32,
logprobs: None,
text: stream_token.token.text,
}],
model: model_id.clone(),
system_fingerprint: system_fingerprint.clone(),
}),
};
let event = event
.json_data(message)
.unwrap_or_else(|_e| Event::default());
yield Ok(event);
}
Err(err) => yield Ok(Event::from(err)),
}
}
};
// send and dont wait for response // send and dont wait for response
let _ = header_tx.send(header_map); let _ = header_tx.send(headers);
// pin an emit messages to the sse_tx // pin an emit messages to the sse_tx
let mut sse = Box::pin(sse); let mut sse = Box::pin(response_stream);
while let Some(event) = sse.next().await { while let Some(event) = sse.next().await {
if sse_tx.send(event).is_err() { if sse_tx.send(event).is_err() {
tracing::error!("Failed to send event. Receiver dropped."); tracing::error!("Failed to send event. Receiver dropped.");
@ -1072,6 +1090,84 @@ async fn completions(
} }
} }
enum StreamState {
Buffering,
BufferTrailing,
Content { skip_close_quote: bool },
}
/// Convert a StreamResponse into an Event to be sent over SSE
fn create_event_from_stream_token(
stream_token: &StreamResponse,
logprobs: bool,
stream_options: Option<StreamOptions>,
inner_using_tools: bool,
system_fingerprint: String,
model_id: String,
) -> Event {
let event = Event::default();
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
let logprobs = logprobs.then(|| {
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens.clone()))
});
// replace the content with the tool calls if grammar is present
let (content, tool_calls) = if inner_using_tools {
(None, Some(vec![stream_token.token.text.clone()]))
} else {
let content = if !stream_token.token.special {
Some(stream_token.token.text.clone())
} else {
None
};
(content, None)
};
let (usage, finish_reason) = match &stream_token.details {
Some(details) => {
let usage = if stream_options
.as_ref()
.map(|s| s.include_usage)
.unwrap_or(false)
{
let completion_tokens = details.generated_tokens;
let prompt_tokens = details.input_length;
let total_tokens = prompt_tokens + completion_tokens;
Some(Usage {
completion_tokens,
prompt_tokens,
total_tokens,
})
} else {
None
};
(usage, Some(details.finish_reason.format(true)))
}
None => (None, None),
};
let chat_complete = CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
model_id.clone(),
system_fingerprint.clone(),
content,
tool_calls,
current_time,
logprobs,
finish_reason,
usage,
));
event.json_data(chat_complete).unwrap_or_else(|e| {
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
Event::default()
})
}
/// Generate tokens /// Generate tokens
#[utoipa::path( #[utoipa::path(
post, post,
@ -1128,88 +1224,135 @@ async fn chat_completions(
// static values that will be returned in all cases // static values that will be returned in all cases
let model_id = info.model_id.clone(); let model_id = info.model_id.clone();
let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native")); let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
// switch on stream // switch on stream
if stream { if stream {
// pass this callback to the stream generation and build the required event structure let (headers, response_stream) =
let on_message_callback = move |stream_token: StreamResponse| { generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
let event = Event::default();
let current_time = std::time::SystemTime::now() // regex to match any function name
.duration_since(std::time::UNIX_EPOCH) let function_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)""#) {
.unwrap_or_else(|_| std::time::Duration::from_secs(0)) Ok(regex) => regex,
.as_secs(); Err(e) => {
return Err((
let logprobs = logprobs.then(|| { StatusCode::INTERNAL_SERVER_ERROR,
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens)) Json(ErrorResponse {
}); error: format!("Failed to compile regex: {}", e),
error_type: "regex".to_string(),
// replace the content with the tool calls if grammar is present }),
let (content, tool_calls) = if using_tools {
(None, Some(vec![stream_token.token.text]))
} else {
let content = if !stream_token.token.special {
Some(stream_token.token.text)
} else {
None
};
(content, None)
};
let (usage, finish_reason) = match stream_token.details {
Some(details) => {
let usage = if stream_options
.as_ref()
.map(|s| s.include_usage)
.unwrap_or(false)
{
let completion_tokens = details.generated_tokens;
let prompt_tokens = details.input_length;
let total_tokens = prompt_tokens + completion_tokens;
Some(Usage {
completion_tokens,
prompt_tokens,
total_tokens,
})
} else {
None
};
(usage, Some(details.finish_reason.format(true)))
}
None => (None, None),
};
event
.json_data(CompletionType::ChatCompletionChunk(
ChatCompletionChunk::new(
model_id.clone(),
system_fingerprint.clone(),
content,
tool_calls,
current_time,
logprobs,
finish_reason,
usage,
),
)) ))
.unwrap_or_else(|e| { }
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
Event::default()
})
}; };
let (headers, response_stream) = generate_stream_internal( let response_stream = async_stream::stream! {
infer, let mut response_stream = Box::pin(response_stream);
compute_type, let mut buffer = Vec::new();
Json(generate_request), let mut json_buffer = String::new();
on_message_callback, let mut state = if using_tools {
span, StreamState::Buffering
) } else {
.await; StreamState::Content {
skip_close_quote: false,
}
};
let mut response_as_tool = using_tools;
while let Some(result) = response_stream.next().await {
if let Ok(stream_token) = result {
let token_text = &stream_token.token.text.clone();
match state {
StreamState::Buffering => {
json_buffer.push_str(&token_text.replace(" ", ""));
buffer.push(stream_token);
if let Some(captures) = function_regex.captures(&json_buffer) {
let function_name = captures[1].to_string();
if function_name == "no_tool" {
state = StreamState::BufferTrailing;
response_as_tool = false;
buffer.clear();
json_buffer.clear();
} else {
state = StreamState::Content {
skip_close_quote: false,
};
// send all the buffered messages
for stream_token in &buffer {
let event = create_event_from_stream_token(
stream_token,
logprobs,
stream_options.clone(),
response_as_tool,
system_fingerprint.clone(),
model_id.clone(),
);
yield Ok::<Event, Infallible>(event);
}
}
}
}
// if we skipped sending the buffer we need to avoid sending the following json key and quotes
StreamState::BufferTrailing => {
let infix_text = "\"content\":\"";
json_buffer.push_str(&token_text.replace(" ", ""));
// keep capturing until we find the infix text
match json_buffer.find(infix_text) {
Some(content_key_index) => {
json_buffer =
json_buffer[content_key_index + infix_text.len()..].to_string();
}
None => {
continue;
}
}
// if there is leftover text after removing the infix text, we need to send it
if !json_buffer.is_empty() {
let event = Event::default();
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();
let chat_complete =
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
model_id.clone(),
system_fingerprint.clone(),
Some(json_buffer.clone()),
None,
current_time,
None,
None,
None,
));
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
InferError::StreamSerializationError(e.to_string()).into()
}));
}
// cleanup the buffers
buffer.clear();
json_buffer.clear();
state = StreamState::Content {
skip_close_quote: true,
};
}
StreamState::Content { skip_close_quote } => {
if skip_close_quote && token_text.contains('"') {
break;
}
let response_stream = response_stream.chain(futures::stream::once(async { // send the content
Ok(Event::default().data("[DONE]")) let event = create_event_from_stream_token(
})); &stream_token,
logprobs,
stream_options.clone(),
response_as_tool,
system_fingerprint.clone(),
model_id.clone(),
);
yield Ok::<Event, Infallible>(event);
}
}
}
}
yield Ok::<Event, Infallible>(Event::default().data("[DONE]"));
};
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
Ok((headers, sse).into_response()) Ok((headers, sse).into_response())
@ -1246,17 +1389,33 @@ async fn chat_completions(
if let Value::Object(ref mut props) = arguments { if let Value::Object(ref mut props) = arguments {
props.remove("_name"); props.remove("_name");
} }
match name.as_str() {
let tool_calls = vec![ToolCall { "no_tool" => {
id: "0".to_string(), // parse the content message
r#type: "function".to_string(), let content_message = arguments
function: FunctionDefinition { .get("content")
description: None, .and_then(Value::as_str)
name, .ok_or_else(|| {
arguments, InferError::ToolError(
}, "No `content` found in generated text".to_string(),
}]; )
(Some(tool_calls), None) })?
.to_string();
(None, Some(content_message))
}
_ => {
let tool_calls = vec![ToolCall {
id: "0".to_string(),
r#type: "function".to_string(),
function: FunctionDefinition {
description: None,
name,
arguments,
},
}];
(Some(tool_calls), None)
}
}
} else { } else {
(None, Some(generation.generated_text)) (None, Some(generation.generated_text))
}; };
@ -2323,6 +2482,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::MissingTemplateVariable(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::MissingTemplateVariable(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::StreamSerializationError(_) => StatusCode::INTERNAL_SERVER_ERROR,
}; };
( (
@ -2500,8 +2660,8 @@ mod tests {
); );
assert!(result.is_ok()); assert!(result.is_ok());
let (inputs, _grammar, using_tools) = result.unwrap(); let (inputs, _grammar, using_tools) = result.expect("Failed to prepare chat input");
assert_eq!(using_tools, true); assert_eq!(using_tools, true);
assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"error\":{\"description\":\"The error or issue to notify\",\"type\":\"string\"}},\"required\":[\"error\"],\"type\":\"object\"}, \"description\": \"Notify an error or issue\", \"name\": \"notify_error\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string()); assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"content\":{\"description\":\"The response content\",\"type\":\"string\"}},\"required\":[\"content\"],\"type\":\"object\"}, \"description\": \"Open ened response with no specific tool selected\", \"name\": \"no_tool\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string());
} }
} }