Fixing types. (#1906)

# What does this PR do?

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->
This commit is contained in:
Nicolas Patry 2024-05-16 16:59:05 +02:00 committed by GitHub
parent d8402eaf67
commit f5d43414c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 256 additions and 229 deletions

View File

@ -2,7 +2,7 @@
use crate::validation::{Validation, ValidationError};
use crate::{
ChatTemplateInputs, ChatTemplateVersions, Entry, GenerateRequest, GenerateStreamResponse,
HubTokenizerConfig, Message, PrefillToken, Queue, Token,
HubTokenizerConfig, Message, MessageChunk, PrefillToken, Queue, Text, TextMessage, Token,
};
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
use futures::future::try_join_all;
@ -362,16 +362,15 @@ impl ChatTemplate {
if self.use_default_tool_template {
if let Some(last_message) = messages.last_mut() {
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
last_message.content = Some(format!(
"{}\n---\n{}\n{}",
last_message.content.as_deref().unwrap_or_default(),
tool_prompt,
tools
));
last_message.content.push(MessageChunk::Text(Text {
text: format!("\n---\n{}\n{}", tool_prompt, tools),
}));
}
}
}
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
self.template
.render(ChatTemplateInputs {
messages,
@ -939,8 +938,7 @@ impl InferError {
#[cfg(test)]
mod tests {
use crate::infer::raise_exception;
use crate::ChatTemplateInputs;
use crate::Message;
use crate::{ChatTemplateInputs, TextMessage};
use minijinja::Environment;
#[test]
@ -974,33 +972,21 @@ mod tests {
let chat_template_inputs = ChatTemplateInputs {
messages: vec![
Message {
TextMessage {
role: "user".to_string(),
content: Some("Hi!".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
content: "Hi!".to_string(),
},
Message {
TextMessage {
role: "assistant".to_string(),
content: Some("Hello how can I help?".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
content: "Hello how can I help?".to_string(),
},
Message {
TextMessage {
role: "user".to_string(),
content: Some("What is Deep Learning?".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
content: "What is Deep Learning?".to_string(),
},
Message {
TextMessage {
role: "assistant".to_string(),
content: Some("magic!".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
content: "magic!".to_string(),
},
],
bos_token: Some("[BOS]"),
@ -1048,40 +1034,25 @@ mod tests {
let chat_template_inputs = ChatTemplateInputs {
messages: vec![
Message {
TextMessage {
role: "user".to_string(),
content: Some("Hi!".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
content: "Hi!".to_string(),
},
Message {
TextMessage {
role: "user".to_string(),
content: Some("Hi again!".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
content: "Hi again!".to_string(),
},
Message {
TextMessage {
role: "assistant".to_string(),
content: Some("Hello how can I help?".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
content: "Hello how can I help?".to_string(),
},
Message {
TextMessage {
role: "user".to_string(),
content: Some("What is Deep Learning?".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
content: "What is Deep Learning?".to_string(),
},
Message {
TextMessage {
role: "assistant".to_string(),
content: Some("magic!".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
content: "magic!".to_string(),
},
],
bos_token: Some("[BOS]"),
@ -1134,33 +1105,21 @@ mod tests {
let chat_template_inputs = ChatTemplateInputs {
messages: vec![
Message {
TextMessage {
role: "user".to_string(),
content: Some("Hi!".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
content: "Hi!".to_string(),
},
Message {
TextMessage {
role: "assistant".to_string(),
content: Some("Hello how can I help?".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
content: "Hello how can I help?".to_string(),
},
Message {
TextMessage {
role: "user".to_string(),
content: Some("What is Deep Learning?".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
content: "What is Deep Learning?".to_string(),
},
Message {
TextMessage {
role: "assistant".to_string(),
content: Some("magic!".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
content: "magic!".to_string(),
},
],
bos_token: Some("[BOS]"),
@ -1197,33 +1156,21 @@ mod tests {
let chat_template_inputs = ChatTemplateInputs {
messages: vec![
Message {
TextMessage {
role: "user".to_string(),
content: Some("Hi!".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
content: "Hi!".to_string(),
},
Message {
TextMessage {
role: "assistant".to_string(),
content: Some("Hello how can I help?".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
content: "Hello how can I help?".to_string(),
},
Message {
TextMessage {
role: "user".to_string(),
content: Some("What is Deep Learning?".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
content: "What is Deep Learning?".to_string(),
},
Message {
TextMessage {
role: "assistant".to_string(),
content: Some("magic!".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
content: "magic!".to_string(),
},
],
bos_token: Some("[BOS]"),
@ -1246,38 +1193,24 @@ mod tests {
#[test]
fn test_many_chat_templates() {
let example_chat = vec![
Message {
TextMessage {
role: "user".to_string(),
content: Some("Hello, how are you?".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
content: "Hello, how are you?".to_string(),
},
Message {
TextMessage {
role: "assistant".to_string(),
content: Some("I'm doing great. How can I help you today?".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
content: "I'm doing great. How can I help you today?".to_string(),
},
Message {
TextMessage {
role: "user".to_string(),
content: Some("I'd like to show off how chat templating works!".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
content: "I'd like to show off how chat templating works!".to_string(),
},
];
let example_chat_with_system = [Message {
let example_chat_with_system = [TextMessage {
role: "system".to_string(),
content: Some(
"You are a friendly chatbot who always responds in the style of a pirate"
content: "You are a friendly chatbot who always responds in the style of a pirate"
.to_string(),
),
name: None,
tool_calls: None,
tool_call_id: None,
}]
.iter()
.chain(&example_chat)
@ -1417,19 +1350,13 @@ mod tests {
chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}",
input: ChatTemplateInputs {
messages: vec![
Message {
TextMessage{
role: "system".to_string(),
content: Some("You are a friendly chatbot who always responds in the style of a pirate".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
content: "You are a friendly chatbot who always responds in the style of a pirate".to_string(),
},
Message {
TextMessage{
role: "user".to_string(),
content: Some("How many helicopters can a human eat in one sitting?".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
content: "How many helicopters can a human eat in one sitting?".to_string(),
},
],
add_generation_prompt: true,

View File

@ -11,6 +11,7 @@ use queue::{Entry, Queue};
use serde::{Deserialize, Serialize};
use tokio::sync::OwnedSemaphorePermit;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::warn;
use utoipa::ToSchema;
use validation::Validation;
@ -440,7 +441,7 @@ pub(crate) struct ChatCompletion {
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionComplete {
pub index: u32,
pub message: Message,
pub message: OutputMessage,
pub logprobs: Option<ChatCompletionLogprobs>,
pub finish_reason: String,
}
@ -533,6 +534,30 @@ impl ChatCompletion {
return_logprobs: bool,
tool_calls: Option<Vec<ToolCall>>,
) -> Self {
let message = match (output, tool_calls) {
(Some(content), None) => OutputMessage::ChatMessage(TextMessage {
role: "assistant".into(),
content,
}),
(None, Some(tool_calls)) => OutputMessage::ToolCall(ToolCallMessage {
role: "assistant".to_string(),
tool_calls,
}),
(Some(output), Some(_)) => {
warn!("Received both chat and tool call");
OutputMessage::ChatMessage(TextMessage {
role: "assistant".into(),
content: output,
})
}
(None, None) => {
warn!("Didn't receive an answer");
OutputMessage::ChatMessage(TextMessage {
role: "assistant".into(),
content: "".to_string(),
})
}
};
Self {
id: String::new(),
object: "text_completion".into(),
@ -541,13 +566,7 @@ impl ChatCompletion {
system_fingerprint,
choices: vec![ChatCompletionComplete {
index: 0,
message: Message {
role: "assistant".into(),
content: output,
name: None,
tool_calls,
tool_call_id: None,
},
message,
logprobs: return_logprobs
.then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))),
finish_reason: details.finish_reason.to_string(),
@ -569,6 +588,7 @@ pub(crate) struct CompletionCompleteChunk {
pub model: String,
pub system_fingerprint: String,
}
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionChunk {
pub id: String,
@ -589,21 +609,20 @@ pub(crate) struct ChatCompletionChoice {
pub finish_reason: Option<String>,
}
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionDelta {
#[schema(example = "user")]
// TODO Modify this to a true enum.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = "What is Deep Learning?")]
pub content: Option<String>,
// default to None
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<DeltaToolCall>,
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
pub struct ToolCallDelta {
#[schema(example = "assistant")]
role: String,
tool_calls: DeltaToolCall,
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
enum ChatCompletionDelta {
Chat(TextMessage),
Tool(ToolCallDelta),
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
pub(crate) struct DeltaToolCall {
pub index: u32,
pub id: String,
@ -611,7 +630,7 @@ pub(crate) struct DeltaToolCall {
pub function: Function,
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
pub(crate) struct Function {
pub name: Option<String>,
pub arguments: String,
@ -629,15 +648,13 @@ impl ChatCompletionChunk {
finish_reason: Option<String>,
) -> Self {
let delta = match (delta, tool_calls) {
(Some(delta), _) => ChatCompletionDelta {
role: Some("assistant".to_string()),
content: Some(delta),
tool_calls: None,
},
(None, Some(tool_calls)) => ChatCompletionDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: Some(DeltaToolCall {
(Some(delta), _) => ChatCompletionDelta::Chat(TextMessage {
role: "assistant".to_string(),
content: delta,
}),
(None, Some(tool_calls)) => ChatCompletionDelta::Tool(ToolCallDelta {
role: "assistant".to_string(),
tool_calls: DeltaToolCall {
index: 0,
id: String::new(),
r#type: "function".to_string(),
@ -645,13 +662,12 @@ impl ChatCompletionChunk {
name: None,
arguments: tool_calls[0].to_string(),
},
},
}),
(None, None) => ChatCompletionDelta::Chat(TextMessage {
role: "assistant".to_string(),
content: "".to_string(),
}),
},
(None, None) => ChatCompletionDelta {
role: None,
content: None,
tool_calls: None,
},
};
Self {
id: String::new(),
@ -852,7 +868,7 @@ where
state.end()
}
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)]
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default, PartialEq)]
pub(crate) struct FunctionDefinition {
#[serde(default)]
pub description: Option<String>,
@ -872,7 +888,7 @@ pub(crate) struct Tool {
#[derive(Clone, Serialize, Deserialize, Default)]
pub(crate) struct ChatTemplateInputs<'a> {
messages: Vec<Message>,
messages: Vec<TextMessage>,
bos_token: Option<&'a str>,
eos_token: Option<&'a str>,
add_generation_prompt: bool,
@ -880,91 +896,113 @@ pub(crate) struct ChatTemplateInputs<'a> {
tools_prompt: Option<&'a str>,
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug, PartialEq)]
pub(crate) struct ToolCall {
pub id: String,
pub r#type: String,
pub function: FunctionDefinition,
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
pub(crate) struct Text {
#[serde(default)]
pub text: String,
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
struct Url {
url: String,
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
pub(crate) struct ImageUrl {
#[serde(default)]
pub url: String,
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
struct ImageUrl {
image_url: Url,
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
pub(crate) struct Content {
pub r#type: String,
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
struct Text {
text: String,
}
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
enum MessageChunk {
Text(Text),
ImageUrl(ImageUrl),
}
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
pub struct Message {
#[schema(example = "user")]
role: String,
#[schema(example = "My name is David and I")]
#[serde(deserialize_with = "message_content_serde::deserialize")]
content: Vec<MessageChunk>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub image_url: Option<ImageUrl>,
#[schema(example = "\"David\"")]
name: Option<String>,
}
mod message_content_serde {
use super::*;
use serde::de;
use serde::Deserializer;
use serde_json::Value;
use serde::{Deserialize, Deserializer};
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<MessageChunk>, D::Error>
where
D: Deserializer<'de>,
{
let value = Value::deserialize(deserializer)?;
match value {
Value::String(s) => Ok(Some(s)),
Value::Array(arr) => {
let results: Result<Vec<String>, _> = arr
#[derive(Deserialize)]
#[serde(untagged)]
enum Message {
Text(String),
Chunks(Vec<MessageChunk>),
}
let message: Message = Deserialize::deserialize(deserializer)?;
let chunks = match message {
Message::Text(text) => {
vec![MessageChunk::Text(Text { text })]
}
Message::Chunks(s) => s,
};
Ok(chunks)
}
}
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
pub struct TextMessage {
#[schema(example = "user")]
pub role: String,
#[schema(example = "My name is David and I")]
pub content: String,
}
impl From<Message> for TextMessage {
fn from(value: Message) -> Self {
TextMessage {
role: value.role,
content: value
.content
.into_iter()
.map(|v| {
let content: Content =
serde_json::from_value(v).map_err(de::Error::custom)?;
match content.r#type.as_str() {
"text" => Ok(content.text.unwrap_or_default()),
"image_url" => {
if let Some(url) = content.image_url {
Ok(format!("![]({})", url.url))
} else {
Ok(String::new())
}
}
_ => Err(de::Error::custom("invalid content type")),
.map(|c| match c {
MessageChunk::Text(Text { text }) => text,
MessageChunk::ImageUrl(image) => {
let url = image.image_url.url;
format!("![]({url})")
}
})
.collect();
results.map(|strings| Some(strings.join("")))
}
Value::Null => Ok(None),
_ => Err(de::Error::custom("invalid token format")),
.collect::<Vec<_>>()
.join(""),
}
}
}
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug)]
pub(crate) struct Message {
#[schema(example = "user")]
pub role: String,
#[serde(skip_serializing_if = "Option::is_none")]
#[schema(example = "My name is David and I")]
#[serde(default, deserialize_with = "message_content_serde::deserialize")]
pub content: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = "\"David\"")]
pub name: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = "\"get_weather\"")]
pub tool_call_id: Option<String>,
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
pub struct ToolCallMessage {
#[schema(example = "assistant")]
role: String,
tool_calls: Vec<ToolCall>,
}
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
#[serde(untagged)]
pub(crate) enum OutputMessage {
ChatMessage(TextMessage),
ToolCall(ToolCallMessage),
}
#[derive(Clone, Debug, Deserialize, ToSchema)]
@ -1127,7 +1165,7 @@ pub(crate) struct ErrorResponse {
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use tokenizers::Tokenizer;
pub(crate) async fn get_tokenizer() -> Tokenizer {
@ -1195,4 +1233,66 @@ mod tests {
);
assert_eq!(config.eos_token, Some("<end▁of▁sentence>".to_string()));
}
#[test]
fn test_chat_simple_string() {
let json = json!(
{
"model": "",
"messages": [
{"role": "user",
"content": "What is Deep Learning?"
}
]
});
let request: ChatRequest = serde_json::from_str(json.to_string().as_str()).unwrap();
assert_eq!(
request.messages[0],
Message {
role: "user".to_string(),
content: vec![MessageChunk::Text(Text {
text: "What is Deep Learning?".to_string()
}),],
name: None
}
);
}
#[test]
fn test_chat_request() {
let json = json!(
{
"model": "",
"messages": [
{"role": "user",
"content": [
{"type": "text", "text": "Whats in this image?"},
{
"type": "image_url",
"image_url": {
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
},
},
]
}
]
});
let request: ChatRequest = serde_json::from_str(json.to_string().as_str()).unwrap();
assert_eq!(
request.messages[0],
Message{
role: "user".to_string(),
content: vec![
MessageChunk::Text(Text { text: "Whats in this image?".to_string() }),
MessageChunk::ImageUrl(ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } })
],
name: None
}
);
}
}