Fixing types. (#1906)
# What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
This commit is contained in:
parent
d8402eaf67
commit
f5d43414c2
|
@ -2,7 +2,7 @@
|
|||
use crate::validation::{Validation, ValidationError};
|
||||
use crate::{
|
||||
ChatTemplateInputs, ChatTemplateVersions, Entry, GenerateRequest, GenerateStreamResponse,
|
||||
HubTokenizerConfig, Message, PrefillToken, Queue, Token,
|
||||
HubTokenizerConfig, Message, MessageChunk, PrefillToken, Queue, Text, TextMessage, Token,
|
||||
};
|
||||
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
|
||||
use futures::future::try_join_all;
|
||||
|
@ -362,16 +362,15 @@ impl ChatTemplate {
|
|||
if self.use_default_tool_template {
|
||||
if let Some(last_message) = messages.last_mut() {
|
||||
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
|
||||
last_message.content = Some(format!(
|
||||
"{}\n---\n{}\n{}",
|
||||
last_message.content.as_deref().unwrap_or_default(),
|
||||
tool_prompt,
|
||||
tools
|
||||
));
|
||||
last_message.content.push(MessageChunk::Text(Text {
|
||||
text: format!("\n---\n{}\n{}", tool_prompt, tools),
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
||||
|
||||
self.template
|
||||
.render(ChatTemplateInputs {
|
||||
messages,
|
||||
|
@ -939,8 +938,7 @@ impl InferError {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::infer::raise_exception;
|
||||
use crate::ChatTemplateInputs;
|
||||
use crate::Message;
|
||||
use crate::{ChatTemplateInputs, TextMessage};
|
||||
use minijinja::Environment;
|
||||
|
||||
#[test]
|
||||
|
@ -974,33 +972,21 @@ mod tests {
|
|||
|
||||
let chat_template_inputs = ChatTemplateInputs {
|
||||
messages: vec![
|
||||
Message {
|
||||
TextMessage {
|
||||
role: "user".to_string(),
|
||||
content: Some("Hi!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
content: "Hi!".to_string(),
|
||||
},
|
||||
Message {
|
||||
TextMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("Hello how can I help?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
content: "Hello how can I help?".to_string(),
|
||||
},
|
||||
Message {
|
||||
TextMessage {
|
||||
role: "user".to_string(),
|
||||
content: Some("What is Deep Learning?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
content: "What is Deep Learning?".to_string(),
|
||||
},
|
||||
Message {
|
||||
TextMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("magic!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
content: "magic!".to_string(),
|
||||
},
|
||||
],
|
||||
bos_token: Some("[BOS]"),
|
||||
|
@ -1048,40 +1034,25 @@ mod tests {
|
|||
|
||||
let chat_template_inputs = ChatTemplateInputs {
|
||||
messages: vec![
|
||||
Message {
|
||||
TextMessage {
|
||||
role: "user".to_string(),
|
||||
content: Some("Hi!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
content: "Hi!".to_string(),
|
||||
},
|
||||
Message {
|
||||
TextMessage {
|
||||
role: "user".to_string(),
|
||||
content: Some("Hi again!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
content: "Hi again!".to_string(),
|
||||
},
|
||||
Message {
|
||||
TextMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("Hello how can I help?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
content: "Hello how can I help?".to_string(),
|
||||
},
|
||||
Message {
|
||||
TextMessage {
|
||||
role: "user".to_string(),
|
||||
content: Some("What is Deep Learning?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
content: "What is Deep Learning?".to_string(),
|
||||
},
|
||||
Message {
|
||||
TextMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("magic!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
content: "magic!".to_string(),
|
||||
},
|
||||
],
|
||||
bos_token: Some("[BOS]"),
|
||||
|
@ -1134,33 +1105,21 @@ mod tests {
|
|||
|
||||
let chat_template_inputs = ChatTemplateInputs {
|
||||
messages: vec![
|
||||
Message {
|
||||
TextMessage {
|
||||
role: "user".to_string(),
|
||||
content: Some("Hi!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
content: "Hi!".to_string(),
|
||||
},
|
||||
Message {
|
||||
TextMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("Hello how can I help?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
content: "Hello how can I help?".to_string(),
|
||||
},
|
||||
Message {
|
||||
TextMessage {
|
||||
role: "user".to_string(),
|
||||
content: Some("What is Deep Learning?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
content: "What is Deep Learning?".to_string(),
|
||||
},
|
||||
Message {
|
||||
TextMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("magic!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
content: "magic!".to_string(),
|
||||
},
|
||||
],
|
||||
bos_token: Some("[BOS]"),
|
||||
|
@ -1197,33 +1156,21 @@ mod tests {
|
|||
|
||||
let chat_template_inputs = ChatTemplateInputs {
|
||||
messages: vec![
|
||||
Message {
|
||||
TextMessage {
|
||||
role: "user".to_string(),
|
||||
content: Some("Hi!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
content: "Hi!".to_string(),
|
||||
},
|
||||
Message {
|
||||
TextMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("Hello how can I help?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
content: "Hello how can I help?".to_string(),
|
||||
},
|
||||
Message {
|
||||
TextMessage {
|
||||
role: "user".to_string(),
|
||||
content: Some("What is Deep Learning?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
content: "What is Deep Learning?".to_string(),
|
||||
},
|
||||
Message {
|
||||
TextMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("magic!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
content: "magic!".to_string(),
|
||||
},
|
||||
],
|
||||
bos_token: Some("[BOS]"),
|
||||
|
@ -1246,38 +1193,24 @@ mod tests {
|
|||
#[test]
|
||||
fn test_many_chat_templates() {
|
||||
let example_chat = vec![
|
||||
Message {
|
||||
TextMessage {
|
||||
role: "user".to_string(),
|
||||
content: Some("Hello, how are you?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
content: "Hello, how are you?".to_string(),
|
||||
},
|
||||
Message {
|
||||
TextMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("I'm doing great. How can I help you today?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
content: "I'm doing great. How can I help you today?".to_string(),
|
||||
},
|
||||
Message {
|
||||
TextMessage {
|
||||
role: "user".to_string(),
|
||||
content: Some("I'd like to show off how chat templating works!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
content: "I'd like to show off how chat templating works!".to_string(),
|
||||
},
|
||||
];
|
||||
|
||||
let example_chat_with_system = [Message {
|
||||
let example_chat_with_system = [TextMessage {
|
||||
role: "system".to_string(),
|
||||
content: Some(
|
||||
"You are a friendly chatbot who always responds in the style of a pirate"
|
||||
content: "You are a friendly chatbot who always responds in the style of a pirate"
|
||||
.to_string(),
|
||||
),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}]
|
||||
.iter()
|
||||
.chain(&example_chat)
|
||||
|
@ -1417,19 +1350,13 @@ mod tests {
|
|||
chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}",
|
||||
input: ChatTemplateInputs {
|
||||
messages: vec![
|
||||
Message {
|
||||
TextMessage{
|
||||
role: "system".to_string(),
|
||||
content: Some("You are a friendly chatbot who always responds in the style of a pirate".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
content: "You are a friendly chatbot who always responds in the style of a pirate".to_string(),
|
||||
},
|
||||
Message {
|
||||
TextMessage{
|
||||
role: "user".to_string(),
|
||||
content: Some("How many helicopters can a human eat in one sitting?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
content: "How many helicopters can a human eat in one sitting?".to_string(),
|
||||
},
|
||||
],
|
||||
add_generation_prompt: true,
|
||||
|
|
|
@ -11,6 +11,7 @@ use queue::{Entry, Queue};
|
|||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::OwnedSemaphorePermit;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::warn;
|
||||
use utoipa::ToSchema;
|
||||
use validation::Validation;
|
||||
|
||||
|
@ -440,7 +441,7 @@ pub(crate) struct ChatCompletion {
|
|||
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||
pub(crate) struct ChatCompletionComplete {
|
||||
pub index: u32,
|
||||
pub message: Message,
|
||||
pub message: OutputMessage,
|
||||
pub logprobs: Option<ChatCompletionLogprobs>,
|
||||
pub finish_reason: String,
|
||||
}
|
||||
|
@ -533,6 +534,30 @@ impl ChatCompletion {
|
|||
return_logprobs: bool,
|
||||
tool_calls: Option<Vec<ToolCall>>,
|
||||
) -> Self {
|
||||
let message = match (output, tool_calls) {
|
||||
(Some(content), None) => OutputMessage::ChatMessage(TextMessage {
|
||||
role: "assistant".into(),
|
||||
content,
|
||||
}),
|
||||
(None, Some(tool_calls)) => OutputMessage::ToolCall(ToolCallMessage {
|
||||
role: "assistant".to_string(),
|
||||
tool_calls,
|
||||
}),
|
||||
(Some(output), Some(_)) => {
|
||||
warn!("Received both chat and tool call");
|
||||
OutputMessage::ChatMessage(TextMessage {
|
||||
role: "assistant".into(),
|
||||
content: output,
|
||||
})
|
||||
}
|
||||
(None, None) => {
|
||||
warn!("Didn't receive an answer");
|
||||
OutputMessage::ChatMessage(TextMessage {
|
||||
role: "assistant".into(),
|
||||
content: "".to_string(),
|
||||
})
|
||||
}
|
||||
};
|
||||
Self {
|
||||
id: String::new(),
|
||||
object: "text_completion".into(),
|
||||
|
@ -541,13 +566,7 @@ impl ChatCompletion {
|
|||
system_fingerprint,
|
||||
choices: vec![ChatCompletionComplete {
|
||||
index: 0,
|
||||
message: Message {
|
||||
role: "assistant".into(),
|
||||
content: output,
|
||||
name: None,
|
||||
tool_calls,
|
||||
tool_call_id: None,
|
||||
},
|
||||
message,
|
||||
logprobs: return_logprobs
|
||||
.then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))),
|
||||
finish_reason: details.finish_reason.to_string(),
|
||||
|
@ -569,6 +588,7 @@ pub(crate) struct CompletionCompleteChunk {
|
|||
pub model: String,
|
||||
pub system_fingerprint: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||
pub(crate) struct ChatCompletionChunk {
|
||||
pub id: String,
|
||||
|
@ -589,21 +609,20 @@ pub(crate) struct ChatCompletionChoice {
|
|||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||
pub(crate) struct ChatCompletionDelta {
|
||||
#[schema(example = "user")]
|
||||
// TODO Modify this to a true enum.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub role: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
#[schema(example = "What is Deep Learning?")]
|
||||
pub content: Option<String>,
|
||||
// default to None
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<DeltaToolCall>,
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
||||
pub struct ToolCallDelta {
|
||||
#[schema(example = "assistant")]
|
||||
role: String,
|
||||
tool_calls: DeltaToolCall,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||
enum ChatCompletionDelta {
|
||||
Chat(TextMessage),
|
||||
Tool(ToolCallDelta),
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
|
||||
pub(crate) struct DeltaToolCall {
|
||||
pub index: u32,
|
||||
pub id: String,
|
||||
|
@ -611,7 +630,7 @@ pub(crate) struct DeltaToolCall {
|
|||
pub function: Function,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
|
||||
pub(crate) struct Function {
|
||||
pub name: Option<String>,
|
||||
pub arguments: String,
|
||||
|
@ -629,15 +648,13 @@ impl ChatCompletionChunk {
|
|||
finish_reason: Option<String>,
|
||||
) -> Self {
|
||||
let delta = match (delta, tool_calls) {
|
||||
(Some(delta), _) => ChatCompletionDelta {
|
||||
role: Some("assistant".to_string()),
|
||||
content: Some(delta),
|
||||
tool_calls: None,
|
||||
},
|
||||
(None, Some(tool_calls)) => ChatCompletionDelta {
|
||||
role: Some("assistant".to_string()),
|
||||
content: None,
|
||||
tool_calls: Some(DeltaToolCall {
|
||||
(Some(delta), _) => ChatCompletionDelta::Chat(TextMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: delta,
|
||||
}),
|
||||
(None, Some(tool_calls)) => ChatCompletionDelta::Tool(ToolCallDelta {
|
||||
role: "assistant".to_string(),
|
||||
tool_calls: DeltaToolCall {
|
||||
index: 0,
|
||||
id: String::new(),
|
||||
r#type: "function".to_string(),
|
||||
|
@ -645,13 +662,12 @@ impl ChatCompletionChunk {
|
|||
name: None,
|
||||
arguments: tool_calls[0].to_string(),
|
||||
},
|
||||
},
|
||||
}),
|
||||
(None, None) => ChatCompletionDelta::Chat(TextMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: "".to_string(),
|
||||
}),
|
||||
},
|
||||
(None, None) => ChatCompletionDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
};
|
||||
Self {
|
||||
id: String::new(),
|
||||
|
@ -852,7 +868,7 @@ where
|
|||
state.end()
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)]
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default, PartialEq)]
|
||||
pub(crate) struct FunctionDefinition {
|
||||
#[serde(default)]
|
||||
pub description: Option<String>,
|
||||
|
@ -872,7 +888,7 @@ pub(crate) struct Tool {
|
|||
|
||||
#[derive(Clone, Serialize, Deserialize, Default)]
|
||||
pub(crate) struct ChatTemplateInputs<'a> {
|
||||
messages: Vec<Message>,
|
||||
messages: Vec<TextMessage>,
|
||||
bos_token: Option<&'a str>,
|
||||
eos_token: Option<&'a str>,
|
||||
add_generation_prompt: bool,
|
||||
|
@ -880,91 +896,113 @@ pub(crate) struct ChatTemplateInputs<'a> {
|
|||
tools_prompt: Option<&'a str>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug, PartialEq)]
|
||||
pub(crate) struct ToolCall {
|
||||
pub id: String,
|
||||
pub r#type: String,
|
||||
pub function: FunctionDefinition,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
|
||||
pub(crate) struct Text {
|
||||
#[serde(default)]
|
||||
pub text: String,
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
||||
struct Url {
|
||||
url: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
|
||||
pub(crate) struct ImageUrl {
|
||||
#[serde(default)]
|
||||
pub url: String,
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
||||
struct ImageUrl {
|
||||
image_url: Url,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
|
||||
pub(crate) struct Content {
|
||||
pub r#type: String,
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
||||
struct Text {
|
||||
text: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
||||
#[serde(tag = "type")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum MessageChunk {
|
||||
Text(Text),
|
||||
ImageUrl(ImageUrl),
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
||||
pub struct Message {
|
||||
#[schema(example = "user")]
|
||||
role: String,
|
||||
#[schema(example = "My name is David and I")]
|
||||
#[serde(deserialize_with = "message_content_serde::deserialize")]
|
||||
content: Vec<MessageChunk>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub text: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub image_url: Option<ImageUrl>,
|
||||
#[schema(example = "\"David\"")]
|
||||
name: Option<String>,
|
||||
}
|
||||
|
||||
mod message_content_serde {
|
||||
use super::*;
|
||||
use serde::de;
|
||||
use serde::Deserializer;
|
||||
use serde_json::Value;
|
||||
use serde::{Deserialize, Deserializer};
|
||||
|
||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
|
||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<MessageChunk>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let value = Value::deserialize(deserializer)?;
|
||||
match value {
|
||||
Value::String(s) => Ok(Some(s)),
|
||||
Value::Array(arr) => {
|
||||
let results: Result<Vec<String>, _> = arr
|
||||
.into_iter()
|
||||
.map(|v| {
|
||||
let content: Content =
|
||||
serde_json::from_value(v).map_err(de::Error::custom)?;
|
||||
match content.r#type.as_str() {
|
||||
"text" => Ok(content.text.unwrap_or_default()),
|
||||
"image_url" => {
|
||||
if let Some(url) = content.image_url {
|
||||
Ok(format!("![]({})", url.url))
|
||||
} else {
|
||||
Ok(String::new())
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum Message {
|
||||
Text(String),
|
||||
Chunks(Vec<MessageChunk>),
|
||||
}
|
||||
let message: Message = Deserialize::deserialize(deserializer)?;
|
||||
let chunks = match message {
|
||||
Message::Text(text) => {
|
||||
vec![MessageChunk::Text(Text { text })]
|
||||
}
|
||||
_ => Err(de::Error::custom("invalid content type")),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
results.map(|strings| Some(strings.join("")))
|
||||
}
|
||||
Value::Null => Ok(None),
|
||||
_ => Err(de::Error::custom("invalid token format")),
|
||||
}
|
||||
Message::Chunks(s) => s,
|
||||
};
|
||||
Ok(chunks)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug)]
|
||||
pub(crate) struct Message {
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
||||
pub struct TextMessage {
|
||||
#[schema(example = "user")]
|
||||
pub role: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[schema(example = "My name is David and I")]
|
||||
#[serde(default, deserialize_with = "message_content_serde::deserialize")]
|
||||
pub content: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
#[schema(example = "\"David\"")]
|
||||
pub name: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
#[schema(example = "\"get_weather\"")]
|
||||
pub tool_call_id: Option<String>,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
impl From<Message> for TextMessage {
|
||||
fn from(value: Message) -> Self {
|
||||
TextMessage {
|
||||
role: value.role,
|
||||
content: value
|
||||
.content
|
||||
.into_iter()
|
||||
.map(|c| match c {
|
||||
MessageChunk::Text(Text { text }) => text,
|
||||
MessageChunk::ImageUrl(image) => {
|
||||
let url = image.image_url.url;
|
||||
format!("![]({url})")
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(""),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
||||
pub struct ToolCallMessage {
|
||||
#[schema(example = "assistant")]
|
||||
role: String,
|
||||
tool_calls: Vec<ToolCall>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
||||
#[serde(untagged)]
|
||||
pub(crate) enum OutputMessage {
|
||||
ChatMessage(TextMessage),
|
||||
ToolCall(ToolCallMessage),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||
|
@ -1127,7 +1165,7 @@ pub(crate) struct ErrorResponse {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use serde_json::json;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
pub(crate) async fn get_tokenizer() -> Tokenizer {
|
||||
|
@ -1195,4 +1233,66 @@ mod tests {
|
|||
);
|
||||
assert_eq!(config.eos_token, Some("<|end▁of▁sentence|>".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_simple_string() {
|
||||
let json = json!(
|
||||
|
||||
{
|
||||
"model": "",
|
||||
"messages": [
|
||||
{"role": "user",
|
||||
"content": "What is Deep Learning?"
|
||||
}
|
||||
]
|
||||
});
|
||||
let request: ChatRequest = serde_json::from_str(json.to_string().as_str()).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
request.messages[0],
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: vec![MessageChunk::Text(Text {
|
||||
text: "What is Deep Learning?".to_string()
|
||||
}),],
|
||||
name: None
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_request() {
|
||||
let json = json!(
|
||||
|
||||
{
|
||||
"model": "",
|
||||
"messages": [
|
||||
{"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Whats in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
}
|
||||
]
|
||||
});
|
||||
let request: ChatRequest = serde_json::from_str(json.to_string().as_str()).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
request.messages[0],
|
||||
Message{
|
||||
role: "user".to_string(),
|
||||
content: vec![
|
||||
MessageChunk::Text(Text { text: "Whats in this image?".to_string() }),
|
||||
MessageChunk::ImageUrl(ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } })
|
||||
],
|
||||
name: None
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue