diff --git a/router/src/infer.rs b/router/src/infer.rs index bfa7b55c..eef42989 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -2,7 +2,7 @@ use crate::validation::{Validation, ValidationError}; use crate::{ ChatTemplateInputs, ChatTemplateVersions, Entry, GenerateRequest, GenerateStreamResponse, - HubTokenizerConfig, Message, PrefillToken, Queue, Token, + HubTokenizerConfig, Message, MessageChunk, PrefillToken, Queue, Text, TextMessage, Token, }; use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools}; use futures::future::try_join_all; @@ -362,16 +362,15 @@ impl ChatTemplate { if self.use_default_tool_template { if let Some(last_message) = messages.last_mut() { if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt { - last_message.content = Some(format!( - "{}\n---\n{}\n{}", - last_message.content.as_deref().unwrap_or_default(), - tool_prompt, - tools - )); + last_message.content.push(MessageChunk::Text(Text { + text: format!("\n---\n{}\n{}", tool_prompt, tools), + })); } } } + let messages: Vec = messages.into_iter().map(|c| c.into()).collect(); + self.template .render(ChatTemplateInputs { messages, @@ -939,8 +938,7 @@ impl InferError { #[cfg(test)] mod tests { use crate::infer::raise_exception; - use crate::ChatTemplateInputs; - use crate::Message; + use crate::{ChatTemplateInputs, TextMessage}; use minijinja::Environment; #[test] @@ -974,33 +972,21 @@ mod tests { let chat_template_inputs = ChatTemplateInputs { messages: vec![ - Message { + TextMessage { role: "user".to_string(), - content: Some("Hi!".to_string()), - name: None, - tool_calls: None, - tool_call_id: None, + content: "Hi!".to_string(), }, - Message { + TextMessage { role: "assistant".to_string(), - content: Some("Hello how can I help?".to_string()), - name: None, - tool_calls: None, - tool_call_id: None, + content: "Hello how can I help?".to_string(), }, - Message { + TextMessage { role: "user".to_string(), - content: Some("What is Deep Learning?".to_string()), - name: None, - tool_calls: None, - tool_call_id: None, + content: "What is Deep Learning?".to_string(), }, - Message { + TextMessage { role: "assistant".to_string(), - content: Some("magic!".to_string()), - name: None, - tool_calls: None, - tool_call_id: None, + content: "magic!".to_string(), }, ], bos_token: Some("[BOS]"), @@ -1048,40 +1034,25 @@ mod tests { let chat_template_inputs = ChatTemplateInputs { messages: vec![ - Message { + TextMessage { role: "user".to_string(), - content: Some("Hi!".to_string()), - name: None, - tool_calls: None, - tool_call_id: None, + content: "Hi!".to_string(), }, - Message { + TextMessage { role: "user".to_string(), - content: Some("Hi again!".to_string()), - name: None, - tool_calls: None, - tool_call_id: None, + content: "Hi again!".to_string(), }, - Message { + TextMessage { role: "assistant".to_string(), - content: Some("Hello how can I help?".to_string()), - name: None, - tool_calls: None, - tool_call_id: None, + content: "Hello how can I help?".to_string(), }, - Message { + TextMessage { role: "user".to_string(), - content: Some("What is Deep Learning?".to_string()), - name: None, - tool_calls: None, - tool_call_id: None, + content: "What is Deep Learning?".to_string(), }, - Message { + TextMessage { role: "assistant".to_string(), - content: Some("magic!".to_string()), - name: None, - tool_calls: None, - tool_call_id: None, + content: "magic!".to_string(), }, ], bos_token: Some("[BOS]"), @@ -1134,33 +1105,21 @@ mod tests { let chat_template_inputs = ChatTemplateInputs { messages: vec![ - Message { + TextMessage { role: "user".to_string(), - content: Some("Hi!".to_string()), - name: None, - tool_calls: None, - tool_call_id: None, + content: "Hi!".to_string(), }, - Message { + TextMessage { role: "assistant".to_string(), - content: Some("Hello how can I help?".to_string()), - name: None, - tool_calls: None, - tool_call_id: None, + content: "Hello how can I help?".to_string(), }, - Message { + TextMessage { role: "user".to_string(), - content: Some("What is Deep Learning?".to_string()), - name: None, - tool_calls: None, - tool_call_id: None, + content: "What is Deep Learning?".to_string(), }, - Message { + TextMessage { role: "assistant".to_string(), - content: Some("magic!".to_string()), - name: None, - tool_calls: None, - tool_call_id: None, + content: "magic!".to_string(), }, ], bos_token: Some("[BOS]"), @@ -1197,33 +1156,21 @@ mod tests { let chat_template_inputs = ChatTemplateInputs { messages: vec![ - Message { + TextMessage { role: "user".to_string(), - content: Some("Hi!".to_string()), - name: None, - tool_calls: None, - tool_call_id: None, + content: "Hi!".to_string(), }, - Message { + TextMessage { role: "assistant".to_string(), - content: Some("Hello how can I help?".to_string()), - name: None, - tool_calls: None, - tool_call_id: None, + content: "Hello how can I help?".to_string(), }, - Message { + TextMessage { role: "user".to_string(), - content: Some("What is Deep Learning?".to_string()), - name: None, - tool_calls: None, - tool_call_id: None, + content: "What is Deep Learning?".to_string(), }, - Message { + TextMessage { role: "assistant".to_string(), - content: Some("magic!".to_string()), - name: None, - tool_calls: None, - tool_call_id: None, + content: "magic!".to_string(), }, ], bos_token: Some("[BOS]"), @@ -1246,38 +1193,24 @@ mod tests { #[test] fn test_many_chat_templates() { let example_chat = vec![ - Message { + TextMessage { role: "user".to_string(), - content: Some("Hello, how are you?".to_string()), - name: None, - tool_calls: None, - tool_call_id: None, + content: "Hello, how are you?".to_string(), }, - Message { + TextMessage { role: "assistant".to_string(), - content: Some("I'm doing great. How can I help you today?".to_string()), - name: None, - tool_calls: None, - tool_call_id: None, + content: "I'm doing great. How can I help you today?".to_string(), }, - Message { + TextMessage { role: "user".to_string(), - content: Some("I'd like to show off how chat templating works!".to_string()), - name: None, - tool_calls: None, - tool_call_id: None, + content: "I'd like to show off how chat templating works!".to_string(), }, ]; - let example_chat_with_system = [Message { + let example_chat_with_system = [TextMessage { role: "system".to_string(), - content: Some( - "You are a friendly chatbot who always responds in the style of a pirate" - .to_string(), - ), - name: None, - tool_calls: None, - tool_call_id: None, + content: "You are a friendly chatbot who always responds in the style of a pirate" + .to_string(), }] .iter() .chain(&example_chat) @@ -1417,19 +1350,13 @@ mod tests { chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}", input: ChatTemplateInputs { messages: vec![ - Message { + TextMessage{ role: "system".to_string(), - content: Some("You are a friendly chatbot who always responds in the style of a pirate".to_string()), - name: None, - tool_calls: None, - tool_call_id: None, + content: "You are a friendly chatbot who always responds in the style of a pirate".to_string(), }, - Message { + TextMessage{ role: "user".to_string(), - content: Some("How many helicopters can a human eat in one sitting?".to_string()), - name: None, - tool_calls: None, - tool_call_id: None, + content: "How many helicopters can a human eat in one sitting?".to_string(), }, ], add_generation_prompt: true, diff --git a/router/src/lib.rs b/router/src/lib.rs index 85e18dfb..5ae861dd 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -11,6 +11,7 @@ use queue::{Entry, Queue}; use serde::{Deserialize, Serialize}; use tokio::sync::OwnedSemaphorePermit; use tokio_stream::wrappers::UnboundedReceiverStream; +use tracing::warn; use utoipa::ToSchema; use validation::Validation; @@ -440,7 +441,7 @@ pub(crate) struct ChatCompletion { #[derive(Clone, Deserialize, Serialize, ToSchema)] pub(crate) struct ChatCompletionComplete { pub index: u32, - pub message: Message, + pub message: OutputMessage, pub logprobs: Option, pub finish_reason: String, } @@ -533,6 +534,30 @@ impl ChatCompletion { return_logprobs: bool, tool_calls: Option>, ) -> Self { + let message = match (output, tool_calls) { + (Some(content), None) => OutputMessage::ChatMessage(TextMessage { + role: "assistant".into(), + content, + }), + (None, Some(tool_calls)) => OutputMessage::ToolCall(ToolCallMessage { + role: "assistant".to_string(), + tool_calls, + }), + (Some(output), Some(_)) => { + warn!("Received both chat and tool call"); + OutputMessage::ChatMessage(TextMessage { + role: "assistant".into(), + content: output, + }) + } + (None, None) => { + warn!("Didn't receive an answer"); + OutputMessage::ChatMessage(TextMessage { + role: "assistant".into(), + content: "".to_string(), + }) + } + }; Self { id: String::new(), object: "text_completion".into(), @@ -541,13 +566,7 @@ impl ChatCompletion { system_fingerprint, choices: vec![ChatCompletionComplete { index: 0, - message: Message { - role: "assistant".into(), - content: output, - name: None, - tool_calls, - tool_call_id: None, - }, + message, logprobs: return_logprobs .then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))), finish_reason: details.finish_reason.to_string(), @@ -569,6 +588,7 @@ pub(crate) struct CompletionCompleteChunk { pub model: String, pub system_fingerprint: String, } + #[derive(Clone, Deserialize, Serialize, ToSchema)] pub(crate) struct ChatCompletionChunk { pub id: String, @@ -589,21 +609,20 @@ pub(crate) struct ChatCompletionChoice { pub finish_reason: Option, } -#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] -pub(crate) struct ChatCompletionDelta { - #[schema(example = "user")] - // TODO Modify this to a true enum. - #[serde(default, skip_serializing_if = "Option::is_none")] - pub role: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - #[schema(example = "What is Deep Learning?")] - pub content: Option, - // default to None - #[serde(default, skip_serializing_if = "Option::is_none")] - pub tool_calls: Option, +#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] +pub struct ToolCallDelta { + #[schema(example = "assistant")] + role: String, + tool_calls: DeltaToolCall, } -#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)] +#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] +enum ChatCompletionDelta { + Chat(TextMessage), + Tool(ToolCallDelta), +} + +#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)] pub(crate) struct DeltaToolCall { pub index: u32, pub id: String, @@ -611,7 +630,7 @@ pub(crate) struct DeltaToolCall { pub function: Function, } -#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)] +#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)] pub(crate) struct Function { pub name: Option, pub arguments: String, @@ -629,15 +648,13 @@ impl ChatCompletionChunk { finish_reason: Option, ) -> Self { let delta = match (delta, tool_calls) { - (Some(delta), _) => ChatCompletionDelta { - role: Some("assistant".to_string()), - content: Some(delta), - tool_calls: None, - }, - (None, Some(tool_calls)) => ChatCompletionDelta { - role: Some("assistant".to_string()), - content: None, - tool_calls: Some(DeltaToolCall { + (Some(delta), _) => ChatCompletionDelta::Chat(TextMessage { + role: "assistant".to_string(), + content: delta, + }), + (None, Some(tool_calls)) => ChatCompletionDelta::Tool(ToolCallDelta { + role: "assistant".to_string(), + tool_calls: DeltaToolCall { index: 0, id: String::new(), r#type: "function".to_string(), @@ -645,13 +662,12 @@ impl ChatCompletionChunk { name: None, arguments: tool_calls[0].to_string(), }, - }), - }, - (None, None) => ChatCompletionDelta { - role: None, - content: None, - tool_calls: None, - }, + }, + }), + (None, None) => ChatCompletionDelta::Chat(TextMessage { + role: "assistant".to_string(), + content: "".to_string(), + }), }; Self { id: String::new(), @@ -852,7 +868,7 @@ where state.end() } -#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)] +#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default, PartialEq)] pub(crate) struct FunctionDefinition { #[serde(default)] pub description: Option, @@ -872,7 +888,7 @@ pub(crate) struct Tool { #[derive(Clone, Serialize, Deserialize, Default)] pub(crate) struct ChatTemplateInputs<'a> { - messages: Vec, + messages: Vec, bos_token: Option<&'a str>, eos_token: Option<&'a str>, add_generation_prompt: bool, @@ -880,91 +896,113 @@ pub(crate) struct ChatTemplateInputs<'a> { tools_prompt: Option<&'a str>, } -#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)] +#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug, PartialEq)] pub(crate) struct ToolCall { pub id: String, pub r#type: String, pub function: FunctionDefinition, } -#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)] -pub(crate) struct Text { - #[serde(default)] - pub text: String, +#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] +struct Url { + url: String, } -#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)] -pub(crate) struct ImageUrl { - #[serde(default)] - pub url: String, +#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] +struct ImageUrl { + image_url: Url, } -#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)] -pub(crate) struct Content { - pub r#type: String, +#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] +struct Text { + text: String, +} + +#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] +#[serde(tag = "type")] +#[serde(rename_all = "snake_case")] +enum MessageChunk { + Text(Text), + ImageUrl(ImageUrl), +} + +#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] +pub struct Message { + #[schema(example = "user")] + role: String, + #[schema(example = "My name is David and I")] + #[serde(deserialize_with = "message_content_serde::deserialize")] + content: Vec, #[serde(default, skip_serializing_if = "Option::is_none")] - pub text: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub image_url: Option, + #[schema(example = "\"David\"")] + name: Option, } mod message_content_serde { use super::*; - use serde::de; - use serde::Deserializer; - use serde_json::Value; + use serde::{Deserialize, Deserializer}; - pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> where D: Deserializer<'de>, { - let value = Value::deserialize(deserializer)?; - match value { - Value::String(s) => Ok(Some(s)), - Value::Array(arr) => { - let results: Result, _> = arr - .into_iter() - .map(|v| { - let content: Content = - serde_json::from_value(v).map_err(de::Error::custom)?; - match content.r#type.as_str() { - "text" => Ok(content.text.unwrap_or_default()), - "image_url" => { - if let Some(url) = content.image_url { - Ok(format!("![]({})", url.url)) - } else { - Ok(String::new()) - } - } - _ => Err(de::Error::custom("invalid content type")), - } - }) - .collect(); - - results.map(|strings| Some(strings.join(""))) + #[derive(Deserialize)] + #[serde(untagged)] + enum Message { + Text(String), + Chunks(Vec), + } + let message: Message = Deserialize::deserialize(deserializer)?; + let chunks = match message { + Message::Text(text) => { + vec![MessageChunk::Text(Text { text })] } - Value::Null => Ok(None), - _ => Err(de::Error::custom("invalid token format")), + Message::Chunks(s) => s, + }; + Ok(chunks) + } +} + +#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] +pub struct TextMessage { + #[schema(example = "user")] + pub role: String, + #[schema(example = "My name is David and I")] + pub content: String, +} + +impl From for TextMessage { + fn from(value: Message) -> Self { + TextMessage { + role: value.role, + content: value + .content + .into_iter() + .map(|c| match c { + MessageChunk::Text(Text { text }) => text, + MessageChunk::ImageUrl(image) => { + let url = image.image_url.url; + format!("![]({url})") + } + }) + .collect::>() + .join(""), } } } -#[derive(Clone, Deserialize, ToSchema, Serialize, Debug)] -pub(crate) struct Message { - #[schema(example = "user")] - pub role: String, - #[serde(skip_serializing_if = "Option::is_none")] - #[schema(example = "My name is David and I")] - #[serde(default, deserialize_with = "message_content_serde::deserialize")] - pub content: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - #[schema(example = "\"David\"")] - pub name: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub tool_calls: Option>, - #[serde(default, skip_serializing_if = "Option::is_none")] - #[schema(example = "\"get_weather\"")] - pub tool_call_id: Option, +#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] +pub struct ToolCallMessage { + #[schema(example = "assistant")] + role: String, + tool_calls: Vec, +} + +#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] +#[serde(untagged)] +pub(crate) enum OutputMessage { + ChatMessage(TextMessage), + ToolCall(ToolCallMessage), } #[derive(Clone, Debug, Deserialize, ToSchema)] @@ -1127,7 +1165,7 @@ pub(crate) struct ErrorResponse { #[cfg(test)] mod tests { use super::*; - + use serde_json::json; use tokenizers::Tokenizer; pub(crate) async fn get_tokenizer() -> Tokenizer { @@ -1195,4 +1233,66 @@ mod tests { ); assert_eq!(config.eos_token, Some("<|end▁of▁sentence|>".to_string())); } + + #[test] + fn test_chat_simple_string() { + let json = json!( + + { + "model": "", + "messages": [ + {"role": "user", + "content": "What is Deep Learning?" + } + ] + }); + let request: ChatRequest = serde_json::from_str(json.to_string().as_str()).unwrap(); + + assert_eq!( + request.messages[0], + Message { + role: "user".to_string(), + content: vec![MessageChunk::Text(Text { + text: "What is Deep Learning?".to_string() + }),], + name: None + } + ); + } + + #[test] + fn test_chat_request() { + let json = json!( + + { + "model": "", + "messages": [ + {"role": "user", + "content": [ + {"type": "text", "text": "Whats in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png" + }, + }, + ] + + } + ] + }); + let request: ChatRequest = serde_json::from_str(json.to_string().as_str()).unwrap(); + + assert_eq!( + request.messages[0], + Message{ + role: "user".to_string(), + content: vec![ + MessageChunk::Text(Text { text: "Whats in this image?".to_string() }), + MessageChunk::ImageUrl(ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } }) + ], + name: None + } + ); + } }