feat: allow tool calling to respond without a tool (#2614)
* feat: process token stream before returning to client * fix: expect content in test * fix: improve comparison via ruff lint * fix: return event in all cases * fix: always send event on error, avoid unwraps, refactor and improve tests * fix: prefer no_tool over notify_error to improve reponse * fix: adjust chat input test for no_tool * fix: adjust test expected content --------- Co-authored-by: System administrator <root@ip-10-90-0-186.ec2.internal>
This commit is contained in:
parent
43f39f6894
commit
e36dfaa8de
|
@ -1,38 +1,26 @@
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"finish_reason": "eos_token",
|
"finish_reason": "stop",
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": null,
|
"logprobs": null,
|
||||||
"message": {
|
"message": {
|
||||||
"content": null,
|
"content": "I am an AI assistant",
|
||||||
"name": null,
|
"name": null,
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"tool_calls": [
|
"tool_calls": null
|
||||||
{
|
|
||||||
"function": {
|
|
||||||
"arguments": {
|
|
||||||
"error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options."
|
|
||||||
},
|
|
||||||
"description": null,
|
|
||||||
"name": "notify_error"
|
|
||||||
},
|
|
||||||
"id": 0,
|
|
||||||
"type": "function"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
},
|
||||||
"usage": null
|
"usage": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1712852597,
|
"created": 1728497062,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
"object": "text_completion",
|
"object": "chat.completion",
|
||||||
"system_fingerprint": "1.4.5-native",
|
"system_fingerprint": "2.3.2-dev0-native",
|
||||||
"usage": {
|
"usage": {
|
||||||
"completion_tokens": 39,
|
"completion_tokens": 23,
|
||||||
"prompt_tokens": 496,
|
"prompt_tokens": 604,
|
||||||
"total_tokens": 535
|
"total_tokens": 627
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": " assistant",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1728497531,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "2.3.2-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
}
|
|
@ -0,0 +1,20 @@
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": " fans",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1728497461,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "2.3.2-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
}
|
|
@ -207,11 +207,20 @@ async def test_flash_llama_grammar_tools_stream(
|
||||||
)
|
)
|
||||||
|
|
||||||
count = 0
|
count = 0
|
||||||
|
tool_calls_generated = ""
|
||||||
|
last_response = None
|
||||||
async for response in responses:
|
async for response in responses:
|
||||||
count += 1
|
count += 1
|
||||||
|
tool_calls_generated += response.choices[0].delta.tool_calls.function.arguments
|
||||||
|
last_response = response
|
||||||
|
assert response.choices[0].delta.content is None
|
||||||
|
|
||||||
|
assert (
|
||||||
|
tool_calls_generated
|
||||||
|
== '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "Paris, France"}}<|eot_id|>'
|
||||||
|
)
|
||||||
assert count == 28
|
assert count == 28
|
||||||
assert response == response_snapshot
|
assert last_response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -227,18 +236,94 @@ async def test_flash_llama_grammar_tools_insufficient_information(
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": "STRICTLY ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
|
"content": "You're a helpful assistant! Answer the users question best you can.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Who are you?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert responses.choices[0].message.tool_calls is None
|
||||||
|
assert responses.choices[0].message.content == "I am an AI assistant"
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_grammar_tools_insufficient_information_stream(
|
||||||
|
flash_llama_grammar_tools, response_snapshot
|
||||||
|
):
|
||||||
|
responses = await flash_llama_grammar_tools.chat(
|
||||||
|
max_tokens=100,
|
||||||
|
seed=24,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="auto",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You're a helpful assistant! Answer the users question best you can.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Who are you?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
content_generated = ""
|
||||||
|
last_response = None
|
||||||
|
async for response in responses:
|
||||||
|
count += 1
|
||||||
|
content_generated += response.choices[0].delta.content
|
||||||
|
last_response = response
|
||||||
|
assert response.choices[0].delta.tool_calls is None
|
||||||
|
|
||||||
|
assert count == 5
|
||||||
|
assert content_generated == "I am an AI assistant"
|
||||||
|
assert last_response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_grammar_tools_sea_creatures_stream(
|
||||||
|
flash_llama_grammar_tools, response_snapshot
|
||||||
|
):
|
||||||
|
responses = await flash_llama_grammar_tools.chat(
|
||||||
|
max_tokens=100,
|
||||||
|
seed=24,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="auto",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": "Tell me a story about 3 sea creatures",
|
"content": "Tell me a story about 3 sea creatures",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
stream=False,
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert responses.choices[0].message.content is None
|
count = 0
|
||||||
|
content_generated = ""
|
||||||
|
last_response = None
|
||||||
|
async for response in responses:
|
||||||
|
count += 1
|
||||||
|
content_generated += response.choices[0].delta.content
|
||||||
|
last_response = response
|
||||||
|
assert response.choices[0].delta.tool_calls is None
|
||||||
|
|
||||||
|
assert count == 62
|
||||||
assert (
|
assert (
|
||||||
responses.choices[0].message.tool_calls[0]["function"]["name"] == "notify_error"
|
content_generated
|
||||||
|
== "Once upon a time, in the ocean, there lived three sea creatures. There was a wise old octopus named Bob, a mischievous seagull named Sam, and a gentle sea turtle named Luna. They all lived together in a beautiful coral reef, surrounded by colorful fish and swaying sea fans"
|
||||||
)
|
)
|
||||||
assert responses == response_snapshot
|
assert last_response == response_snapshot
|
||||||
|
|
|
@ -355,6 +355,8 @@ pub enum InferError {
|
||||||
MissingTemplateVariable(String),
|
MissingTemplateVariable(String),
|
||||||
#[error("Tool error: {0}")]
|
#[error("Tool error: {0}")]
|
||||||
ToolError(String),
|
ToolError(String),
|
||||||
|
#[error("Stream event serialization error")]
|
||||||
|
StreamSerializationError(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl InferError {
|
impl InferError {
|
||||||
|
@ -368,6 +370,7 @@ impl InferError {
|
||||||
InferError::TemplateError(_) => "template_error",
|
InferError::TemplateError(_) => "template_error",
|
||||||
InferError::MissingTemplateVariable(_) => "missing_template_variable",
|
InferError::MissingTemplateVariable(_) => "missing_template_variable",
|
||||||
InferError::ToolError(_) => "tool_error",
|
InferError::ToolError(_) => "tool_error",
|
||||||
|
InferError::StreamSerializationError(_) => "stream_serialization_error",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,25 +31,25 @@ impl ToolGrammar {
|
||||||
|
|
||||||
let mut tools = tools.clone();
|
let mut tools = tools.clone();
|
||||||
|
|
||||||
// add the notify_error function to the tools
|
// add the no_tool function to the tools
|
||||||
let notify_error = Tool {
|
let no_tool = Tool {
|
||||||
r#type: "function".to_string(),
|
r#type: "function".to_string(),
|
||||||
function: FunctionDefinition {
|
function: FunctionDefinition {
|
||||||
name: "notify_error".to_string(),
|
name: "no_tool".to_string(),
|
||||||
description: Some("Notify an error or issue".to_string()),
|
description: Some("Open ened response with no specific tool selected".to_string()),
|
||||||
arguments: json!({
|
arguments: json!({
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"error": {
|
"content": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The error or issue to notify"
|
"description": "The response content",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"required": ["error"]
|
"required": ["content"]
|
||||||
}),
|
}),
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
tools.push(notify_error);
|
tools.push(no_tool);
|
||||||
|
|
||||||
// if tools are provided and no tool_choice we default to the OneOf
|
// if tools are provided and no tool_choice we default to the OneOf
|
||||||
let tools_to_use = match tool_choice {
|
let tools_to_use = match tool_choice {
|
||||||
|
|
|
@ -42,6 +42,7 @@ use hf_hub::{Cache, Repo, RepoType};
|
||||||
use http::header::AUTHORIZATION;
|
use http::header::AUTHORIZATION;
|
||||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||||
use pyo3::types::IntoPyDict;
|
use pyo3::types::IntoPyDict;
|
||||||
|
use regex::Regex;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
|
@ -452,12 +453,20 @@ async fn generate_stream(
|
||||||
Sse<impl Stream<Item = Result<Event, Infallible>>>,
|
Sse<impl Stream<Item = Result<Event, Infallible>>>,
|
||||||
) {
|
) {
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
let on_message_callback = |stream_token: StreamResponse| {
|
|
||||||
let event = Event::default();
|
|
||||||
event.json_data(stream_token).unwrap()
|
|
||||||
};
|
|
||||||
let (headers, response_stream) =
|
let (headers, response_stream) =
|
||||||
generate_stream_internal(infer, compute_type, Json(req), on_message_callback, span).await;
|
generate_stream_internal(infer, compute_type, Json(req), span).await;
|
||||||
|
|
||||||
|
let response_stream = async_stream::stream! {
|
||||||
|
let mut response_stream = Box::pin(response_stream);
|
||||||
|
while let Some(raw_event) = response_stream.next().await {
|
||||||
|
yield Ok(raw_event.map_or_else(Event::from, |token| {
|
||||||
|
Event::default()
|
||||||
|
.json_data(token)
|
||||||
|
.unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into())
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||||
(headers, sse)
|
(headers, sse)
|
||||||
}
|
}
|
||||||
|
@ -466,9 +475,11 @@ async fn generate_stream_internal(
|
||||||
infer: Infer,
|
infer: Infer,
|
||||||
ComputeType(compute_type): ComputeType,
|
ComputeType(compute_type): ComputeType,
|
||||||
Json(req): Json<GenerateRequest>,
|
Json(req): Json<GenerateRequest>,
|
||||||
on_message_callback: impl Fn(StreamResponse) -> Event,
|
|
||||||
span: tracing::Span,
|
span: tracing::Span,
|
||||||
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
|
) -> (
|
||||||
|
HeaderMap,
|
||||||
|
impl Stream<Item = Result<StreamResponse, InferError>>,
|
||||||
|
) {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
metrics::counter!("tgi_request_count").increment(1);
|
metrics::counter!("tgi_request_count").increment(1);
|
||||||
|
|
||||||
|
@ -500,12 +511,12 @@ async fn generate_stream_internal(
|
||||||
let err = InferError::from(ValidationError::BestOfStream);
|
let err = InferError::from(ValidationError::BestOfStream);
|
||||||
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
yield Ok(Event::from(err));
|
yield Err(err);
|
||||||
} else if req.parameters.decoder_input_details {
|
} else if req.parameters.decoder_input_details {
|
||||||
let err = InferError::from(ValidationError::PrefillDetailsStream);
|
let err = InferError::from(ValidationError::PrefillDetailsStream);
|
||||||
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
yield Ok(Event::from(err));
|
yield Err(err);
|
||||||
} else {
|
} else {
|
||||||
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||||
// Keep permit as long as generate_stream lives
|
// Keep permit as long as generate_stream lives
|
||||||
|
@ -535,8 +546,7 @@ async fn generate_stream_internal(
|
||||||
generated_text: None,
|
generated_text: None,
|
||||||
details: None,
|
details: None,
|
||||||
};
|
};
|
||||||
let event = on_message_callback(stream_token);
|
yield Ok(stream_token);
|
||||||
yield Ok(event);
|
|
||||||
}
|
}
|
||||||
// Yield event for last token and compute timings
|
// Yield event for last token and compute timings
|
||||||
InferStreamResponse::End {
|
InferStreamResponse::End {
|
||||||
|
@ -600,9 +610,7 @@ async fn generate_stream_internal(
|
||||||
details
|
details
|
||||||
};
|
};
|
||||||
|
|
||||||
|
yield Ok(stream_token);
|
||||||
let event = on_message_callback(stream_token);
|
|
||||||
yield Ok(event);
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -610,7 +618,7 @@ async fn generate_stream_internal(
|
||||||
// yield error
|
// yield error
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
error = true;
|
error = true;
|
||||||
yield Ok(Event::from(err));
|
yield Err(err);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -619,7 +627,7 @@ async fn generate_stream_internal(
|
||||||
// yield error
|
// yield error
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
error = true;
|
error = true;
|
||||||
yield Ok(Event::from(err));
|
yield Err(err);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Check if generation reached the end
|
// Check if generation reached the end
|
||||||
|
@ -628,7 +636,7 @@ async fn generate_stream_internal(
|
||||||
let err = InferError::IncompleteGenerationStream;
|
let err = InferError::IncompleteGenerationStream;
|
||||||
metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1);
|
metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1);
|
||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
yield Ok(Event::from(err));
|
yield Err(err);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -771,75 +779,85 @@ async fn completions(
|
||||||
|
|
||||||
// Create a future for each generate_stream_internal call.
|
// Create a future for each generate_stream_internal call.
|
||||||
let generate_future = async move {
|
let generate_future = async move {
|
||||||
let on_message_callback = move |stream_token: StreamResponse| {
|
|
||||||
let event = Event::default();
|
|
||||||
|
|
||||||
let current_time = std::time::SystemTime::now()
|
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
|
||||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
|
||||||
.as_secs();
|
|
||||||
|
|
||||||
let message = match stream_token.details {
|
|
||||||
Some(details) => {
|
|
||||||
let completion_tokens = details.generated_tokens;
|
|
||||||
let prompt_tokens = details.input_length;
|
|
||||||
let total_tokens = prompt_tokens + completion_tokens;
|
|
||||||
|
|
||||||
Completion::Final(CompletionFinal {
|
|
||||||
id: String::new(),
|
|
||||||
created: current_time,
|
|
||||||
model: model_id.clone(),
|
|
||||||
system_fingerprint: system_fingerprint.clone(),
|
|
||||||
choices: vec![CompletionComplete {
|
|
||||||
finish_reason: details.finish_reason.to_string(),
|
|
||||||
index: index as u32,
|
|
||||||
logprobs: None,
|
|
||||||
text: stream_token.token.text,
|
|
||||||
}],
|
|
||||||
usage: Usage {
|
|
||||||
prompt_tokens,
|
|
||||||
completion_tokens,
|
|
||||||
total_tokens,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
None => Completion::Chunk(Chunk {
|
|
||||||
id: String::new(),
|
|
||||||
created: current_time,
|
|
||||||
choices: vec![CompletionComplete {
|
|
||||||
finish_reason: String::new(),
|
|
||||||
index: index as u32,
|
|
||||||
logprobs: None,
|
|
||||||
text: stream_token.token.text,
|
|
||||||
}],
|
|
||||||
model: model_id.clone(),
|
|
||||||
system_fingerprint: system_fingerprint.clone(),
|
|
||||||
}),
|
|
||||||
};
|
|
||||||
|
|
||||||
event
|
|
||||||
.json_data(message)
|
|
||||||
.unwrap_or_else(|_e| Event::default())
|
|
||||||
};
|
|
||||||
|
|
||||||
let (header_tx, header_rx) = oneshot::channel();
|
let (header_tx, header_rx) = oneshot::channel();
|
||||||
let (sse_tx, sse_rx) = tokio::sync::mpsc::unbounded_channel();
|
let (sse_tx, sse_rx) = tokio::sync::mpsc::unbounded_channel();
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let (header_map, sse) = generate_stream_internal(
|
let (headers, response_stream) = generate_stream_internal(
|
||||||
infer_clone.clone(),
|
infer_clone.clone(),
|
||||||
compute_type_clone.clone(),
|
compute_type_clone.clone(),
|
||||||
Json(generate_request),
|
Json(generate_request),
|
||||||
on_message_callback,
|
|
||||||
span_clone.clone(),
|
span_clone.clone(),
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
|
let response_stream = async_stream::stream! {
|
||||||
|
let mut response_stream = Box::pin(response_stream);
|
||||||
|
|
||||||
|
while let Some(stream_token) = response_stream.next().await {
|
||||||
|
match stream_token {
|
||||||
|
Ok(stream_token) => {
|
||||||
|
let event = Event::default();
|
||||||
|
|
||||||
|
let current_time = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||||
|
.as_secs();
|
||||||
|
|
||||||
|
let message = match stream_token.details {
|
||||||
|
Some(details) => {
|
||||||
|
let completion_tokens = details.generated_tokens;
|
||||||
|
let prompt_tokens = details.input_length;
|
||||||
|
let total_tokens = prompt_tokens + completion_tokens;
|
||||||
|
|
||||||
|
Completion::Final(CompletionFinal {
|
||||||
|
id: String::new(),
|
||||||
|
created: current_time,
|
||||||
|
model: model_id.clone(),
|
||||||
|
system_fingerprint: system_fingerprint.clone(),
|
||||||
|
choices: vec![CompletionComplete {
|
||||||
|
finish_reason: details.finish_reason.to_string(),
|
||||||
|
index: index as u32,
|
||||||
|
logprobs: None,
|
||||||
|
text: stream_token.token.text,
|
||||||
|
}],
|
||||||
|
usage: Usage {
|
||||||
|
prompt_tokens,
|
||||||
|
completion_tokens,
|
||||||
|
total_tokens,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
None => Completion::Chunk(Chunk {
|
||||||
|
id: String::new(),
|
||||||
|
created: current_time,
|
||||||
|
choices: vec![CompletionComplete {
|
||||||
|
finish_reason: String::new(),
|
||||||
|
index: index as u32,
|
||||||
|
logprobs: None,
|
||||||
|
text: stream_token.token.text,
|
||||||
|
}],
|
||||||
|
model: model_id.clone(),
|
||||||
|
system_fingerprint: system_fingerprint.clone(),
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
|
let event = event
|
||||||
|
.json_data(message)
|
||||||
|
.unwrap_or_else(|_e| Event::default());
|
||||||
|
|
||||||
|
yield Ok(event);
|
||||||
|
}
|
||||||
|
Err(err) => yield Ok(Event::from(err)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// send and dont wait for response
|
// send and dont wait for response
|
||||||
let _ = header_tx.send(header_map);
|
let _ = header_tx.send(headers);
|
||||||
|
|
||||||
// pin an emit messages to the sse_tx
|
// pin an emit messages to the sse_tx
|
||||||
let mut sse = Box::pin(sse);
|
let mut sse = Box::pin(response_stream);
|
||||||
while let Some(event) = sse.next().await {
|
while let Some(event) = sse.next().await {
|
||||||
if sse_tx.send(event).is_err() {
|
if sse_tx.send(event).is_err() {
|
||||||
tracing::error!("Failed to send event. Receiver dropped.");
|
tracing::error!("Failed to send event. Receiver dropped.");
|
||||||
|
@ -1072,6 +1090,84 @@ async fn completions(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum StreamState {
|
||||||
|
Buffering,
|
||||||
|
BufferTrailing,
|
||||||
|
Content { skip_close_quote: bool },
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert a StreamResponse into an Event to be sent over SSE
|
||||||
|
fn create_event_from_stream_token(
|
||||||
|
stream_token: &StreamResponse,
|
||||||
|
logprobs: bool,
|
||||||
|
stream_options: Option<StreamOptions>,
|
||||||
|
inner_using_tools: bool,
|
||||||
|
system_fingerprint: String,
|
||||||
|
model_id: String,
|
||||||
|
) -> Event {
|
||||||
|
let event = Event::default();
|
||||||
|
let current_time = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||||
|
.as_secs();
|
||||||
|
|
||||||
|
let logprobs = logprobs.then(|| {
|
||||||
|
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens.clone()))
|
||||||
|
});
|
||||||
|
|
||||||
|
// replace the content with the tool calls if grammar is present
|
||||||
|
let (content, tool_calls) = if inner_using_tools {
|
||||||
|
(None, Some(vec![stream_token.token.text.clone()]))
|
||||||
|
} else {
|
||||||
|
let content = if !stream_token.token.special {
|
||||||
|
Some(stream_token.token.text.clone())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
(content, None)
|
||||||
|
};
|
||||||
|
|
||||||
|
let (usage, finish_reason) = match &stream_token.details {
|
||||||
|
Some(details) => {
|
||||||
|
let usage = if stream_options
|
||||||
|
.as_ref()
|
||||||
|
.map(|s| s.include_usage)
|
||||||
|
.unwrap_or(false)
|
||||||
|
{
|
||||||
|
let completion_tokens = details.generated_tokens;
|
||||||
|
let prompt_tokens = details.input_length;
|
||||||
|
let total_tokens = prompt_tokens + completion_tokens;
|
||||||
|
Some(Usage {
|
||||||
|
completion_tokens,
|
||||||
|
prompt_tokens,
|
||||||
|
total_tokens,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
(usage, Some(details.finish_reason.format(true)))
|
||||||
|
}
|
||||||
|
None => (None, None),
|
||||||
|
};
|
||||||
|
|
||||||
|
let chat_complete = CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
|
||||||
|
model_id.clone(),
|
||||||
|
system_fingerprint.clone(),
|
||||||
|
content,
|
||||||
|
tool_calls,
|
||||||
|
current_time,
|
||||||
|
logprobs,
|
||||||
|
finish_reason,
|
||||||
|
usage,
|
||||||
|
));
|
||||||
|
|
||||||
|
event.json_data(chat_complete).unwrap_or_else(|e| {
|
||||||
|
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
|
||||||
|
Event::default()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
/// Generate tokens
|
/// Generate tokens
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
post,
|
post,
|
||||||
|
@ -1128,88 +1224,135 @@ async fn chat_completions(
|
||||||
// static values that will be returned in all cases
|
// static values that will be returned in all cases
|
||||||
let model_id = info.model_id.clone();
|
let model_id = info.model_id.clone();
|
||||||
let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
|
let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
|
||||||
|
|
||||||
// switch on stream
|
// switch on stream
|
||||||
if stream {
|
if stream {
|
||||||
// pass this callback to the stream generation and build the required event structure
|
let (headers, response_stream) =
|
||||||
let on_message_callback = move |stream_token: StreamResponse| {
|
generate_stream_internal(infer, compute_type, Json(generate_request), span).await;
|
||||||
let event = Event::default();
|
|
||||||
|
|
||||||
let current_time = std::time::SystemTime::now()
|
// regex to match any function name
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
let function_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)""#) {
|
||||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
Ok(regex) => regex,
|
||||||
.as_secs();
|
Err(e) => {
|
||||||
|
return Err((
|
||||||
let logprobs = logprobs.then(|| {
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens))
|
Json(ErrorResponse {
|
||||||
});
|
error: format!("Failed to compile regex: {}", e),
|
||||||
|
error_type: "regex".to_string(),
|
||||||
// replace the content with the tool calls if grammar is present
|
}),
|
||||||
let (content, tool_calls) = if using_tools {
|
|
||||||
(None, Some(vec![stream_token.token.text]))
|
|
||||||
} else {
|
|
||||||
let content = if !stream_token.token.special {
|
|
||||||
Some(stream_token.token.text)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
(content, None)
|
|
||||||
};
|
|
||||||
|
|
||||||
let (usage, finish_reason) = match stream_token.details {
|
|
||||||
Some(details) => {
|
|
||||||
let usage = if stream_options
|
|
||||||
.as_ref()
|
|
||||||
.map(|s| s.include_usage)
|
|
||||||
.unwrap_or(false)
|
|
||||||
{
|
|
||||||
let completion_tokens = details.generated_tokens;
|
|
||||||
let prompt_tokens = details.input_length;
|
|
||||||
let total_tokens = prompt_tokens + completion_tokens;
|
|
||||||
Some(Usage {
|
|
||||||
completion_tokens,
|
|
||||||
prompt_tokens,
|
|
||||||
total_tokens,
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
(usage, Some(details.finish_reason.format(true)))
|
|
||||||
}
|
|
||||||
None => (None, None),
|
|
||||||
};
|
|
||||||
event
|
|
||||||
.json_data(CompletionType::ChatCompletionChunk(
|
|
||||||
ChatCompletionChunk::new(
|
|
||||||
model_id.clone(),
|
|
||||||
system_fingerprint.clone(),
|
|
||||||
content,
|
|
||||||
tool_calls,
|
|
||||||
current_time,
|
|
||||||
logprobs,
|
|
||||||
finish_reason,
|
|
||||||
usage,
|
|
||||||
),
|
|
||||||
))
|
))
|
||||||
.unwrap_or_else(|e| {
|
}
|
||||||
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
|
|
||||||
Event::default()
|
|
||||||
})
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let (headers, response_stream) = generate_stream_internal(
|
let response_stream = async_stream::stream! {
|
||||||
infer,
|
let mut response_stream = Box::pin(response_stream);
|
||||||
compute_type,
|
let mut buffer = Vec::new();
|
||||||
Json(generate_request),
|
let mut json_buffer = String::new();
|
||||||
on_message_callback,
|
let mut state = if using_tools {
|
||||||
span,
|
StreamState::Buffering
|
||||||
)
|
} else {
|
||||||
.await;
|
StreamState::Content {
|
||||||
|
skip_close_quote: false,
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let mut response_as_tool = using_tools;
|
||||||
|
while let Some(result) = response_stream.next().await {
|
||||||
|
if let Ok(stream_token) = result {
|
||||||
|
let token_text = &stream_token.token.text.clone();
|
||||||
|
match state {
|
||||||
|
StreamState::Buffering => {
|
||||||
|
json_buffer.push_str(&token_text.replace(" ", ""));
|
||||||
|
buffer.push(stream_token);
|
||||||
|
if let Some(captures) = function_regex.captures(&json_buffer) {
|
||||||
|
let function_name = captures[1].to_string();
|
||||||
|
if function_name == "no_tool" {
|
||||||
|
state = StreamState::BufferTrailing;
|
||||||
|
response_as_tool = false;
|
||||||
|
buffer.clear();
|
||||||
|
json_buffer.clear();
|
||||||
|
} else {
|
||||||
|
state = StreamState::Content {
|
||||||
|
skip_close_quote: false,
|
||||||
|
};
|
||||||
|
// send all the buffered messages
|
||||||
|
for stream_token in &buffer {
|
||||||
|
let event = create_event_from_stream_token(
|
||||||
|
stream_token,
|
||||||
|
logprobs,
|
||||||
|
stream_options.clone(),
|
||||||
|
response_as_tool,
|
||||||
|
system_fingerprint.clone(),
|
||||||
|
model_id.clone(),
|
||||||
|
);
|
||||||
|
yield Ok::<Event, Infallible>(event);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// if we skipped sending the buffer we need to avoid sending the following json key and quotes
|
||||||
|
StreamState::BufferTrailing => {
|
||||||
|
let infix_text = "\"content\":\"";
|
||||||
|
json_buffer.push_str(&token_text.replace(" ", ""));
|
||||||
|
// keep capturing until we find the infix text
|
||||||
|
match json_buffer.find(infix_text) {
|
||||||
|
Some(content_key_index) => {
|
||||||
|
json_buffer =
|
||||||
|
json_buffer[content_key_index + infix_text.len()..].to_string();
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// if there is leftover text after removing the infix text, we need to send it
|
||||||
|
if !json_buffer.is_empty() {
|
||||||
|
let event = Event::default();
|
||||||
|
let current_time = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||||
|
.as_secs();
|
||||||
|
let chat_complete =
|
||||||
|
CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(
|
||||||
|
model_id.clone(),
|
||||||
|
system_fingerprint.clone(),
|
||||||
|
Some(json_buffer.clone()),
|
||||||
|
None,
|
||||||
|
current_time,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
));
|
||||||
|
yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| {
|
||||||
|
InferError::StreamSerializationError(e.to_string()).into()
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
// cleanup the buffers
|
||||||
|
buffer.clear();
|
||||||
|
json_buffer.clear();
|
||||||
|
state = StreamState::Content {
|
||||||
|
skip_close_quote: true,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
StreamState::Content { skip_close_quote } => {
|
||||||
|
if skip_close_quote && token_text.contains('"') {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
let response_stream = response_stream.chain(futures::stream::once(async {
|
// send the content
|
||||||
Ok(Event::default().data("[DONE]"))
|
let event = create_event_from_stream_token(
|
||||||
}));
|
&stream_token,
|
||||||
|
logprobs,
|
||||||
|
stream_options.clone(),
|
||||||
|
response_as_tool,
|
||||||
|
system_fingerprint.clone(),
|
||||||
|
model_id.clone(),
|
||||||
|
);
|
||||||
|
|
||||||
|
yield Ok::<Event, Infallible>(event);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
yield Ok::<Event, Infallible>(Event::default().data("[DONE]"));
|
||||||
|
};
|
||||||
|
|
||||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||||
Ok((headers, sse).into_response())
|
Ok((headers, sse).into_response())
|
||||||
|
@ -1246,17 +1389,33 @@ async fn chat_completions(
|
||||||
if let Value::Object(ref mut props) = arguments {
|
if let Value::Object(ref mut props) = arguments {
|
||||||
props.remove("_name");
|
props.remove("_name");
|
||||||
}
|
}
|
||||||
|
match name.as_str() {
|
||||||
let tool_calls = vec![ToolCall {
|
"no_tool" => {
|
||||||
id: "0".to_string(),
|
// parse the content message
|
||||||
r#type: "function".to_string(),
|
let content_message = arguments
|
||||||
function: FunctionDefinition {
|
.get("content")
|
||||||
description: None,
|
.and_then(Value::as_str)
|
||||||
name,
|
.ok_or_else(|| {
|
||||||
arguments,
|
InferError::ToolError(
|
||||||
},
|
"No `content` found in generated text".to_string(),
|
||||||
}];
|
)
|
||||||
(Some(tool_calls), None)
|
})?
|
||||||
|
.to_string();
|
||||||
|
(None, Some(content_message))
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
let tool_calls = vec![ToolCall {
|
||||||
|
id: "0".to_string(),
|
||||||
|
r#type: "function".to_string(),
|
||||||
|
function: FunctionDefinition {
|
||||||
|
description: None,
|
||||||
|
name,
|
||||||
|
arguments,
|
||||||
|
},
|
||||||
|
}];
|
||||||
|
(Some(tool_calls), None)
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
(None, Some(generation.generated_text))
|
(None, Some(generation.generated_text))
|
||||||
};
|
};
|
||||||
|
@ -2323,6 +2482,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
||||||
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
InferError::MissingTemplateVariable(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
InferError::MissingTemplateVariable(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
|
InferError::StreamSerializationError(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
};
|
};
|
||||||
|
|
||||||
(
|
(
|
||||||
|
@ -2500,8 +2660,8 @@ mod tests {
|
||||||
);
|
);
|
||||||
|
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
let (inputs, _grammar, using_tools) = result.unwrap();
|
let (inputs, _grammar, using_tools) = result.expect("Failed to prepare chat input");
|
||||||
assert_eq!(using_tools, true);
|
assert_eq!(using_tools, true);
|
||||||
assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"error\":{\"description\":\"The error or issue to notify\",\"type\":\"string\"}},\"required\":[\"error\"],\"type\":\"object\"}, \"description\": \"Notify an error or issue\", \"name\": \"notify_error\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string());
|
assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"content\":{\"description\":\"The response content\",\"type\":\"string\"}},\"required\":[\"content\"],\"type\":\"object\"}, \"description\": \"Open ened response with no specific tool selected\", \"name\": \"no_tool\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue