fix: prefer serde structs over custom functions (#2127)

* fix: prefer enum for chat object

* fix: adjust typo

* fix: enum CompletionType not ObjectType

* fix: adjust typo

* feat: leverage serde for conditional deser

* fix: adjust HubTokenizerConfig after rebase

* fix: update create_post_processor logic for token type

* fix: adjust unwrap syntax in template

* Fixing the post processor.

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
drbh 2024-07-01 09:08:05 -04:00 committed by GitHub
parent 5da4cfab1c
commit 9eefb2f672
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 207 additions and 203 deletions

View File

@ -7,9 +7,11 @@ pub(crate) use health::HealthCheck;
use crate::validation::{ValidGenerateRequest, Validation, ValidationError}; use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
use crate::{ use crate::{
ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig,
HubTokenizerConfig, Message, MessageChunk, PrefillToken, Text, TextMessage, Token, HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token,
};
use crate::{
FunctionRef, FunctionsMap, GrammarType, Properties, TokenizerConfigToken, Tool, ToolType, Tools,
}; };
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
use futures::future::try_join_all; use futures::future::try_join_all;
use minijinja::{Environment, ErrorKind, Template}; use minijinja::{Environment, ErrorKind, Template};
use minijinja_contrib::pycompat; use minijinja_contrib::pycompat;
@ -270,7 +272,11 @@ struct ChatTemplate {
} }
impl ChatTemplate { impl ChatTemplate {
fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self { fn new(
template: String,
bos_token: Option<TokenizerConfigToken>,
eos_token: Option<TokenizerConfigToken>,
) -> Self {
let mut env = Box::new(Environment::new()); let mut env = Box::new(Environment::new());
// enable things like .strip() or .capitalize() // enable things like .strip() or .capitalize()
env.set_unknown_method_callback(pycompat::unknown_method_callback); env.set_unknown_method_callback(pycompat::unknown_method_callback);
@ -287,8 +293,8 @@ impl ChatTemplate {
Self { Self {
template, template,
bos_token, bos_token: bos_token.map(|token| token.as_str().to_string()),
eos_token, eos_token: eos_token.map(|token| token.as_str().to_string()),
use_default_tool_template, use_default_tool_template,
} }
} }
@ -301,9 +307,9 @@ impl ChatTemplate {
if self.use_default_tool_template { if self.use_default_tool_template {
if let Some(last_message) = messages.last_mut() { if let Some(last_message) = messages.last_mut() {
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt { if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
last_message.content.push(MessageChunk::Text(Text { last_message.content.push(MessageChunk::Text {
text: format!("\n---\n{}\n{}", tool_prompt, tools), text: format!("\n---\n{}\n{}", tool_prompt, tools),
})); });
} }
} }
} }
@ -340,6 +346,14 @@ impl ToolGrammar {
.unwrap_or_else(|| panic!("Tool with name {} not found", name)) .unwrap_or_else(|| panic!("Tool with name {} not found", name))
.clone()] .clone()]
} }
ToolType::Function { function } => {
let tool = req_tools
.iter()
.find(|tool| tool.function.name == function.name)
.unwrap_or_else(|| panic!("Tool with name {} not found", function.name))
.clone();
vec![tool]
}
ToolType::OneOf => req_tools.to_owned(), ToolType::OneOf => req_tools.to_owned(),
}; };

View File

@ -53,23 +53,40 @@ pub enum ChatTemplateVersions {
Multiple(Vec<ChatTemplate>), Multiple(Vec<ChatTemplate>),
} }
use std::path::Path;
#[derive(Debug, Clone, Deserialize, Default)] #[derive(Debug, Clone, Deserialize, Default)]
pub struct HubTokenizerConfig { pub struct HubTokenizerConfig {
pub chat_template: Option<ChatTemplateVersions>, pub chat_template: Option<ChatTemplateVersions>,
pub completion_template: Option<String>, pub completion_template: Option<String>,
#[serde(deserialize_with = "token_serde::deserialize")] pub bos_token: Option<TokenizerConfigToken>,
pub bos_token: Option<String>, pub eos_token: Option<TokenizerConfigToken>,
#[serde(deserialize_with = "token_serde::deserialize")]
pub eos_token: Option<String>,
pub tokenizer_class: Option<String>, pub tokenizer_class: Option<String>,
pub add_bos_token: Option<bool>, pub add_bos_token: Option<bool>,
pub add_eos_token: Option<bool>, pub add_eos_token: Option<bool>,
} }
impl HubTokenizerConfig { impl HubTokenizerConfig {
pub fn from_file<P: AsRef<std::path::Path>>(filename: P) -> Option<Self> { pub fn from_file<P: AsRef<Path>>(filename: P) -> Option<Self> {
let content = std::fs::read_to_string(filename).ok()?; std::fs::read_to_string(filename)
serde_json::from_str(&content).ok() .ok()
.and_then(|content| serde_json::from_str(&content).ok())
}
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
#[serde(untagged)]
pub enum TokenizerConfigToken {
String(String),
Object { content: String },
}
impl TokenizerConfigToken {
pub fn as_str(&self) -> &str {
match self {
TokenizerConfigToken::String(s) => s,
TokenizerConfigToken::Object { content } => content,
}
} }
} }
@ -100,9 +117,10 @@ pub struct HubProcessorConfig {
} }
impl HubProcessorConfig { impl HubProcessorConfig {
pub fn from_file<P: AsRef<std::path::Path>>(filename: P) -> Option<Self> { pub fn from_file<P: AsRef<Path>>(filename: P) -> Option<Self> {
let content = std::fs::read_to_string(filename).ok()?; std::fs::read_to_string(filename)
serde_json::from_str(&content).ok() .ok()
.and_then(|content| serde_json::from_str(&content).ok())
} }
} }
@ -121,35 +139,6 @@ pub(crate) enum GrammarType {
Regex(String), Regex(String),
} }
mod token_serde {
use super::*;
use serde::de;
use serde::Deserializer;
use serde_json::Value;
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
where
D: Deserializer<'de>,
{
let value = Value::deserialize(deserializer)?;
match value {
Value::String(s) => Ok(Some(s)),
Value::Object(map) => {
if let Some(content) = map.get("content").and_then(|v| v.as_str()) {
Ok(Some(content.to_string()))
} else {
Err(de::Error::custom(
"content key not found in structured token",
))
}
}
Value::Null => Ok(None),
_ => Err(de::Error::custom("invalid token format")),
}
}
}
#[derive(Clone, Debug, Serialize, ToSchema)] #[derive(Clone, Debug, Serialize, ToSchema)]
pub struct Info { pub struct Info {
/// Model info /// Model info
@ -359,30 +348,33 @@ fn default_parameters() -> GenerateParameters {
} }
} }
mod prompt_serde { #[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
use serde::{self, Deserialize, Deserializer}; #[serde(try_from = "PromptDeserializer")]
use serde_json::Value; pub struct Prompt(pub Vec<String>);
pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error> #[derive(Deserialize)]
where #[serde(untagged)]
D: Deserializer<'de>, enum PromptDeserializer {
{ Single(String),
let value = Value::deserialize(deserializer)?; Multiple(Vec<String>),
}
impl TryFrom<PromptDeserializer> for Prompt {
type Error = String;
fn try_from(value: PromptDeserializer) -> Result<Self, Self::Error> {
match value { match value {
Value::String(s) => Ok(vec![s]), PromptDeserializer::Single(s) => Ok(Prompt(vec![s])),
Value::Array(arr) if arr.is_empty() => Err(serde::de::Error::custom( PromptDeserializer::Multiple(v) => {
"Empty array detected. Do not use an empty array for the prompt.", if v.is_empty() {
)), Err(
Value::Array(arr) => arr "Empty array detected. Do not use an empty array for the prompt."
.iter() .to_string(),
.map(|v| match v { )
Value::String(s) => Ok(s.to_owned()), } else {
_ => Err(serde::de::Error::custom("Expected a string")), Ok(Prompt(v))
}) }
.collect(), }
_ => Err(serde::de::Error::custom(
"Expected a string or an array of strings",
)),
} }
} }
} }
@ -396,8 +388,7 @@ pub struct CompletionRequest {
/// The prompt to generate completions for. /// The prompt to generate completions for.
#[schema(example = "What is Deep Learning?")] #[schema(example = "What is Deep Learning?")]
#[serde(deserialize_with = "prompt_serde::deserialize")] pub prompt: Prompt,
pub prompt: Vec<String>,
/// The maximum number of tokens that can be generated in the chat completion. /// The maximum number of tokens that can be generated in the chat completion.
#[serde(default)] #[serde(default)]
@ -445,7 +436,6 @@ pub struct CompletionRequest {
#[derive(Clone, Deserialize, Serialize, ToSchema, Default)] #[derive(Clone, Deserialize, Serialize, ToSchema, Default)]
pub(crate) struct Completion { pub(crate) struct Completion {
pub id: String, pub id: String,
pub object: String,
#[schema(example = "1706270835")] #[schema(example = "1706270835")]
pub created: u64, pub created: u64,
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
@ -466,7 +456,6 @@ pub(crate) struct CompletionComplete {
#[derive(Clone, Deserialize, Serialize, ToSchema)] #[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletion { pub(crate) struct ChatCompletion {
pub id: String, pub id: String,
pub object: String,
#[schema(example = "1706270835")] #[schema(example = "1706270835")]
pub created: u64, pub created: u64,
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
@ -562,6 +551,15 @@ pub(crate) struct Usage {
pub total_tokens: u32, pub total_tokens: u32,
} }
#[derive(Clone, Serialize, ToSchema)]
#[serde(tag = "object")]
enum CompletionType {
#[serde(rename = "chat.completion.chunk")]
ChatCompletionChunk(ChatCompletionChunk),
#[serde(rename = "chat.completion")]
ChatCompletion(ChatCompletion),
}
impl ChatCompletion { impl ChatCompletion {
pub(crate) fn new( pub(crate) fn new(
model: String, model: String,
@ -598,7 +596,6 @@ impl ChatCompletion {
}; };
Self { Self {
id: String::new(), id: String::new(),
object: "chat.completion".into(),
created, created,
model, model,
system_fingerprint, system_fingerprint,
@ -620,7 +617,6 @@ impl ChatCompletion {
#[derive(Clone, Deserialize, Serialize, ToSchema)] #[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct CompletionCompleteChunk { pub(crate) struct CompletionCompleteChunk {
pub id: String, pub id: String,
pub object: String,
pub created: u64, pub created: u64,
pub choices: Vec<CompletionComplete>, pub choices: Vec<CompletionComplete>,
pub model: String, pub model: String,
@ -630,7 +626,6 @@ pub(crate) struct CompletionCompleteChunk {
#[derive(Clone, Serialize, ToSchema)] #[derive(Clone, Serialize, ToSchema)]
pub(crate) struct ChatCompletionChunk { pub(crate) struct ChatCompletionChunk {
pub id: String, pub id: String,
pub object: String,
#[schema(example = "1706270978")] #[schema(example = "1706270978")]
pub created: u64, pub created: u64,
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
@ -710,7 +705,6 @@ impl ChatCompletionChunk {
}; };
Self { Self {
id: String::new(), id: String::new(),
object: "chat.completion.chunk".to_string(),
created, created,
model, model,
system_fingerprint, system_fingerprint,
@ -821,7 +815,6 @@ pub(crate) struct ChatRequest {
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter. /// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
#[serde(default)] #[serde(default)]
#[schema(nullable = true, example = "null")] #[schema(nullable = true, example = "null")]
#[serde(deserialize_with = "deserialize_tool_choice::deserialize")]
pub tool_choice: Option<ToolType>, pub tool_choice: Option<ToolType>,
/// Response format constraints for the generation. /// Response format constraints for the generation.
@ -837,44 +830,41 @@ fn default_tool_prompt() -> Option<String> {
"\nYou will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n".to_string(), "\nYou will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n".to_string(),
) )
} }
#[derive(Clone, Deserialize, ToSchema, Serialize)]
enum ToolType { #[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)]
FunctionName(String), #[serde(untagged)]
pub enum ToolType {
OneOf, OneOf,
FunctionName(String),
Function { function: FunctionName },
} }
/// Deserialize the tool choice from the JSON input or from the function name ("none" is allowed but mapped to None) #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
mod deserialize_tool_choice { pub struct FunctionName {
use super::*; pub name: String,
use serde::de; }
use serde::Deserializer;
use serde_json::Value;
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<ToolType>, D::Error> #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
where #[serde(from = "ToolTypeDeserializer")]
D: Deserializer<'de>, pub struct ToolChoice(pub Option<ToolType>);
{
let value = Value::deserialize(deserializer)?;
#[derive(Deserialize)]
#[serde(untagged)]
enum ToolTypeDeserializer {
None(Option<String>),
Some(ToolType),
}
impl From<ToolTypeDeserializer> for ToolChoice {
fn from(value: ToolTypeDeserializer) -> Self {
match value { match value {
Value::String(s) => match s.as_str() { ToolTypeDeserializer::None(opt) => match opt.as_deref() {
"none" => Ok(None), Some("none") => ToolChoice(None),
"auto" => Ok(Some(ToolType::OneOf)), Some("auto") => ToolChoice(Some(ToolType::OneOf)),
_ => Ok(Some(ToolType::FunctionName(s))), Some(s) => ToolChoice(Some(ToolType::FunctionName(s.to_string()))),
None => ToolChoice(Some(ToolType::OneOf)),
}, },
Value::Object(map) => { ToolTypeDeserializer::Some(tool_type) => ToolChoice(Some(tool_type)),
if let Some(content) = map
.get("function")
.and_then(|v| v.get("name"))
.and_then(|v| v.as_str())
{
Ok(Some(ToolType::FunctionName(content.to_string())))
} else {
Err(de::Error::custom("function key not found in tool choice"))
}
}
Value::Null => Ok(Some(ToolType::OneOf)),
_ => Err(de::Error::custom("invalid token format")),
} }
} }
} }
@ -950,26 +940,16 @@ pub(crate) struct ToolCall {
} }
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
struct Url { pub struct Url {
url: String, url: String,
} }
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
struct ImageUrl {
image_url: Url,
}
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
struct Text {
text: String,
}
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
#[serde(tag = "type")] #[serde(tag = "type")]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
enum MessageChunk { pub enum MessageChunk {
Text(Text), Text { text: String },
ImageUrl(ImageUrl), ImageUrl { image_url: Url },
} }
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
@ -977,35 +957,31 @@ pub struct Message {
#[schema(example = "user")] #[schema(example = "user")]
role: String, role: String,
#[schema(example = "My name is David and I")] #[schema(example = "My name is David and I")]
#[serde(deserialize_with = "message_content_serde::deserialize")] pub content: MessageContent,
content: Vec<MessageChunk>,
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = "\"David\"")] #[schema(example = "\"David\"")]
name: Option<String>, name: Option<String>,
} }
mod message_content_serde { #[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
use super::*; #[serde(untagged)]
use serde::{Deserialize, Deserializer}; pub enum MessageContent {
SingleText(String),
MultipleChunks(Vec<MessageChunk>),
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<MessageChunk>, D::Error> // Pushing a chunk to a single text message will convert it to a multiple chunks message
where impl MessageContent {
D: Deserializer<'de>, pub fn push(&mut self, chunk: MessageChunk) {
{ match self {
#[derive(Deserialize)] MessageContent::SingleText(text) => {
#[serde(untagged)] *self =
enum Message { MessageContent::MultipleChunks(vec![MessageChunk::Text { text: text.clone() }]);
Text(String), }
Chunks(Vec<MessageChunk>), MessageContent::MultipleChunks(chunks) => {
chunks.push(chunk);
} }
let message: Message = Deserialize::deserialize(deserializer)?;
let chunks = match message {
Message::Text(text) => {
vec![MessageChunk::Text(Text { text })]
} }
Message::Chunks(s) => s,
};
Ok(chunks)
} }
} }
@ -1021,18 +997,17 @@ impl From<Message> for TextMessage {
fn from(value: Message) -> Self { fn from(value: Message) -> Self {
TextMessage { TextMessage {
role: value.role, role: value.role,
content: value content: match value.content {
.content MessageContent::SingleText(text) => text,
MessageContent::MultipleChunks(chunks) => chunks
.into_iter() .into_iter()
.map(|c| match c { .map(|chunk| match chunk {
MessageChunk::Text(Text { text }) => text, MessageChunk::Text { text } => text,
MessageChunk::ImageUrl(image) => { MessageChunk::ImageUrl { image_url } => format!("![]({})", image_url.url),
let url = image.image_url.url;
format!("![]({url})")
}
}) })
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(""), .join(""),
},
} }
} }
} }
@ -1240,9 +1215,16 @@ mod tests {
); );
assert_eq!( assert_eq!(
config.bos_token, config.bos_token,
Some("<begin▁of▁sentence>".to_string()) Some(TokenizerConfigToken::String(
"<begin▁of▁sentence>".to_string()
))
);
assert_eq!(
config.eos_token,
Some(TokenizerConfigToken::String(
"<end▁of▁sentence>".to_string()
))
); );
assert_eq!(config.eos_token, Some("<end▁of▁sentence>".to_string()));
// in this case we expect the tokens to be encoded as structured tokens // in this case we expect the tokens to be encoded as structured tokens
// we want the content of the structured token // we want the content of the structured token
@ -1275,9 +1257,16 @@ mod tests {
); );
assert_eq!( assert_eq!(
config.bos_token, config.bos_token,
Some("<begin▁of▁sentence>".to_string()) Some(TokenizerConfigToken::Object {
content: "<begin▁of▁sentence>".to_string()
})
);
assert_eq!(
config.eos_token,
Some(TokenizerConfigToken::Object {
content: "<end▁of▁sentence>".to_string()
})
); );
assert_eq!(config.eos_token, Some("<end▁of▁sentence>".to_string()));
} }
#[test] #[test]
@ -1295,9 +1284,7 @@ mod tests {
request.messages[0], request.messages[0],
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: vec![MessageChunk::Text(Text { content: MessageContent::SingleText("What is Deep Learning?".to_string()),
text: "What is Deep Learning?".to_string()
}),],
name: None name: None
} }
); );
@ -1321,10 +1308,10 @@ mod tests {
request.messages[0], request.messages[0],
Message{ Message{
role: "user".to_string(), role: "user".to_string(),
content: vec![ content: MessageContent::MultipleChunks(vec![
MessageChunk::Text(Text { text: "Whats in this image?".to_string() }), MessageChunk::Text { text: "Whats in this image?".to_string() },
MessageChunk::ImageUrl(ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } }) MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() }},
], ]),
name: None name: None
} }
); );
@ -1334,10 +1321,10 @@ mod tests {
fn text_message_convert() { fn text_message_convert() {
let message = Message{ let message = Message{
role: "user".to_string(), role: "user".to_string(),
content: vec![ content: MessageContent::MultipleChunks(vec![
MessageChunk::Text(Text { text: "Whats in this image?".to_string() }), MessageChunk::Text { text: "Whats in this image?".to_string() },
MessageChunk::ImageUrl(ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } }) MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } }
], ]),
name: None name: None
}; };
let textmsg: TextMessage = message.into(); let textmsg: TextMessage = message.into();

View File

@ -553,11 +553,11 @@ pub fn create_post_processor(
if add_bos_token { if add_bos_token {
if let Some(bos) = bos_token { if let Some(bos) = bos_token {
let bos_token_id = tokenizer let bos_token_id = tokenizer
.token_to_id(bos) .token_to_id(bos.as_str())
.expect("Should have found the bos token id"); .expect("Should have found the bos token id");
special_tokens.push((bos.clone(), bos_token_id)); special_tokens.push((bos.as_str(), bos_token_id));
single.push(format!("{}:0", bos)); single.push(format!("{}:0", bos.as_str()));
pair.push(format!("{}:0", bos)); pair.push(format!("{}:0", bos.as_str()));
} }
} }
@ -567,17 +567,17 @@ pub fn create_post_processor(
if add_eos_token { if add_eos_token {
if let Some(eos) = eos_token { if let Some(eos) = eos_token {
let eos_token_id = tokenizer let eos_token_id = tokenizer
.token_to_id(eos) .token_to_id(eos.as_str())
.expect("Should have found the eos token id"); .expect("Should have found the eos token id");
special_tokens.push((eos.clone(), eos_token_id)); special_tokens.push((eos.as_str(), eos_token_id));
single.push(format!("{}:0", eos)); single.push(format!("{}:0", eos.as_str()));
pair.push(format!("{}:0", eos)); pair.push(format!("{}:0", eos.as_str()));
} }
} }
if add_bos_token { if add_bos_token {
if let Some(bos) = bos_token { if let Some(bos) = bos_token {
pair.push(format!("{}:1", bos)); pair.push(format!("{}:1", bos.as_str()));
} }
} }
@ -585,7 +585,7 @@ pub fn create_post_processor(
if add_eos_token { if add_eos_token {
if let Some(eos) = eos_token { if let Some(eos) = eos_token {
pair.push(format!("{}:1", eos)); pair.push(format!("{}:1", eos.as_str()));
} }
} }
@ -611,14 +611,15 @@ enum RouterError {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use text_generation_router::TokenizerConfigToken;
#[test] #[test]
fn test_create_post_processor() { fn test_create_post_processor() {
let tokenizer_config = HubTokenizerConfig { let tokenizer_config = HubTokenizerConfig {
add_bos_token: None, add_bos_token: None,
add_eos_token: None, add_eos_token: None,
bos_token: Some("<s>".to_string()), bos_token: Some(TokenizerConfigToken::String("<s>".to_string())),
eos_token: Some("</s>".to_string()), eos_token: Some(TokenizerConfigToken::String("</s>".to_string())),
chat_template: None, chat_template: None,
tokenizer_class: None, tokenizer_class: None,
completion_template: None, completion_template: None,
@ -629,9 +630,9 @@ mod tests {
let post_processor = create_post_processor(&tokenizer, &tokenizer_config).unwrap(); let post_processor = create_post_processor(&tokenizer, &tokenizer_config).unwrap();
let expected = TemplateProcessing::builder() let expected = TemplateProcessing::builder()
.try_single("<s>:0 $A:0 <s>:1") .try_single("<s>:0 $A:0")
.unwrap() .unwrap()
.try_pair("<s>:0 $A:0 $B:1") .try_pair("<s>:0 $A:0 <s>:1 $B:1")
.unwrap() .unwrap()
.special_tokens(vec![("<s>".to_string(), 1)]) .special_tokens(vec![("<s>".to_string(), 1)])
.build() .build()

View File

@ -12,17 +12,18 @@ use crate::kserve::{
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::{ use crate::{
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
GenerateResponse, GrammarType, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info,
HubTokenizerConfig, Info, Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse,
Token, TokenizeResponse, Usage, Validation, Usage, Validation,
}; };
use crate::{ use crate::{
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob, ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob,
ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk, ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk,
CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse, CompletionRequest, CompletionType, DeltaToolCall, Function, Tool, VertexRequest,
VertexResponse,
}; };
use crate::{FunctionDefinition, ToolCall, ToolType}; use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolType};
use async_stream::__private::AsyncStream; use async_stream::__private::AsyncStream;
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode}; use axum::http::{HeaderMap, Method, StatusCode};
@ -635,7 +636,7 @@ async fn completions(
)); ));
} }
if req.prompt.len() > info.max_client_batch_size { if req.prompt.0.len() > info.max_client_batch_size {
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); metrics::increment_counter!("tgi_request_failure", "err" => "validation");
return Err(( return Err((
StatusCode::UNPROCESSABLE_ENTITY, StatusCode::UNPROCESSABLE_ENTITY,
@ -651,6 +652,7 @@ async fn completions(
let generate_requests: Vec<GenerateRequest> = req let generate_requests: Vec<GenerateRequest> = req
.prompt .prompt
.0
.iter() .iter()
.map(|prompt| GenerateRequest { .map(|prompt| GenerateRequest {
inputs: prompt.to_string(), inputs: prompt.to_string(),
@ -705,7 +707,6 @@ async fn completions(
event event
.json_data(CompletionCompleteChunk { .json_data(CompletionCompleteChunk {
id: "".to_string(), id: "".to_string(),
object: "text_completion".to_string(),
created: current_time, created: current_time,
choices: vec![CompletionComplete { choices: vec![CompletionComplete {
@ -932,7 +933,6 @@ async fn completions(
let response = Completion { let response = Completion {
id: "".to_string(), id: "".to_string(),
object: "text_completion".to_string(),
created: current_time, created: current_time,
model: info.model_id.clone(), model: info.model_id.clone(),
system_fingerprint: format!( system_fingerprint: format!(
@ -1153,7 +1153,8 @@ async fn chat_completions(
}; };
event event
.json_data(ChatCompletionChunk::new( .json_data(CompletionType::ChatCompletionChunk(
ChatCompletionChunk::new(
model_id.clone(), model_id.clone(),
system_fingerprint.clone(), system_fingerprint.clone(),
content, content,
@ -1161,6 +1162,7 @@ async fn chat_completions(
current_time, current_time,
logprobs, logprobs,
stream_token.details.map(|d| d.finish_reason.to_string()), stream_token.details.map(|d| d.finish_reason.to_string()),
),
)) ))
.unwrap_or_else(|e| { .unwrap_or_else(|e| {
println!("Failed to serialize ChatCompletionChunk: {:?}", e); println!("Failed to serialize ChatCompletionChunk: {:?}", e);
@ -1228,7 +1230,7 @@ async fn chat_completions(
(None, Some(generation.generated_text)) (None, Some(generation.generated_text))
}; };
// build the complete response object with the full text // build the complete response object with the full text
let response = ChatCompletion::new( let response = CompletionType::ChatCompletion(ChatCompletion::new(
model_id, model_id,
system_fingerprint, system_fingerprint,
output, output,
@ -1236,7 +1238,7 @@ async fn chat_completions(
generation.details.unwrap(), generation.details.unwrap(),
logprobs, logprobs,
tool_calls, tool_calls,
); ));
// wrap generation inside a Vec to match api-inference // wrap generation inside a Vec to match api-inference
Ok((headers, Json(response)).into_response()) Ok((headers, Json(response)).into_response())