From e36dfaa8de2d9a9fa67eeed5ce64fd5949916c99 Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 10 Oct 2024 09:28:25 -0400 Subject: [PATCH] feat: allow tool calling to respond without a tool (#2614) * feat: process token stream before returning to client * fix: expect content in test * fix: improve comparison via ruff lint * fix: return event in all cases * fix: always send event on error, avoid unwraps, refactor and improve tests * fix: prefer no_tool over notify_error to improve reponse * fix: adjust chat input test for no_tool * fix: adjust test expected content --------- Co-authored-by: System administrator --- ...rammar_tools_insufficient_information.json | 32 +- ...tools_insufficient_information_stream.json | 20 + ...ma_grammar_tools_sea_creatures_stream.json | 20 + integration-tests/models/test_tools_llama.py | 97 +++- router/src/infer/mod.rs | 3 + router/src/infer/tool_grammar.rs | 16 +- router/src/server.rs | 480 ++++++++++++------ 7 files changed, 472 insertions(+), 196 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information_stream.json create mode 100644 integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream.json diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information.json index 0cd3c67f..70b20362 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information.json @@ -1,38 +1,26 @@ { "choices": [ { - "finish_reason": "eos_token", + "finish_reason": "stop", "index": 0, "logprobs": null, "message": { - "content": null, + "content": "I am an AI assistant", "name": null, "role": "assistant", - "tool_calls": [ - { - "function": { - "arguments": { - "error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options." - }, - "description": null, - "name": "notify_error" - }, - "id": 0, - "type": "function" - } - ] + "tool_calls": null }, "usage": null } ], - "created": 1712852597, + "created": 1728497062, "id": "", - "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", - "object": "text_completion", - "system_fingerprint": "1.4.5-native", + "model": "meta-llama/Llama-3.1-8B-Instruct", + "object": "chat.completion", + "system_fingerprint": "2.3.2-dev0-native", "usage": { - "completion_tokens": 39, - "prompt_tokens": 496, - "total_tokens": 535 + "completion_tokens": 23, + "prompt_tokens": 604, + "total_tokens": 627 } } diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information_stream.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information_stream.json new file mode 100644 index 00000000..fa208c54 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information_stream.json @@ -0,0 +1,20 @@ +{ + "choices": [ + { + "delta": { + "content": " assistant", + "role": "assistant", + "tool_calls": null + }, + "finish_reason": null, + "index": 0, + "logprobs": null + } + ], + "created": 1728497531, + "id": "", + "model": "meta-llama/Llama-3.1-8B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "2.3.2-dev0-native", + "usage": null +} diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream.json new file mode 100644 index 00000000..72232e17 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream.json @@ -0,0 +1,20 @@ +{ + "choices": [ + { + "delta": { + "content": " fans", + "role": "assistant", + "tool_calls": null + }, + "finish_reason": null, + "index": 0, + "logprobs": null + } + ], + "created": 1728497461, + "id": "", + "model": "meta-llama/Llama-3.1-8B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "2.3.2-dev0-native", + "usage": null +} diff --git a/integration-tests/models/test_tools_llama.py b/integration-tests/models/test_tools_llama.py index c337afa1..98e75bb4 100644 --- a/integration-tests/models/test_tools_llama.py +++ b/integration-tests/models/test_tools_llama.py @@ -207,11 +207,20 @@ async def test_flash_llama_grammar_tools_stream( ) count = 0 + tool_calls_generated = "" + last_response = None async for response in responses: count += 1 + tool_calls_generated += response.choices[0].delta.tool_calls.function.arguments + last_response = response + assert response.choices[0].delta.content is None + assert ( + tool_calls_generated + == '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "Paris, France"}}<|eot_id|>' + ) assert count == 28 - assert response == response_snapshot + assert last_response == response_snapshot @pytest.mark.asyncio @@ -227,18 +236,94 @@ async def test_flash_llama_grammar_tools_insufficient_information( messages=[ { "role": "system", - "content": "STRICTLY ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION", + "content": "You're a helpful assistant! Answer the users question best you can.", + }, + { + "role": "user", + "content": "Who are you?", + }, + ], + stream=False, + ) + + assert responses.choices[0].message.tool_calls is None + assert responses.choices[0].message.content == "I am an AI assistant" + + assert responses == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_tools_insufficient_information_stream( + flash_llama_grammar_tools, response_snapshot +): + responses = await flash_llama_grammar_tools.chat( + max_tokens=100, + seed=24, + tools=tools, + tool_choice="auto", + messages=[ + { + "role": "system", + "content": "You're a helpful assistant! Answer the users question best you can.", + }, + { + "role": "user", + "content": "Who are you?", + }, + ], + stream=True, + ) + + count = 0 + content_generated = "" + last_response = None + async for response in responses: + count += 1 + content_generated += response.choices[0].delta.content + last_response = response + assert response.choices[0].delta.tool_calls is None + + assert count == 5 + assert content_generated == "I am an AI assistant" + assert last_response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_tools_sea_creatures_stream( + flash_llama_grammar_tools, response_snapshot +): + responses = await flash_llama_grammar_tools.chat( + max_tokens=100, + seed=24, + tools=tools, + tool_choice="auto", + messages=[ + { + "role": "system", + "content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.", }, { "role": "user", "content": "Tell me a story about 3 sea creatures", }, ], - stream=False, + stream=True, ) - assert responses.choices[0].message.content is None + count = 0 + content_generated = "" + last_response = None + async for response in responses: + count += 1 + content_generated += response.choices[0].delta.content + last_response = response + assert response.choices[0].delta.tool_calls is None + + assert count == 62 assert ( - responses.choices[0].message.tool_calls[0]["function"]["name"] == "notify_error" + content_generated + == "Once upon a time, in the ocean, there lived three sea creatures. There was a wise old octopus named Bob, a mischievous seagull named Sam, and a gentle sea turtle named Luna. They all lived together in a beautiful coral reef, surrounded by colorful fish and swaying sea fans" ) - assert responses == response_snapshot + assert last_response == response_snapshot diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 1c9d5620..896f4f43 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -355,6 +355,8 @@ pub enum InferError { MissingTemplateVariable(String), #[error("Tool error: {0}")] ToolError(String), + #[error("Stream event serialization error")] + StreamSerializationError(String), } impl InferError { @@ -368,6 +370,7 @@ impl InferError { InferError::TemplateError(_) => "template_error", InferError::MissingTemplateVariable(_) => "missing_template_variable", InferError::ToolError(_) => "tool_error", + InferError::StreamSerializationError(_) => "stream_serialization_error", } } } diff --git a/router/src/infer/tool_grammar.rs b/router/src/infer/tool_grammar.rs index cc9bf31d..f86205fb 100644 --- a/router/src/infer/tool_grammar.rs +++ b/router/src/infer/tool_grammar.rs @@ -31,25 +31,25 @@ impl ToolGrammar { let mut tools = tools.clone(); - // add the notify_error function to the tools - let notify_error = Tool { + // add the no_tool function to the tools + let no_tool = Tool { r#type: "function".to_string(), function: FunctionDefinition { - name: "notify_error".to_string(), - description: Some("Notify an error or issue".to_string()), + name: "no_tool".to_string(), + description: Some("Open ened response with no specific tool selected".to_string()), arguments: json!({ "type": "object", "properties": { - "error": { + "content": { "type": "string", - "description": "The error or issue to notify" + "description": "The response content", } }, - "required": ["error"] + "required": ["content"] }), }, }; - tools.push(notify_error); + tools.push(no_tool); // if tools are provided and no tool_choice we default to the OneOf let tools_to_use = match tool_choice { diff --git a/router/src/server.rs b/router/src/server.rs index 73b54321..5e6e6960 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -42,6 +42,7 @@ use hf_hub::{Cache, Repo, RepoType}; use http::header::AUTHORIZATION; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use pyo3::types::IntoPyDict; +use regex::Regex; use serde_json::Value; use std::convert::Infallible; use std::fs::File; @@ -452,12 +453,20 @@ async fn generate_stream( Sse>>, ) { let span = tracing::Span::current(); - let on_message_callback = |stream_token: StreamResponse| { - let event = Event::default(); - event.json_data(stream_token).unwrap() - }; let (headers, response_stream) = - generate_stream_internal(infer, compute_type, Json(req), on_message_callback, span).await; + generate_stream_internal(infer, compute_type, Json(req), span).await; + + let response_stream = async_stream::stream! { + let mut response_stream = Box::pin(response_stream); + while let Some(raw_event) = response_stream.next().await { + yield Ok(raw_event.map_or_else(Event::from, |token| { + Event::default() + .json_data(token) + .unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into()) + })); + } + }; + let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); (headers, sse) } @@ -466,9 +475,11 @@ async fn generate_stream_internal( infer: Infer, ComputeType(compute_type): ComputeType, Json(req): Json, - on_message_callback: impl Fn(StreamResponse) -> Event, span: tracing::Span, -) -> (HeaderMap, impl Stream>) { +) -> ( + HeaderMap, + impl Stream>, +) { let start_time = Instant::now(); metrics::counter!("tgi_request_count").increment(1); @@ -500,12 +511,12 @@ async fn generate_stream_internal( let err = InferError::from(ValidationError::BestOfStream); metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); tracing::error!("{err}"); - yield Ok(Event::from(err)); + yield Err(err); } else if req.parameters.decoder_input_details { let err = InferError::from(ValidationError::PrefillDetailsStream); metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); tracing::error!("{err}"); - yield Ok(Event::from(err)); + yield Err(err); } else { match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { // Keep permit as long as generate_stream lives @@ -535,8 +546,7 @@ async fn generate_stream_internal( generated_text: None, details: None, }; - let event = on_message_callback(stream_token); - yield Ok(event); + yield Ok(stream_token); } // Yield event for last token and compute timings InferStreamResponse::End { @@ -600,9 +610,7 @@ async fn generate_stream_internal( details }; - - let event = on_message_callback(stream_token); - yield Ok(event); + yield Ok(stream_token); break; } } @@ -610,7 +618,7 @@ async fn generate_stream_internal( // yield error Err(err) => { error = true; - yield Ok(Event::from(err)); + yield Err(err); break; } } @@ -619,7 +627,7 @@ async fn generate_stream_internal( // yield error Err(err) => { error = true; - yield Ok(Event::from(err)); + yield Err(err); } } // Check if generation reached the end @@ -628,7 +636,7 @@ async fn generate_stream_internal( let err = InferError::IncompleteGenerationStream; metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1); tracing::error!("{err}"); - yield Ok(Event::from(err)); + yield Err(err); } } }; @@ -771,75 +779,85 @@ async fn completions( // Create a future for each generate_stream_internal call. let generate_future = async move { - let on_message_callback = move |stream_token: StreamResponse| { - let event = Event::default(); - - let current_time = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_else(|_| std::time::Duration::from_secs(0)) - .as_secs(); - - let message = match stream_token.details { - Some(details) => { - let completion_tokens = details.generated_tokens; - let prompt_tokens = details.input_length; - let total_tokens = prompt_tokens + completion_tokens; - - Completion::Final(CompletionFinal { - id: String::new(), - created: current_time, - model: model_id.clone(), - system_fingerprint: system_fingerprint.clone(), - choices: vec![CompletionComplete { - finish_reason: details.finish_reason.to_string(), - index: index as u32, - logprobs: None, - text: stream_token.token.text, - }], - usage: Usage { - prompt_tokens, - completion_tokens, - total_tokens, - }, - }) - } - None => Completion::Chunk(Chunk { - id: String::new(), - created: current_time, - choices: vec![CompletionComplete { - finish_reason: String::new(), - index: index as u32, - logprobs: None, - text: stream_token.token.text, - }], - model: model_id.clone(), - system_fingerprint: system_fingerprint.clone(), - }), - }; - - event - .json_data(message) - .unwrap_or_else(|_e| Event::default()) - }; - let (header_tx, header_rx) = oneshot::channel(); let (sse_tx, sse_rx) = tokio::sync::mpsc::unbounded_channel(); tokio::spawn(async move { - let (header_map, sse) = generate_stream_internal( + let (headers, response_stream) = generate_stream_internal( infer_clone.clone(), compute_type_clone.clone(), Json(generate_request), - on_message_callback, span_clone.clone(), ) .await; + let response_stream = async_stream::stream! { + let mut response_stream = Box::pin(response_stream); + + while let Some(stream_token) = response_stream.next().await { + match stream_token { + Ok(stream_token) => { + let event = Event::default(); + + let current_time = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| std::time::Duration::from_secs(0)) + .as_secs(); + + let message = match stream_token.details { + Some(details) => { + let completion_tokens = details.generated_tokens; + let prompt_tokens = details.input_length; + let total_tokens = prompt_tokens + completion_tokens; + + Completion::Final(CompletionFinal { + id: String::new(), + created: current_time, + model: model_id.clone(), + system_fingerprint: system_fingerprint.clone(), + choices: vec![CompletionComplete { + finish_reason: details.finish_reason.to_string(), + index: index as u32, + logprobs: None, + text: stream_token.token.text, + }], + usage: Usage { + prompt_tokens, + completion_tokens, + total_tokens, + }, + }) + } + None => Completion::Chunk(Chunk { + id: String::new(), + created: current_time, + choices: vec![CompletionComplete { + finish_reason: String::new(), + index: index as u32, + logprobs: None, + text: stream_token.token.text, + }], + model: model_id.clone(), + system_fingerprint: system_fingerprint.clone(), + }), + }; + + let event = event + .json_data(message) + .unwrap_or_else(|_e| Event::default()); + + yield Ok(event); + } + Err(err) => yield Ok(Event::from(err)), + } + } + }; + // send and dont wait for response - let _ = header_tx.send(header_map); + let _ = header_tx.send(headers); // pin an emit messages to the sse_tx - let mut sse = Box::pin(sse); + let mut sse = Box::pin(response_stream); while let Some(event) = sse.next().await { if sse_tx.send(event).is_err() { tracing::error!("Failed to send event. Receiver dropped."); @@ -1072,6 +1090,84 @@ async fn completions( } } +enum StreamState { + Buffering, + BufferTrailing, + Content { skip_close_quote: bool }, +} + +/// Convert a StreamResponse into an Event to be sent over SSE +fn create_event_from_stream_token( + stream_token: &StreamResponse, + logprobs: bool, + stream_options: Option, + inner_using_tools: bool, + system_fingerprint: String, + model_id: String, +) -> Event { + let event = Event::default(); + let current_time = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| std::time::Duration::from_secs(0)) + .as_secs(); + + let logprobs = logprobs.then(|| { + ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens.clone())) + }); + + // replace the content with the tool calls if grammar is present + let (content, tool_calls) = if inner_using_tools { + (None, Some(vec![stream_token.token.text.clone()])) + } else { + let content = if !stream_token.token.special { + Some(stream_token.token.text.clone()) + } else { + None + }; + + (content, None) + }; + + let (usage, finish_reason) = match &stream_token.details { + Some(details) => { + let usage = if stream_options + .as_ref() + .map(|s| s.include_usage) + .unwrap_or(false) + { + let completion_tokens = details.generated_tokens; + let prompt_tokens = details.input_length; + let total_tokens = prompt_tokens + completion_tokens; + Some(Usage { + completion_tokens, + prompt_tokens, + total_tokens, + }) + } else { + None + }; + (usage, Some(details.finish_reason.format(true))) + } + None => (None, None), + }; + + let chat_complete = CompletionType::ChatCompletionChunk(ChatCompletionChunk::new( + model_id.clone(), + system_fingerprint.clone(), + content, + tool_calls, + current_time, + logprobs, + finish_reason, + usage, + )); + + event.json_data(chat_complete).unwrap_or_else(|e| { + println!("Failed to serialize ChatCompletionChunk: {:?}", e); + Event::default() + }) +} + /// Generate tokens #[utoipa::path( post, @@ -1128,88 +1224,135 @@ async fn chat_completions( // static values that will be returned in all cases let model_id = info.model_id.clone(); let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native")); - // switch on stream if stream { - // pass this callback to the stream generation and build the required event structure - let on_message_callback = move |stream_token: StreamResponse| { - let event = Event::default(); + let (headers, response_stream) = + generate_stream_internal(infer, compute_type, Json(generate_request), span).await; - let current_time = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_else(|_| std::time::Duration::from_secs(0)) - .as_secs(); - - let logprobs = logprobs.then(|| { - ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens)) - }); - - // replace the content with the tool calls if grammar is present - let (content, tool_calls) = if using_tools { - (None, Some(vec![stream_token.token.text])) - } else { - let content = if !stream_token.token.special { - Some(stream_token.token.text) - } else { - None - }; - - (content, None) - }; - - let (usage, finish_reason) = match stream_token.details { - Some(details) => { - let usage = if stream_options - .as_ref() - .map(|s| s.include_usage) - .unwrap_or(false) - { - let completion_tokens = details.generated_tokens; - let prompt_tokens = details.input_length; - let total_tokens = prompt_tokens + completion_tokens; - Some(Usage { - completion_tokens, - prompt_tokens, - total_tokens, - }) - } else { - None - }; - (usage, Some(details.finish_reason.format(true))) - } - None => (None, None), - }; - event - .json_data(CompletionType::ChatCompletionChunk( - ChatCompletionChunk::new( - model_id.clone(), - system_fingerprint.clone(), - content, - tool_calls, - current_time, - logprobs, - finish_reason, - usage, - ), + // regex to match any function name + let function_regex = match Regex::new(r#"\{"function":\{"_name":"([^"]+)""#) { + Ok(regex) => regex, + Err(e) => { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: format!("Failed to compile regex: {}", e), + error_type: "regex".to_string(), + }), )) - .unwrap_or_else(|e| { - println!("Failed to serialize ChatCompletionChunk: {:?}", e); - Event::default() - }) + } }; - let (headers, response_stream) = generate_stream_internal( - infer, - compute_type, - Json(generate_request), - on_message_callback, - span, - ) - .await; + let response_stream = async_stream::stream! { + let mut response_stream = Box::pin(response_stream); + let mut buffer = Vec::new(); + let mut json_buffer = String::new(); + let mut state = if using_tools { + StreamState::Buffering + } else { + StreamState::Content { + skip_close_quote: false, + } + }; + let mut response_as_tool = using_tools; + while let Some(result) = response_stream.next().await { + if let Ok(stream_token) = result { + let token_text = &stream_token.token.text.clone(); + match state { + StreamState::Buffering => { + json_buffer.push_str(&token_text.replace(" ", "")); + buffer.push(stream_token); + if let Some(captures) = function_regex.captures(&json_buffer) { + let function_name = captures[1].to_string(); + if function_name == "no_tool" { + state = StreamState::BufferTrailing; + response_as_tool = false; + buffer.clear(); + json_buffer.clear(); + } else { + state = StreamState::Content { + skip_close_quote: false, + }; + // send all the buffered messages + for stream_token in &buffer { + let event = create_event_from_stream_token( + stream_token, + logprobs, + stream_options.clone(), + response_as_tool, + system_fingerprint.clone(), + model_id.clone(), + ); + yield Ok::(event); + } + } + } + } + // if we skipped sending the buffer we need to avoid sending the following json key and quotes + StreamState::BufferTrailing => { + let infix_text = "\"content\":\""; + json_buffer.push_str(&token_text.replace(" ", "")); + // keep capturing until we find the infix text + match json_buffer.find(infix_text) { + Some(content_key_index) => { + json_buffer = + json_buffer[content_key_index + infix_text.len()..].to_string(); + } + None => { + continue; + } + } + // if there is leftover text after removing the infix text, we need to send it + if !json_buffer.is_empty() { + let event = Event::default(); + let current_time = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| std::time::Duration::from_secs(0)) + .as_secs(); + let chat_complete = + CompletionType::ChatCompletionChunk(ChatCompletionChunk::new( + model_id.clone(), + system_fingerprint.clone(), + Some(json_buffer.clone()), + None, + current_time, + None, + None, + None, + )); + yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| { + InferError::StreamSerializationError(e.to_string()).into() + })); + } + // cleanup the buffers + buffer.clear(); + json_buffer.clear(); + state = StreamState::Content { + skip_close_quote: true, + }; + } + StreamState::Content { skip_close_quote } => { + if skip_close_quote && token_text.contains('"') { + break; + } - let response_stream = response_stream.chain(futures::stream::once(async { - Ok(Event::default().data("[DONE]")) - })); + // send the content + let event = create_event_from_stream_token( + &stream_token, + logprobs, + stream_options.clone(), + response_as_tool, + system_fingerprint.clone(), + model_id.clone(), + ); + + yield Ok::(event); + } + } + } + } + yield Ok::(Event::default().data("[DONE]")); + }; let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); Ok((headers, sse).into_response()) @@ -1246,17 +1389,33 @@ async fn chat_completions( if let Value::Object(ref mut props) = arguments { props.remove("_name"); } - - let tool_calls = vec![ToolCall { - id: "0".to_string(), - r#type: "function".to_string(), - function: FunctionDefinition { - description: None, - name, - arguments, - }, - }]; - (Some(tool_calls), None) + match name.as_str() { + "no_tool" => { + // parse the content message + let content_message = arguments + .get("content") + .and_then(Value::as_str) + .ok_or_else(|| { + InferError::ToolError( + "No `content` found in generated text".to_string(), + ) + })? + .to_string(); + (None, Some(content_message)) + } + _ => { + let tool_calls = vec![ToolCall { + id: "0".to_string(), + r#type: "function".to_string(), + function: FunctionDefinition { + description: None, + name, + arguments, + }, + }]; + (Some(tool_calls), None) + } + } } else { (None, Some(generation.generated_text)) }; @@ -2323,6 +2482,7 @@ impl From for (StatusCode, Json) { InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::MissingTemplateVariable(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY, + InferError::StreamSerializationError(_) => StatusCode::INTERNAL_SERVER_ERROR, }; ( @@ -2500,8 +2660,8 @@ mod tests { ); assert!(result.is_ok()); - let (inputs, _grammar, using_tools) = result.unwrap(); + let (inputs, _grammar, using_tools) = result.expect("Failed to prepare chat input"); assert_eq!(using_tools, true); - assert_eq!(inputs, "[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"error\":{\"description\":\"The error or issue to notify\",\"type\":\"string\"}},\"required\":[\"error\"],\"type\":\"object\"}, \"description\": \"Notify an error or issue\", \"name\": \"notify_error\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string()); + assert_eq!(inputs, "[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"content\":{\"description\":\"The response content\",\"type\":\"string\"}},\"required\":[\"content\"],\"type\":\"object\"}, \"description\": \"Open ened response with no specific tool selected\", \"name\": \"no_tool\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string()); } }