Fixing types. (#1906)

# What does this PR do?

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->
This commit is contained in:
Nicolas Patry 2024-05-16 16:59:05 +02:00 committed by GitHub
parent d8402eaf67
commit f5d43414c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 256 additions and 229 deletions

View File

@ -2,7 +2,7 @@
use crate::validation::{Validation, ValidationError}; use crate::validation::{Validation, ValidationError};
use crate::{ use crate::{
ChatTemplateInputs, ChatTemplateVersions, Entry, GenerateRequest, GenerateStreamResponse, ChatTemplateInputs, ChatTemplateVersions, Entry, GenerateRequest, GenerateStreamResponse,
HubTokenizerConfig, Message, PrefillToken, Queue, Token, HubTokenizerConfig, Message, MessageChunk, PrefillToken, Queue, Text, TextMessage, Token,
}; };
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools}; use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
use futures::future::try_join_all; use futures::future::try_join_all;
@ -362,16 +362,15 @@ impl ChatTemplate {
if self.use_default_tool_template { if self.use_default_tool_template {
if let Some(last_message) = messages.last_mut() { if let Some(last_message) = messages.last_mut() {
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt { if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
last_message.content = Some(format!( last_message.content.push(MessageChunk::Text(Text {
"{}\n---\n{}\n{}", text: format!("\n---\n{}\n{}", tool_prompt, tools),
last_message.content.as_deref().unwrap_or_default(), }));
tool_prompt,
tools
));
} }
} }
} }
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
self.template self.template
.render(ChatTemplateInputs { .render(ChatTemplateInputs {
messages, messages,
@ -939,8 +938,7 @@ impl InferError {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::infer::raise_exception; use crate::infer::raise_exception;
use crate::ChatTemplateInputs; use crate::{ChatTemplateInputs, TextMessage};
use crate::Message;
use minijinja::Environment; use minijinja::Environment;
#[test] #[test]
@ -974,33 +972,21 @@ mod tests {
let chat_template_inputs = ChatTemplateInputs { let chat_template_inputs = ChatTemplateInputs {
messages: vec![ messages: vec![
Message { TextMessage {
role: "user".to_string(), role: "user".to_string(),
content: Some("Hi!".to_string()), content: "Hi!".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
Message { TextMessage {
role: "assistant".to_string(), role: "assistant".to_string(),
content: Some("Hello how can I help?".to_string()), content: "Hello how can I help?".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
Message { TextMessage {
role: "user".to_string(), role: "user".to_string(),
content: Some("What is Deep Learning?".to_string()), content: "What is Deep Learning?".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
Message { TextMessage {
role: "assistant".to_string(), role: "assistant".to_string(),
content: Some("magic!".to_string()), content: "magic!".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
], ],
bos_token: Some("[BOS]"), bos_token: Some("[BOS]"),
@ -1048,40 +1034,25 @@ mod tests {
let chat_template_inputs = ChatTemplateInputs { let chat_template_inputs = ChatTemplateInputs {
messages: vec![ messages: vec![
Message { TextMessage {
role: "user".to_string(), role: "user".to_string(),
content: Some("Hi!".to_string()), content: "Hi!".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
Message { TextMessage {
role: "user".to_string(), role: "user".to_string(),
content: Some("Hi again!".to_string()), content: "Hi again!".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
Message { TextMessage {
role: "assistant".to_string(), role: "assistant".to_string(),
content: Some("Hello how can I help?".to_string()), content: "Hello how can I help?".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
Message { TextMessage {
role: "user".to_string(), role: "user".to_string(),
content: Some("What is Deep Learning?".to_string()), content: "What is Deep Learning?".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
Message { TextMessage {
role: "assistant".to_string(), role: "assistant".to_string(),
content: Some("magic!".to_string()), content: "magic!".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
], ],
bos_token: Some("[BOS]"), bos_token: Some("[BOS]"),
@ -1134,33 +1105,21 @@ mod tests {
let chat_template_inputs = ChatTemplateInputs { let chat_template_inputs = ChatTemplateInputs {
messages: vec![ messages: vec![
Message { TextMessage {
role: "user".to_string(), role: "user".to_string(),
content: Some("Hi!".to_string()), content: "Hi!".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
Message { TextMessage {
role: "assistant".to_string(), role: "assistant".to_string(),
content: Some("Hello how can I help?".to_string()), content: "Hello how can I help?".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
Message { TextMessage {
role: "user".to_string(), role: "user".to_string(),
content: Some("What is Deep Learning?".to_string()), content: "What is Deep Learning?".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
Message { TextMessage {
role: "assistant".to_string(), role: "assistant".to_string(),
content: Some("magic!".to_string()), content: "magic!".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
], ],
bos_token: Some("[BOS]"), bos_token: Some("[BOS]"),
@ -1197,33 +1156,21 @@ mod tests {
let chat_template_inputs = ChatTemplateInputs { let chat_template_inputs = ChatTemplateInputs {
messages: vec![ messages: vec![
Message { TextMessage {
role: "user".to_string(), role: "user".to_string(),
content: Some("Hi!".to_string()), content: "Hi!".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
Message { TextMessage {
role: "assistant".to_string(), role: "assistant".to_string(),
content: Some("Hello how can I help?".to_string()), content: "Hello how can I help?".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
Message { TextMessage {
role: "user".to_string(), role: "user".to_string(),
content: Some("What is Deep Learning?".to_string()), content: "What is Deep Learning?".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
Message { TextMessage {
role: "assistant".to_string(), role: "assistant".to_string(),
content: Some("magic!".to_string()), content: "magic!".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
], ],
bos_token: Some("[BOS]"), bos_token: Some("[BOS]"),
@ -1246,38 +1193,24 @@ mod tests {
#[test] #[test]
fn test_many_chat_templates() { fn test_many_chat_templates() {
let example_chat = vec![ let example_chat = vec![
Message { TextMessage {
role: "user".to_string(), role: "user".to_string(),
content: Some("Hello, how are you?".to_string()), content: "Hello, how are you?".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
Message { TextMessage {
role: "assistant".to_string(), role: "assistant".to_string(),
content: Some("I'm doing great. How can I help you today?".to_string()), content: "I'm doing great. How can I help you today?".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
Message { TextMessage {
role: "user".to_string(), role: "user".to_string(),
content: Some("I'd like to show off how chat templating works!".to_string()), content: "I'd like to show off how chat templating works!".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
]; ];
let example_chat_with_system = [Message { let example_chat_with_system = [TextMessage {
role: "system".to_string(), role: "system".to_string(),
content: Some( content: "You are a friendly chatbot who always responds in the style of a pirate"
"You are a friendly chatbot who always responds in the style of a pirate" .to_string(),
.to_string(),
),
name: None,
tool_calls: None,
tool_call_id: None,
}] }]
.iter() .iter()
.chain(&example_chat) .chain(&example_chat)
@ -1417,19 +1350,13 @@ mod tests {
chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}", chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}",
input: ChatTemplateInputs { input: ChatTemplateInputs {
messages: vec![ messages: vec![
Message { TextMessage{
role: "system".to_string(), role: "system".to_string(),
content: Some("You are a friendly chatbot who always responds in the style of a pirate".to_string()), content: "You are a friendly chatbot who always responds in the style of a pirate".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
Message { TextMessage{
role: "user".to_string(), role: "user".to_string(),
content: Some("How many helicopters can a human eat in one sitting?".to_string()), content: "How many helicopters can a human eat in one sitting?".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
], ],
add_generation_prompt: true, add_generation_prompt: true,

View File

@ -11,6 +11,7 @@ use queue::{Entry, Queue};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::sync::OwnedSemaphorePermit; use tokio::sync::OwnedSemaphorePermit;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::warn;
use utoipa::ToSchema; use utoipa::ToSchema;
use validation::Validation; use validation::Validation;
@ -440,7 +441,7 @@ pub(crate) struct ChatCompletion {
#[derive(Clone, Deserialize, Serialize, ToSchema)] #[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionComplete { pub(crate) struct ChatCompletionComplete {
pub index: u32, pub index: u32,
pub message: Message, pub message: OutputMessage,
pub logprobs: Option<ChatCompletionLogprobs>, pub logprobs: Option<ChatCompletionLogprobs>,
pub finish_reason: String, pub finish_reason: String,
} }
@ -533,6 +534,30 @@ impl ChatCompletion {
return_logprobs: bool, return_logprobs: bool,
tool_calls: Option<Vec<ToolCall>>, tool_calls: Option<Vec<ToolCall>>,
) -> Self { ) -> Self {
let message = match (output, tool_calls) {
(Some(content), None) => OutputMessage::ChatMessage(TextMessage {
role: "assistant".into(),
content,
}),
(None, Some(tool_calls)) => OutputMessage::ToolCall(ToolCallMessage {
role: "assistant".to_string(),
tool_calls,
}),
(Some(output), Some(_)) => {
warn!("Received both chat and tool call");
OutputMessage::ChatMessage(TextMessage {
role: "assistant".into(),
content: output,
})
}
(None, None) => {
warn!("Didn't receive an answer");
OutputMessage::ChatMessage(TextMessage {
role: "assistant".into(),
content: "".to_string(),
})
}
};
Self { Self {
id: String::new(), id: String::new(),
object: "text_completion".into(), object: "text_completion".into(),
@ -541,13 +566,7 @@ impl ChatCompletion {
system_fingerprint, system_fingerprint,
choices: vec![ChatCompletionComplete { choices: vec![ChatCompletionComplete {
index: 0, index: 0,
message: Message { message,
role: "assistant".into(),
content: output,
name: None,
tool_calls,
tool_call_id: None,
},
logprobs: return_logprobs logprobs: return_logprobs
.then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))), .then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))),
finish_reason: details.finish_reason.to_string(), finish_reason: details.finish_reason.to_string(),
@ -569,6 +588,7 @@ pub(crate) struct CompletionCompleteChunk {
pub model: String, pub model: String,
pub system_fingerprint: String, pub system_fingerprint: String,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema)] #[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionChunk { pub(crate) struct ChatCompletionChunk {
pub id: String, pub id: String,
@ -589,21 +609,20 @@ pub(crate) struct ChatCompletionChoice {
pub finish_reason: Option<String>, pub finish_reason: Option<String>,
} }
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
pub(crate) struct ChatCompletionDelta { pub struct ToolCallDelta {
#[schema(example = "user")] #[schema(example = "assistant")]
// TODO Modify this to a true enum. role: String,
#[serde(default, skip_serializing_if = "Option::is_none")] tool_calls: DeltaToolCall,
pub role: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = "What is Deep Learning?")]
pub content: Option<String>,
// default to None
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<DeltaToolCall>,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)] #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
enum ChatCompletionDelta {
Chat(TextMessage),
Tool(ToolCallDelta),
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
pub(crate) struct DeltaToolCall { pub(crate) struct DeltaToolCall {
pub index: u32, pub index: u32,
pub id: String, pub id: String,
@ -611,7 +630,7 @@ pub(crate) struct DeltaToolCall {
pub function: Function, pub function: Function,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)] #[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
pub(crate) struct Function { pub(crate) struct Function {
pub name: Option<String>, pub name: Option<String>,
pub arguments: String, pub arguments: String,
@ -629,15 +648,13 @@ impl ChatCompletionChunk {
finish_reason: Option<String>, finish_reason: Option<String>,
) -> Self { ) -> Self {
let delta = match (delta, tool_calls) { let delta = match (delta, tool_calls) {
(Some(delta), _) => ChatCompletionDelta { (Some(delta), _) => ChatCompletionDelta::Chat(TextMessage {
role: Some("assistant".to_string()), role: "assistant".to_string(),
content: Some(delta), content: delta,
tool_calls: None, }),
}, (None, Some(tool_calls)) => ChatCompletionDelta::Tool(ToolCallDelta {
(None, Some(tool_calls)) => ChatCompletionDelta { role: "assistant".to_string(),
role: Some("assistant".to_string()), tool_calls: DeltaToolCall {
content: None,
tool_calls: Some(DeltaToolCall {
index: 0, index: 0,
id: String::new(), id: String::new(),
r#type: "function".to_string(), r#type: "function".to_string(),
@ -645,13 +662,12 @@ impl ChatCompletionChunk {
name: None, name: None,
arguments: tool_calls[0].to_string(), arguments: tool_calls[0].to_string(),
}, },
}), },
}, }),
(None, None) => ChatCompletionDelta { (None, None) => ChatCompletionDelta::Chat(TextMessage {
role: None, role: "assistant".to_string(),
content: None, content: "".to_string(),
tool_calls: None, }),
},
}; };
Self { Self {
id: String::new(), id: String::new(),
@ -852,7 +868,7 @@ where
state.end() state.end()
} }
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)] #[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default, PartialEq)]
pub(crate) struct FunctionDefinition { pub(crate) struct FunctionDefinition {
#[serde(default)] #[serde(default)]
pub description: Option<String>, pub description: Option<String>,
@ -872,7 +888,7 @@ pub(crate) struct Tool {
#[derive(Clone, Serialize, Deserialize, Default)] #[derive(Clone, Serialize, Deserialize, Default)]
pub(crate) struct ChatTemplateInputs<'a> { pub(crate) struct ChatTemplateInputs<'a> {
messages: Vec<Message>, messages: Vec<TextMessage>,
bos_token: Option<&'a str>, bos_token: Option<&'a str>,
eos_token: Option<&'a str>, eos_token: Option<&'a str>,
add_generation_prompt: bool, add_generation_prompt: bool,
@ -880,91 +896,113 @@ pub(crate) struct ChatTemplateInputs<'a> {
tools_prompt: Option<&'a str>, tools_prompt: Option<&'a str>,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)] #[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug, PartialEq)]
pub(crate) struct ToolCall { pub(crate) struct ToolCall {
pub id: String, pub id: String,
pub r#type: String, pub r#type: String,
pub function: FunctionDefinition, pub function: FunctionDefinition,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)] #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
pub(crate) struct Text { struct Url {
#[serde(default)] url: String,
pub text: String,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)] #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
pub(crate) struct ImageUrl { struct ImageUrl {
#[serde(default)] image_url: Url,
pub url: String,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)] #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
pub(crate) struct Content { struct Text {
pub r#type: String, text: String,
}
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
enum MessageChunk {
Text(Text),
ImageUrl(ImageUrl),
}
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
pub struct Message {
#[schema(example = "user")]
role: String,
#[schema(example = "My name is David and I")]
#[serde(deserialize_with = "message_content_serde::deserialize")]
content: Vec<MessageChunk>,
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub text: Option<String>, #[schema(example = "\"David\"")]
#[serde(default, skip_serializing_if = "Option::is_none")] name: Option<String>,
pub image_url: Option<ImageUrl>,
} }
mod message_content_serde { mod message_content_serde {
use super::*; use super::*;
use serde::de; use serde::{Deserialize, Deserializer};
use serde::Deserializer;
use serde_json::Value;
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<String>, D::Error> pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<MessageChunk>, D::Error>
where where
D: Deserializer<'de>, D: Deserializer<'de>,
{ {
let value = Value::deserialize(deserializer)?; #[derive(Deserialize)]
match value { #[serde(untagged)]
Value::String(s) => Ok(Some(s)), enum Message {
Value::Array(arr) => { Text(String),
let results: Result<Vec<String>, _> = arr Chunks(Vec<MessageChunk>),
.into_iter() }
.map(|v| { let message: Message = Deserialize::deserialize(deserializer)?;
let content: Content = let chunks = match message {
serde_json::from_value(v).map_err(de::Error::custom)?; Message::Text(text) => {
match content.r#type.as_str() { vec![MessageChunk::Text(Text { text })]
"text" => Ok(content.text.unwrap_or_default()),
"image_url" => {
if let Some(url) = content.image_url {
Ok(format!("![]({})", url.url))
} else {
Ok(String::new())
}
}
_ => Err(de::Error::custom("invalid content type")),
}
})
.collect();
results.map(|strings| Some(strings.join("")))
} }
Value::Null => Ok(None), Message::Chunks(s) => s,
_ => Err(de::Error::custom("invalid token format")), };
Ok(chunks)
}
}
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
pub struct TextMessage {
#[schema(example = "user")]
pub role: String,
#[schema(example = "My name is David and I")]
pub content: String,
}
impl From<Message> for TextMessage {
fn from(value: Message) -> Self {
TextMessage {
role: value.role,
content: value
.content
.into_iter()
.map(|c| match c {
MessageChunk::Text(Text { text }) => text,
MessageChunk::ImageUrl(image) => {
let url = image.image_url.url;
format!("![]({url})")
}
})
.collect::<Vec<_>>()
.join(""),
} }
} }
} }
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug)] #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
pub(crate) struct Message { pub struct ToolCallMessage {
#[schema(example = "user")] #[schema(example = "assistant")]
pub role: String, role: String,
#[serde(skip_serializing_if = "Option::is_none")] tool_calls: Vec<ToolCall>,
#[schema(example = "My name is David and I")] }
#[serde(default, deserialize_with = "message_content_serde::deserialize")]
pub content: Option<String>, #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(untagged)]
#[schema(example = "\"David\"")] pub(crate) enum OutputMessage {
pub name: Option<String>, ChatMessage(TextMessage),
#[serde(default, skip_serializing_if = "Option::is_none")] ToolCall(ToolCallMessage),
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = "\"get_weather\"")]
pub tool_call_id: Option<String>,
} }
#[derive(Clone, Debug, Deserialize, ToSchema)] #[derive(Clone, Debug, Deserialize, ToSchema)]
@ -1127,7 +1165,7 @@ pub(crate) struct ErrorResponse {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use serde_json::json;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
pub(crate) async fn get_tokenizer() -> Tokenizer { pub(crate) async fn get_tokenizer() -> Tokenizer {
@ -1195,4 +1233,66 @@ mod tests {
); );
assert_eq!(config.eos_token, Some("<end▁of▁sentence>".to_string())); assert_eq!(config.eos_token, Some("<end▁of▁sentence>".to_string()));
} }
#[test]
fn test_chat_simple_string() {
let json = json!(
{
"model": "",
"messages": [
{"role": "user",
"content": "What is Deep Learning?"
}
]
});
let request: ChatRequest = serde_json::from_str(json.to_string().as_str()).unwrap();
assert_eq!(
request.messages[0],
Message {
role: "user".to_string(),
content: vec![MessageChunk::Text(Text {
text: "What is Deep Learning?".to_string()
}),],
name: None
}
);
}
#[test]
fn test_chat_request() {
let json = json!(
{
"model": "",
"messages": [
{"role": "user",
"content": [
{"type": "text", "text": "Whats in this image?"},
{
"type": "image_url",
"image_url": {
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
},
},
]
}
]
});
let request: ChatRequest = serde_json::from_str(json.to_string().as_str()).unwrap();
assert_eq!(
request.messages[0],
Message{
role: "user".to_string(),
content: vec![
MessageChunk::Text(Text { text: "Whats in this image?".to_string() }),
MessageChunk::ImageUrl(ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } })
],
name: None
}
);
}
} }