feat: improve message content chunks handling
This commit is contained in:
parent
27b3a2c9fc
commit
c98a6b9948
|
@ -362,16 +362,21 @@ impl ChatTemplate {
|
|||
if self.use_default_tool_template {
|
||||
if let Some(last_message) = messages.last_mut() {
|
||||
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
|
||||
last_message.content = Some(format!(
|
||||
"{}\n---\n{}\n{}",
|
||||
last_message.content.as_deref().unwrap_or_default(),
|
||||
tool_prompt,
|
||||
tools
|
||||
));
|
||||
let last_message_content =
|
||||
last_message.content.clone().unwrap_or_else(|| "".into());
|
||||
last_message.content = Some(
|
||||
format!(
|
||||
"{:?}\n---\n{}\n{}",
|
||||
last_message_content, tool_prompt, tools
|
||||
)
|
||||
.into(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!("{:?}", messages);
|
||||
|
||||
self.template
|
||||
.render(ChatTemplateInputs {
|
||||
messages,
|
||||
|
@ -976,25 +981,25 @@ mod tests {
|
|||
messages: vec![
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: Some("Hi!".to_string()),
|
||||
content: Some("Hi!".into()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("Hello how can I help?".to_string()),
|
||||
content: Some("Hello how can I help?".into()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: Some("What is Deep Learning?".to_string()),
|
||||
content: Some("What is Deep Learning?".into()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("magic!".to_string()),
|
||||
content: Some("magic!".into()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
|
@ -1046,31 +1051,31 @@ mod tests {
|
|||
messages: vec![
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: Some("Hi!".to_string()),
|
||||
content: Some("Hi!".into()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: Some("Hi again!".to_string()),
|
||||
content: Some("Hi again!".into()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("Hello how can I help?".to_string()),
|
||||
content: Some("Hello how can I help?".into()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: Some("What is Deep Learning?".to_string()),
|
||||
content: Some("What is Deep Learning?".into()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("magic!".to_string()),
|
||||
content: Some("magic!".into()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
|
@ -1127,25 +1132,25 @@ mod tests {
|
|||
messages: vec![
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: Some("Hi!".to_string()),
|
||||
content: Some("Hi!".into()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("Hello how can I help?".to_string()),
|
||||
content: Some("Hello how can I help?".into()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: Some("What is Deep Learning?".to_string()),
|
||||
content: Some("What is Deep Learning?".into()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("magic!".to_string()),
|
||||
content: Some("magic!".into()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
|
@ -1186,25 +1191,25 @@ mod tests {
|
|||
messages: vec![
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: Some("Hi!".to_string()),
|
||||
content: Some("Hi!".into()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("Hello how can I help?".to_string()),
|
||||
content: Some("Hello how can I help?".into()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: Some("What is Deep Learning?".to_string()),
|
||||
content: Some("What is Deep Learning?".into()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("magic!".to_string()),
|
||||
content: Some("magic!".into()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
|
@ -1231,29 +1236,28 @@ mod tests {
|
|||
let example_chat = vec![
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: Some("Hello, how are you?".to_string()),
|
||||
content: Some("Hello, how are you?".into()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("I'm doing great. How can I help you today?".to_string()),
|
||||
content: Some("I'm doing great. How can I help you today?".into()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: Some("I'd like to show off how chat templating works!".to_string()),
|
||||
content: Some("I'd like to show off how chat templating works!".into()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
];
|
||||
|
||||
let example_chat_with_system = vec![Message {
|
||||
let example_chat_with_system = [Message {
|
||||
role: "system".to_string(),
|
||||
content: Some(
|
||||
"You are a friendly chatbot who always responds in the style of a pirate"
|
||||
.to_string(),
|
||||
"You are a friendly chatbot who always responds in the style of a pirate".into(),
|
||||
),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
|
@ -1373,7 +1377,7 @@ mod tests {
|
|||
{
|
||||
let mut env = Environment::new();
|
||||
env.add_function("raise_exception", raise_exception);
|
||||
let tmpl = env.template_from_str(&chat_template);
|
||||
let tmpl = env.template_from_str(chat_template);
|
||||
let result = tmpl.unwrap().render(input).unwrap();
|
||||
assert_eq!(result, target);
|
||||
}
|
||||
|
@ -1398,13 +1402,13 @@ mod tests {
|
|||
messages: vec![
|
||||
Message {
|
||||
role: "system".to_string(),
|
||||
content: Some("You are a friendly chatbot who always responds in the style of a pirate".to_string()),
|
||||
content: Some("You are a friendly chatbot who always responds in the style of a pirate".into()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: Some("How many helicopters can a human eat in one sitting?".to_string()),
|
||||
content: Some("How many helicopters can a human eat in one sitting?".into()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
|
|
|
@ -541,7 +541,7 @@ impl ChatCompletion {
|
|||
index: 0,
|
||||
message: Message {
|
||||
role: "assistant".into(),
|
||||
content: output,
|
||||
content: output.map(|content| vec![content.into()].into()),
|
||||
name: None,
|
||||
tool_calls,
|
||||
},
|
||||
|
@ -896,52 +896,99 @@ pub(crate) struct ImageUrl {
|
|||
pub url: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
|
||||
pub(crate) struct Content {
|
||||
pub r#type: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub text: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub image_url: Option<ImageUrl>,
|
||||
#[derive(Clone, Deserialize, Serialize, Debug)]
|
||||
enum ContentChunk {
|
||||
Text(String),
|
||||
ImageUrl(String),
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Debug)]
|
||||
struct ContentChunks(Vec<ContentChunk>);
|
||||
|
||||
// Convert in and out of ContentChunk
|
||||
impl From<String> for ContentChunk {
|
||||
fn from(s: String) -> Self {
|
||||
ContentChunk::Text(s)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for ContentChunk {
|
||||
fn from(s: &str) -> Self {
|
||||
s.to_string().into()
|
||||
}
|
||||
}
|
||||
|
||||
// Convert in and out of ContentChunks
|
||||
impl From<Vec<ContentChunk>> for ContentChunks {
|
||||
fn from(chunks: Vec<ContentChunk>) -> Self {
|
||||
Self(chunks)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for ContentChunks {
|
||||
fn from(s: &str) -> Self {
|
||||
vec![s.into()].into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for ContentChunks {
|
||||
fn from(s: String) -> Self {
|
||||
vec![s.into()].into()
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for ContentChunks {
|
||||
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
|
||||
let formatted = self
|
||||
.0
|
||||
.iter()
|
||||
.map(|chunk| match chunk {
|
||||
ContentChunk::Text(s) => s.clone(),
|
||||
ContentChunk::ImageUrl(s) => format!("![]({})", s),
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(" ");
|
||||
serializer.serialize_str(&formatted)
|
||||
}
|
||||
}
|
||||
|
||||
mod message_content_serde {
|
||||
use super::*;
|
||||
use serde::de;
|
||||
use serde::Deserializer;
|
||||
use serde::de::{Deserialize, Deserializer, Error};
|
||||
use serde_json::Value;
|
||||
|
||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
|
||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<ContentChunks>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let value = Value::deserialize(deserializer)?;
|
||||
match value {
|
||||
Value::String(s) => Ok(Some(s)),
|
||||
Value::Array(arr) => {
|
||||
let results: Result<Vec<String>, _> = arr
|
||||
.into_iter()
|
||||
.map(|v| {
|
||||
let content: Content =
|
||||
serde_json::from_value(v).map_err(de::Error::custom)?;
|
||||
match content.r#type.as_str() {
|
||||
"text" => Ok(content.text.unwrap_or_default()),
|
||||
"image_url" => {
|
||||
if let Some(url) = content.image_url {
|
||||
Ok(format!("![]({})", url.url))
|
||||
} else {
|
||||
Ok(String::new())
|
||||
}
|
||||
}
|
||||
_ => Err(de::Error::custom("invalid content type")),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
results.map(|strings| Some(strings.join("")))
|
||||
}
|
||||
match value {
|
||||
Value::String(s) => Ok(Some(vec![s.into()].into())),
|
||||
Value::Array(arr) => arr
|
||||
.into_iter()
|
||||
.map(|v| match v {
|
||||
Value::String(s) => Ok(ContentChunk::Text(s)),
|
||||
Value::Object(map) => match map
|
||||
.get("image_url")
|
||||
.and_then(|x| x.get("url").and_then(|u| u.as_str()))
|
||||
{
|
||||
Some(url) => Ok(ContentChunk::ImageUrl(url.to_string())),
|
||||
None => map
|
||||
.get("text")
|
||||
.and_then(|t| t.as_str())
|
||||
.map(|text| Ok(ContentChunk::Text(text.to_string())))
|
||||
.map_or_else(
|
||||
|| Err(Error::custom("Expected a string or an object")),
|
||||
|x| x,
|
||||
),
|
||||
},
|
||||
_ => Err(Error::custom("Expected a string or an object")),
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map(|chunks| Some(chunks.into())),
|
||||
Value::Null => Ok(None),
|
||||
_ => Err(de::Error::custom("invalid token format")),
|
||||
_ => Err(Error::custom("Invalid content format")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -953,7 +1000,7 @@ pub(crate) struct Message {
|
|||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[schema(example = "My name is David and I")]
|
||||
#[serde(deserialize_with = "message_content_serde::deserialize")]
|
||||
pub content: Option<String>,
|
||||
pub content: Option<ContentChunks>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
#[schema(example = "\"David\"")]
|
||||
pub name: Option<String>,
|
||||
|
|
Loading…
Reference in New Issue