feat: improve message content chunks handling

This commit is contained in:
drbh 2024-05-02 03:46:40 +00:00
parent 27b3a2c9fc
commit c98a6b9948
2 changed files with 119 additions and 68 deletions

View File

@ -362,16 +362,21 @@ impl ChatTemplate {
if self.use_default_tool_template {
if let Some(last_message) = messages.last_mut() {
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
last_message.content = Some(format!(
"{}\n---\n{}\n{}",
last_message.content.as_deref().unwrap_or_default(),
tool_prompt,
tools
));
let last_message_content =
last_message.content.clone().unwrap_or_else(|| "".into());
last_message.content = Some(
format!(
"{:?}\n---\n{}\n{}",
last_message_content, tool_prompt, tools
)
.into(),
);
}
}
}
println!("{:?}", messages);
self.template
.render(ChatTemplateInputs {
messages,
@ -976,25 +981,25 @@ mod tests {
messages: vec![
Message {
role: "user".to_string(),
content: Some("Hi!".to_string()),
content: Some("Hi!".into()),
name: None,
tool_calls: None,
},
Message {
role: "assistant".to_string(),
content: Some("Hello how can I help?".to_string()),
content: Some("Hello how can I help?".into()),
name: None,
tool_calls: None,
},
Message {
role: "user".to_string(),
content: Some("What is Deep Learning?".to_string()),
content: Some("What is Deep Learning?".into()),
name: None,
tool_calls: None,
},
Message {
role: "assistant".to_string(),
content: Some("magic!".to_string()),
content: Some("magic!".into()),
name: None,
tool_calls: None,
},
@ -1046,31 +1051,31 @@ mod tests {
messages: vec![
Message {
role: "user".to_string(),
content: Some("Hi!".to_string()),
content: Some("Hi!".into()),
name: None,
tool_calls: None,
},
Message {
role: "user".to_string(),
content: Some("Hi again!".to_string()),
content: Some("Hi again!".into()),
name: None,
tool_calls: None,
},
Message {
role: "assistant".to_string(),
content: Some("Hello how can I help?".to_string()),
content: Some("Hello how can I help?".into()),
name: None,
tool_calls: None,
},
Message {
role: "user".to_string(),
content: Some("What is Deep Learning?".to_string()),
content: Some("What is Deep Learning?".into()),
name: None,
tool_calls: None,
},
Message {
role: "assistant".to_string(),
content: Some("magic!".to_string()),
content: Some("magic!".into()),
name: None,
tool_calls: None,
},
@ -1127,25 +1132,25 @@ mod tests {
messages: vec![
Message {
role: "user".to_string(),
content: Some("Hi!".to_string()),
content: Some("Hi!".into()),
name: None,
tool_calls: None,
},
Message {
role: "assistant".to_string(),
content: Some("Hello how can I help?".to_string()),
content: Some("Hello how can I help?".into()),
name: None,
tool_calls: None,
},
Message {
role: "user".to_string(),
content: Some("What is Deep Learning?".to_string()),
content: Some("What is Deep Learning?".into()),
name: None,
tool_calls: None,
},
Message {
role: "assistant".to_string(),
content: Some("magic!".to_string()),
content: Some("magic!".into()),
name: None,
tool_calls: None,
},
@ -1186,25 +1191,25 @@ mod tests {
messages: vec![
Message {
role: "user".to_string(),
content: Some("Hi!".to_string()),
content: Some("Hi!".into()),
name: None,
tool_calls: None,
},
Message {
role: "assistant".to_string(),
content: Some("Hello how can I help?".to_string()),
content: Some("Hello how can I help?".into()),
name: None,
tool_calls: None,
},
Message {
role: "user".to_string(),
content: Some("What is Deep Learning?".to_string()),
content: Some("What is Deep Learning?".into()),
name: None,
tool_calls: None,
},
Message {
role: "assistant".to_string(),
content: Some("magic!".to_string()),
content: Some("magic!".into()),
name: None,
tool_calls: None,
},
@ -1231,29 +1236,28 @@ mod tests {
let example_chat = vec![
Message {
role: "user".to_string(),
content: Some("Hello, how are you?".to_string()),
content: Some("Hello, how are you?".into()),
name: None,
tool_calls: None,
},
Message {
role: "assistant".to_string(),
content: Some("I'm doing great. How can I help you today?".to_string()),
content: Some("I'm doing great. How can I help you today?".into()),
name: None,
tool_calls: None,
},
Message {
role: "user".to_string(),
content: Some("I'd like to show off how chat templating works!".to_string()),
content: Some("I'd like to show off how chat templating works!".into()),
name: None,
tool_calls: None,
},
];
let example_chat_with_system = vec![Message {
let example_chat_with_system = [Message {
role: "system".to_string(),
content: Some(
"You are a friendly chatbot who always responds in the style of a pirate"
.to_string(),
"You are a friendly chatbot who always responds in the style of a pirate".into(),
),
name: None,
tool_calls: None,
@ -1373,7 +1377,7 @@ mod tests {
{
let mut env = Environment::new();
env.add_function("raise_exception", raise_exception);
let tmpl = env.template_from_str(&chat_template);
let tmpl = env.template_from_str(chat_template);
let result = tmpl.unwrap().render(input).unwrap();
assert_eq!(result, target);
}
@ -1398,13 +1402,13 @@ mod tests {
messages: vec![
Message {
role: "system".to_string(),
content: Some("You are a friendly chatbot who always responds in the style of a pirate".to_string()),
content: Some("You are a friendly chatbot who always responds in the style of a pirate".into()),
name: None,
tool_calls: None,
},
Message {
role: "user".to_string(),
content: Some("How many helicopters can a human eat in one sitting?".to_string()),
content: Some("How many helicopters can a human eat in one sitting?".into()),
name: None,
tool_calls: None,
},

View File

@ -541,7 +541,7 @@ impl ChatCompletion {
index: 0,
message: Message {
role: "assistant".into(),
content: output,
content: output.map(|content| vec![content.into()].into()),
name: None,
tool_calls,
},
@ -896,52 +896,99 @@ pub(crate) struct ImageUrl {
pub url: String,
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
pub(crate) struct Content {
pub r#type: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub image_url: Option<ImageUrl>,
#[derive(Clone, Deserialize, Serialize, Debug)]
enum ContentChunk {
Text(String),
ImageUrl(String),
}
#[derive(Clone, Deserialize, Debug)]
struct ContentChunks(Vec<ContentChunk>);
// Convert in and out of ContentChunk
impl From<String> for ContentChunk {
fn from(s: String) -> Self {
ContentChunk::Text(s)
}
}
impl From<&str> for ContentChunk {
fn from(s: &str) -> Self {
s.to_string().into()
}
}
// Convert in and out of ContentChunks
impl From<Vec<ContentChunk>> for ContentChunks {
fn from(chunks: Vec<ContentChunk>) -> Self {
Self(chunks)
}
}
impl From<&str> for ContentChunks {
fn from(s: &str) -> Self {
vec![s.into()].into()
}
}
impl From<String> for ContentChunks {
fn from(s: String) -> Self {
vec![s.into()].into()
}
}
impl Serialize for ContentChunks {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let formatted = self
.0
.iter()
.map(|chunk| match chunk {
ContentChunk::Text(s) => s.clone(),
ContentChunk::ImageUrl(s) => format!("![]({})", s),
})
.collect::<Vec<_>>()
.join(" ");
serializer.serialize_str(&formatted)
}
}
mod message_content_serde {
use super::*;
use serde::de;
use serde::Deserializer;
use serde::de::{Deserialize, Deserializer, Error};
use serde_json::Value;
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<ContentChunks>, D::Error>
where
D: Deserializer<'de>,
{
let value = Value::deserialize(deserializer)?;
match value {
Value::String(s) => Ok(Some(s)),
Value::Array(arr) => {
let results: Result<Vec<String>, _> = arr
.into_iter()
.map(|v| {
let content: Content =
serde_json::from_value(v).map_err(de::Error::custom)?;
match content.r#type.as_str() {
"text" => Ok(content.text.unwrap_or_default()),
"image_url" => {
if let Some(url) = content.image_url {
Ok(format!("![]({})", url.url))
} else {
Ok(String::new())
}
}
_ => Err(de::Error::custom("invalid content type")),
}
})
.collect();
results.map(|strings| Some(strings.join("")))
}
match value {
Value::String(s) => Ok(Some(vec![s.into()].into())),
Value::Array(arr) => arr
.into_iter()
.map(|v| match v {
Value::String(s) => Ok(ContentChunk::Text(s)),
Value::Object(map) => match map
.get("image_url")
.and_then(|x| x.get("url").and_then(|u| u.as_str()))
{
Some(url) => Ok(ContentChunk::ImageUrl(url.to_string())),
None => map
.get("text")
.and_then(|t| t.as_str())
.map(|text| Ok(ContentChunk::Text(text.to_string())))
.map_or_else(
|| Err(Error::custom("Expected a string or an object")),
|x| x,
),
},
_ => Err(Error::custom("Expected a string or an object")),
})
.collect::<Result<Vec<_>, _>>()
.map(|chunks| Some(chunks.into())),
Value::Null => Ok(None),
_ => Err(de::Error::custom("invalid token format")),
_ => Err(Error::custom("Invalid content format")),
}
}
}
@ -953,7 +1000,7 @@ pub(crate) struct Message {
#[serde(skip_serializing_if = "Option::is_none")]
#[schema(example = "My name is David and I")]
#[serde(deserialize_with = "message_content_serde::deserialize")]
pub content: Option<String>,
pub content: Option<ContentChunks>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = "\"David\"")]
pub name: Option<String>,