From c98a6b9948d22c0be6bd5c15b9822ae90e6a771a Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 2 May 2024 03:46:40 +0000 Subject: [PATCH] feat: improve message content chunks handling --- router/src/infer.rs | 68 +++++++++++++------------ router/src/lib.rs | 119 ++++++++++++++++++++++++++++++-------------- 2 files changed, 119 insertions(+), 68 deletions(-) diff --git a/router/src/infer.rs b/router/src/infer.rs index da0db072..58d96b3e 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -362,16 +362,21 @@ impl ChatTemplate { if self.use_default_tool_template { if let Some(last_message) = messages.last_mut() { if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt { - last_message.content = Some(format!( - "{}\n---\n{}\n{}", - last_message.content.as_deref().unwrap_or_default(), - tool_prompt, - tools - )); + let last_message_content = + last_message.content.clone().unwrap_or_else(|| "".into()); + last_message.content = Some( + format!( + "{:?}\n---\n{}\n{}", + last_message_content, tool_prompt, tools + ) + .into(), + ); } } } + println!("{:?}", messages); + self.template .render(ChatTemplateInputs { messages, @@ -976,25 +981,25 @@ mod tests { messages: vec![ Message { role: "user".to_string(), - content: Some("Hi!".to_string()), + content: Some("Hi!".into()), name: None, tool_calls: None, }, Message { role: "assistant".to_string(), - content: Some("Hello how can I help?".to_string()), + content: Some("Hello how can I help?".into()), name: None, tool_calls: None, }, Message { role: "user".to_string(), - content: Some("What is Deep Learning?".to_string()), + content: Some("What is Deep Learning?".into()), name: None, tool_calls: None, }, Message { role: "assistant".to_string(), - content: Some("magic!".to_string()), + content: Some("magic!".into()), name: None, tool_calls: None, }, @@ -1046,31 +1051,31 @@ mod tests { messages: vec![ Message { role: "user".to_string(), - content: Some("Hi!".to_string()), + content: Some("Hi!".into()), name: None, tool_calls: None, }, Message { role: "user".to_string(), - content: Some("Hi again!".to_string()), + content: Some("Hi again!".into()), name: None, tool_calls: None, }, Message { role: "assistant".to_string(), - content: Some("Hello how can I help?".to_string()), + content: Some("Hello how can I help?".into()), name: None, tool_calls: None, }, Message { role: "user".to_string(), - content: Some("What is Deep Learning?".to_string()), + content: Some("What is Deep Learning?".into()), name: None, tool_calls: None, }, Message { role: "assistant".to_string(), - content: Some("magic!".to_string()), + content: Some("magic!".into()), name: None, tool_calls: None, }, @@ -1127,25 +1132,25 @@ mod tests { messages: vec![ Message { role: "user".to_string(), - content: Some("Hi!".to_string()), + content: Some("Hi!".into()), name: None, tool_calls: None, }, Message { role: "assistant".to_string(), - content: Some("Hello how can I help?".to_string()), + content: Some("Hello how can I help?".into()), name: None, tool_calls: None, }, Message { role: "user".to_string(), - content: Some("What is Deep Learning?".to_string()), + content: Some("What is Deep Learning?".into()), name: None, tool_calls: None, }, Message { role: "assistant".to_string(), - content: Some("magic!".to_string()), + content: Some("magic!".into()), name: None, tool_calls: None, }, @@ -1186,25 +1191,25 @@ mod tests { messages: vec![ Message { role: "user".to_string(), - content: Some("Hi!".to_string()), + content: Some("Hi!".into()), name: None, tool_calls: None, }, Message { role: "assistant".to_string(), - content: Some("Hello how can I help?".to_string()), + content: Some("Hello how can I help?".into()), name: None, tool_calls: None, }, Message { role: "user".to_string(), - content: Some("What is Deep Learning?".to_string()), + content: Some("What is Deep Learning?".into()), name: None, tool_calls: None, }, Message { role: "assistant".to_string(), - content: Some("magic!".to_string()), + content: Some("magic!".into()), name: None, tool_calls: None, }, @@ -1231,29 +1236,28 @@ mod tests { let example_chat = vec![ Message { role: "user".to_string(), - content: Some("Hello, how are you?".to_string()), + content: Some("Hello, how are you?".into()), name: None, tool_calls: None, }, Message { role: "assistant".to_string(), - content: Some("I'm doing great. How can I help you today?".to_string()), + content: Some("I'm doing great. How can I help you today?".into()), name: None, tool_calls: None, }, Message { role: "user".to_string(), - content: Some("I'd like to show off how chat templating works!".to_string()), + content: Some("I'd like to show off how chat templating works!".into()), name: None, tool_calls: None, }, ]; - let example_chat_with_system = vec![Message { + let example_chat_with_system = [Message { role: "system".to_string(), content: Some( - "You are a friendly chatbot who always responds in the style of a pirate" - .to_string(), + "You are a friendly chatbot who always responds in the style of a pirate".into(), ), name: None, tool_calls: None, @@ -1373,7 +1377,7 @@ mod tests { { let mut env = Environment::new(); env.add_function("raise_exception", raise_exception); - let tmpl = env.template_from_str(&chat_template); + let tmpl = env.template_from_str(chat_template); let result = tmpl.unwrap().render(input).unwrap(); assert_eq!(result, target); } @@ -1398,13 +1402,13 @@ mod tests { messages: vec![ Message { role: "system".to_string(), - content: Some("You are a friendly chatbot who always responds in the style of a pirate".to_string()), + content: Some("You are a friendly chatbot who always responds in the style of a pirate".into()), name: None, tool_calls: None, }, Message { role: "user".to_string(), - content: Some("How many helicopters can a human eat in one sitting?".to_string()), + content: Some("How many helicopters can a human eat in one sitting?".into()), name: None, tool_calls: None, }, diff --git a/router/src/lib.rs b/router/src/lib.rs index fac4c14e..81d68fd1 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -541,7 +541,7 @@ impl ChatCompletion { index: 0, message: Message { role: "assistant".into(), - content: output, + content: output.map(|content| vec![content.into()].into()), name: None, tool_calls, }, @@ -896,52 +896,99 @@ pub(crate) struct ImageUrl { pub url: String, } -#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)] -pub(crate) struct Content { - pub r#type: String, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub text: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub image_url: Option, +#[derive(Clone, Deserialize, Serialize, Debug)] +enum ContentChunk { + Text(String), + ImageUrl(String), +} + +#[derive(Clone, Deserialize, Debug)] +struct ContentChunks(Vec); + +// Convert in and out of ContentChunk +impl From for ContentChunk { + fn from(s: String) -> Self { + ContentChunk::Text(s) + } +} + +impl From<&str> for ContentChunk { + fn from(s: &str) -> Self { + s.to_string().into() + } +} + +// Convert in and out of ContentChunks +impl From> for ContentChunks { + fn from(chunks: Vec) -> Self { + Self(chunks) + } +} + +impl From<&str> for ContentChunks { + fn from(s: &str) -> Self { + vec![s.into()].into() + } +} + +impl From for ContentChunks { + fn from(s: String) -> Self { + vec![s.into()].into() + } +} + +impl Serialize for ContentChunks { + fn serialize(&self, serializer: S) -> Result { + let formatted = self + .0 + .iter() + .map(|chunk| match chunk { + ContentChunk::Text(s) => s.clone(), + ContentChunk::ImageUrl(s) => format!("![]({})", s), + }) + .collect::>() + .join(" "); + serializer.serialize_str(&formatted) + } } mod message_content_serde { use super::*; - use serde::de; - use serde::Deserializer; + use serde::de::{Deserialize, Deserializer, Error}; use serde_json::Value; - pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> where D: Deserializer<'de>, { let value = Value::deserialize(deserializer)?; - match value { - Value::String(s) => Ok(Some(s)), - Value::Array(arr) => { - let results: Result, _> = arr - .into_iter() - .map(|v| { - let content: Content = - serde_json::from_value(v).map_err(de::Error::custom)?; - match content.r#type.as_str() { - "text" => Ok(content.text.unwrap_or_default()), - "image_url" => { - if let Some(url) = content.image_url { - Ok(format!("![]({})", url.url)) - } else { - Ok(String::new()) - } - } - _ => Err(de::Error::custom("invalid content type")), - } - }) - .collect(); - results.map(|strings| Some(strings.join(""))) - } + match value { + Value::String(s) => Ok(Some(vec![s.into()].into())), + Value::Array(arr) => arr + .into_iter() + .map(|v| match v { + Value::String(s) => Ok(ContentChunk::Text(s)), + Value::Object(map) => match map + .get("image_url") + .and_then(|x| x.get("url").and_then(|u| u.as_str())) + { + Some(url) => Ok(ContentChunk::ImageUrl(url.to_string())), + None => map + .get("text") + .and_then(|t| t.as_str()) + .map(|text| Ok(ContentChunk::Text(text.to_string()))) + .map_or_else( + || Err(Error::custom("Expected a string or an object")), + |x| x, + ), + }, + _ => Err(Error::custom("Expected a string or an object")), + }) + .collect::, _>>() + .map(|chunks| Some(chunks.into())), Value::Null => Ok(None), - _ => Err(de::Error::custom("invalid token format")), + _ => Err(Error::custom("Invalid content format")), } } } @@ -953,7 +1000,7 @@ pub(crate) struct Message { #[serde(skip_serializing_if = "Option::is_none")] #[schema(example = "My name is David and I")] #[serde(deserialize_with = "message_content_serde::deserialize")] - pub content: Option, + pub content: Option, #[serde(default, skip_serializing_if = "Option::is_none")] #[schema(example = "\"David\"")] pub name: Option,