fix: prefer serde structs over custom functions (#2127)
* fix: prefer enum for chat object * fix: adjust typo * fix: enum CompletionType not ObjectType * fix: adjust typo * feat: leverage serde for conditional deser * fix: adjust HubTokenizerConfig after rebase * fix: update create_post_processor logic for token type * fix: adjust unwrap syntax in template * Fixing the post processor. --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
parent
5da4cfab1c
commit
9eefb2f672
|
@ -7,9 +7,11 @@ pub(crate) use health::HealthCheck;
|
||||||
use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
|
use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
|
||||||
use crate::{
|
use crate::{
|
||||||
ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig,
|
ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig,
|
||||||
HubTokenizerConfig, Message, MessageChunk, PrefillToken, Text, TextMessage, Token,
|
HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token,
|
||||||
|
};
|
||||||
|
use crate::{
|
||||||
|
FunctionRef, FunctionsMap, GrammarType, Properties, TokenizerConfigToken, Tool, ToolType, Tools,
|
||||||
};
|
};
|
||||||
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
|
|
||||||
use futures::future::try_join_all;
|
use futures::future::try_join_all;
|
||||||
use minijinja::{Environment, ErrorKind, Template};
|
use minijinja::{Environment, ErrorKind, Template};
|
||||||
use minijinja_contrib::pycompat;
|
use minijinja_contrib::pycompat;
|
||||||
|
@ -270,7 +272,11 @@ struct ChatTemplate {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ChatTemplate {
|
impl ChatTemplate {
|
||||||
fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self {
|
fn new(
|
||||||
|
template: String,
|
||||||
|
bos_token: Option<TokenizerConfigToken>,
|
||||||
|
eos_token: Option<TokenizerConfigToken>,
|
||||||
|
) -> Self {
|
||||||
let mut env = Box::new(Environment::new());
|
let mut env = Box::new(Environment::new());
|
||||||
// enable things like .strip() or .capitalize()
|
// enable things like .strip() or .capitalize()
|
||||||
env.set_unknown_method_callback(pycompat::unknown_method_callback);
|
env.set_unknown_method_callback(pycompat::unknown_method_callback);
|
||||||
|
@ -287,8 +293,8 @@ impl ChatTemplate {
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
template,
|
template,
|
||||||
bos_token,
|
bos_token: bos_token.map(|token| token.as_str().to_string()),
|
||||||
eos_token,
|
eos_token: eos_token.map(|token| token.as_str().to_string()),
|
||||||
use_default_tool_template,
|
use_default_tool_template,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -301,9 +307,9 @@ impl ChatTemplate {
|
||||||
if self.use_default_tool_template {
|
if self.use_default_tool_template {
|
||||||
if let Some(last_message) = messages.last_mut() {
|
if let Some(last_message) = messages.last_mut() {
|
||||||
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
|
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
|
||||||
last_message.content.push(MessageChunk::Text(Text {
|
last_message.content.push(MessageChunk::Text {
|
||||||
text: format!("\n---\n{}\n{}", tool_prompt, tools),
|
text: format!("\n---\n{}\n{}", tool_prompt, tools),
|
||||||
}));
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -340,6 +346,14 @@ impl ToolGrammar {
|
||||||
.unwrap_or_else(|| panic!("Tool with name {} not found", name))
|
.unwrap_or_else(|| panic!("Tool with name {} not found", name))
|
||||||
.clone()]
|
.clone()]
|
||||||
}
|
}
|
||||||
|
ToolType::Function { function } => {
|
||||||
|
let tool = req_tools
|
||||||
|
.iter()
|
||||||
|
.find(|tool| tool.function.name == function.name)
|
||||||
|
.unwrap_or_else(|| panic!("Tool with name {} not found", function.name))
|
||||||
|
.clone();
|
||||||
|
vec![tool]
|
||||||
|
}
|
||||||
ToolType::OneOf => req_tools.to_owned(),
|
ToolType::OneOf => req_tools.to_owned(),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -53,23 +53,40 @@ pub enum ChatTemplateVersions {
|
||||||
Multiple(Vec<ChatTemplate>),
|
Multiple(Vec<ChatTemplate>),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Default)]
|
#[derive(Debug, Clone, Deserialize, Default)]
|
||||||
pub struct HubTokenizerConfig {
|
pub struct HubTokenizerConfig {
|
||||||
pub chat_template: Option<ChatTemplateVersions>,
|
pub chat_template: Option<ChatTemplateVersions>,
|
||||||
pub completion_template: Option<String>,
|
pub completion_template: Option<String>,
|
||||||
#[serde(deserialize_with = "token_serde::deserialize")]
|
pub bos_token: Option<TokenizerConfigToken>,
|
||||||
pub bos_token: Option<String>,
|
pub eos_token: Option<TokenizerConfigToken>,
|
||||||
#[serde(deserialize_with = "token_serde::deserialize")]
|
|
||||||
pub eos_token: Option<String>,
|
|
||||||
pub tokenizer_class: Option<String>,
|
pub tokenizer_class: Option<String>,
|
||||||
pub add_bos_token: Option<bool>,
|
pub add_bos_token: Option<bool>,
|
||||||
pub add_eos_token: Option<bool>,
|
pub add_eos_token: Option<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl HubTokenizerConfig {
|
impl HubTokenizerConfig {
|
||||||
pub fn from_file<P: AsRef<std::path::Path>>(filename: P) -> Option<Self> {
|
pub fn from_file<P: AsRef<Path>>(filename: P) -> Option<Self> {
|
||||||
let content = std::fs::read_to_string(filename).ok()?;
|
std::fs::read_to_string(filename)
|
||||||
serde_json::from_str(&content).ok()
|
.ok()
|
||||||
|
.and_then(|content| serde_json::from_str(&content).ok())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum TokenizerConfigToken {
|
||||||
|
String(String),
|
||||||
|
Object { content: String },
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TokenizerConfigToken {
|
||||||
|
pub fn as_str(&self) -> &str {
|
||||||
|
match self {
|
||||||
|
TokenizerConfigToken::String(s) => s,
|
||||||
|
TokenizerConfigToken::Object { content } => content,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -100,9 +117,10 @@ pub struct HubProcessorConfig {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl HubProcessorConfig {
|
impl HubProcessorConfig {
|
||||||
pub fn from_file<P: AsRef<std::path::Path>>(filename: P) -> Option<Self> {
|
pub fn from_file<P: AsRef<Path>>(filename: P) -> Option<Self> {
|
||||||
let content = std::fs::read_to_string(filename).ok()?;
|
std::fs::read_to_string(filename)
|
||||||
serde_json::from_str(&content).ok()
|
.ok()
|
||||||
|
.and_then(|content| serde_json::from_str(&content).ok())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -121,35 +139,6 @@ pub(crate) enum GrammarType {
|
||||||
Regex(String),
|
Regex(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
mod token_serde {
|
|
||||||
use super::*;
|
|
||||||
use serde::de;
|
|
||||||
use serde::Deserializer;
|
|
||||||
use serde_json::Value;
|
|
||||||
|
|
||||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
|
|
||||||
where
|
|
||||||
D: Deserializer<'de>,
|
|
||||||
{
|
|
||||||
let value = Value::deserialize(deserializer)?;
|
|
||||||
|
|
||||||
match value {
|
|
||||||
Value::String(s) => Ok(Some(s)),
|
|
||||||
Value::Object(map) => {
|
|
||||||
if let Some(content) = map.get("content").and_then(|v| v.as_str()) {
|
|
||||||
Ok(Some(content.to_string()))
|
|
||||||
} else {
|
|
||||||
Err(de::Error::custom(
|
|
||||||
"content key not found in structured token",
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Value::Null => Ok(None),
|
|
||||||
_ => Err(de::Error::custom("invalid token format")),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, ToSchema)]
|
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||||
pub struct Info {
|
pub struct Info {
|
||||||
/// Model info
|
/// Model info
|
||||||
|
@ -359,30 +348,33 @@ fn default_parameters() -> GenerateParameters {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mod prompt_serde {
|
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
|
||||||
use serde::{self, Deserialize, Deserializer};
|
#[serde(try_from = "PromptDeserializer")]
|
||||||
use serde_json::Value;
|
pub struct Prompt(pub Vec<String>);
|
||||||
|
|
||||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
|
#[derive(Deserialize)]
|
||||||
where
|
#[serde(untagged)]
|
||||||
D: Deserializer<'de>,
|
enum PromptDeserializer {
|
||||||
{
|
Single(String),
|
||||||
let value = Value::deserialize(deserializer)?;
|
Multiple(Vec<String>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<PromptDeserializer> for Prompt {
|
||||||
|
type Error = String;
|
||||||
|
|
||||||
|
fn try_from(value: PromptDeserializer) -> Result<Self, Self::Error> {
|
||||||
match value {
|
match value {
|
||||||
Value::String(s) => Ok(vec![s]),
|
PromptDeserializer::Single(s) => Ok(Prompt(vec![s])),
|
||||||
Value::Array(arr) if arr.is_empty() => Err(serde::de::Error::custom(
|
PromptDeserializer::Multiple(v) => {
|
||||||
"Empty array detected. Do not use an empty array for the prompt.",
|
if v.is_empty() {
|
||||||
)),
|
Err(
|
||||||
Value::Array(arr) => arr
|
"Empty array detected. Do not use an empty array for the prompt."
|
||||||
.iter()
|
.to_string(),
|
||||||
.map(|v| match v {
|
)
|
||||||
Value::String(s) => Ok(s.to_owned()),
|
} else {
|
||||||
_ => Err(serde::de::Error::custom("Expected a string")),
|
Ok(Prompt(v))
|
||||||
})
|
}
|
||||||
.collect(),
|
}
|
||||||
_ => Err(serde::de::Error::custom(
|
|
||||||
"Expected a string or an array of strings",
|
|
||||||
)),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -396,8 +388,7 @@ pub struct CompletionRequest {
|
||||||
|
|
||||||
/// The prompt to generate completions for.
|
/// The prompt to generate completions for.
|
||||||
#[schema(example = "What is Deep Learning?")]
|
#[schema(example = "What is Deep Learning?")]
|
||||||
#[serde(deserialize_with = "prompt_serde::deserialize")]
|
pub prompt: Prompt,
|
||||||
pub prompt: Vec<String>,
|
|
||||||
|
|
||||||
/// The maximum number of tokens that can be generated in the chat completion.
|
/// The maximum number of tokens that can be generated in the chat completion.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
@ -445,7 +436,6 @@ pub struct CompletionRequest {
|
||||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default)]
|
#[derive(Clone, Deserialize, Serialize, ToSchema, Default)]
|
||||||
pub(crate) struct Completion {
|
pub(crate) struct Completion {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
pub object: String,
|
|
||||||
#[schema(example = "1706270835")]
|
#[schema(example = "1706270835")]
|
||||||
pub created: u64,
|
pub created: u64,
|
||||||
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
||||||
|
@ -466,7 +456,6 @@ pub(crate) struct CompletionComplete {
|
||||||
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||||
pub(crate) struct ChatCompletion {
|
pub(crate) struct ChatCompletion {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
pub object: String,
|
|
||||||
#[schema(example = "1706270835")]
|
#[schema(example = "1706270835")]
|
||||||
pub created: u64,
|
pub created: u64,
|
||||||
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
||||||
|
@ -562,6 +551,15 @@ pub(crate) struct Usage {
|
||||||
pub total_tokens: u32,
|
pub total_tokens: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Serialize, ToSchema)]
|
||||||
|
#[serde(tag = "object")]
|
||||||
|
enum CompletionType {
|
||||||
|
#[serde(rename = "chat.completion.chunk")]
|
||||||
|
ChatCompletionChunk(ChatCompletionChunk),
|
||||||
|
#[serde(rename = "chat.completion")]
|
||||||
|
ChatCompletion(ChatCompletion),
|
||||||
|
}
|
||||||
|
|
||||||
impl ChatCompletion {
|
impl ChatCompletion {
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
model: String,
|
model: String,
|
||||||
|
@ -598,7 +596,6 @@ impl ChatCompletion {
|
||||||
};
|
};
|
||||||
Self {
|
Self {
|
||||||
id: String::new(),
|
id: String::new(),
|
||||||
object: "chat.completion".into(),
|
|
||||||
created,
|
created,
|
||||||
model,
|
model,
|
||||||
system_fingerprint,
|
system_fingerprint,
|
||||||
|
@ -620,7 +617,6 @@ impl ChatCompletion {
|
||||||
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||||
pub(crate) struct CompletionCompleteChunk {
|
pub(crate) struct CompletionCompleteChunk {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
pub object: String,
|
|
||||||
pub created: u64,
|
pub created: u64,
|
||||||
pub choices: Vec<CompletionComplete>,
|
pub choices: Vec<CompletionComplete>,
|
||||||
pub model: String,
|
pub model: String,
|
||||||
|
@ -630,7 +626,6 @@ pub(crate) struct CompletionCompleteChunk {
|
||||||
#[derive(Clone, Serialize, ToSchema)]
|
#[derive(Clone, Serialize, ToSchema)]
|
||||||
pub(crate) struct ChatCompletionChunk {
|
pub(crate) struct ChatCompletionChunk {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
pub object: String,
|
|
||||||
#[schema(example = "1706270978")]
|
#[schema(example = "1706270978")]
|
||||||
pub created: u64,
|
pub created: u64,
|
||||||
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
||||||
|
@ -710,7 +705,6 @@ impl ChatCompletionChunk {
|
||||||
};
|
};
|
||||||
Self {
|
Self {
|
||||||
id: String::new(),
|
id: String::new(),
|
||||||
object: "chat.completion.chunk".to_string(),
|
|
||||||
created,
|
created,
|
||||||
model,
|
model,
|
||||||
system_fingerprint,
|
system_fingerprint,
|
||||||
|
@ -821,7 +815,6 @@ pub(crate) struct ChatRequest {
|
||||||
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
|
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(nullable = true, example = "null")]
|
#[schema(nullable = true, example = "null")]
|
||||||
#[serde(deserialize_with = "deserialize_tool_choice::deserialize")]
|
|
||||||
pub tool_choice: Option<ToolType>,
|
pub tool_choice: Option<ToolType>,
|
||||||
|
|
||||||
/// Response format constraints for the generation.
|
/// Response format constraints for the generation.
|
||||||
|
@ -837,44 +830,41 @@ fn default_tool_prompt() -> Option<String> {
|
||||||
"\nYou will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n".to_string(),
|
"\nYou will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n".to_string(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
|
||||||
enum ToolType {
|
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)]
|
||||||
FunctionName(String),
|
#[serde(untagged)]
|
||||||
|
pub enum ToolType {
|
||||||
OneOf,
|
OneOf,
|
||||||
|
FunctionName(String),
|
||||||
|
Function { function: FunctionName },
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Deserialize the tool choice from the JSON input or from the function name ("none" is allowed but mapped to None)
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
mod deserialize_tool_choice {
|
pub struct FunctionName {
|
||||||
use super::*;
|
pub name: String,
|
||||||
use serde::de;
|
}
|
||||||
use serde::Deserializer;
|
|
||||||
use serde_json::Value;
|
|
||||||
|
|
||||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<ToolType>, D::Error>
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
where
|
#[serde(from = "ToolTypeDeserializer")]
|
||||||
D: Deserializer<'de>,
|
pub struct ToolChoice(pub Option<ToolType>);
|
||||||
{
|
|
||||||
let value = Value::deserialize(deserializer)?;
|
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
enum ToolTypeDeserializer {
|
||||||
|
None(Option<String>),
|
||||||
|
Some(ToolType),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<ToolTypeDeserializer> for ToolChoice {
|
||||||
|
fn from(value: ToolTypeDeserializer) -> Self {
|
||||||
match value {
|
match value {
|
||||||
Value::String(s) => match s.as_str() {
|
ToolTypeDeserializer::None(opt) => match opt.as_deref() {
|
||||||
"none" => Ok(None),
|
Some("none") => ToolChoice(None),
|
||||||
"auto" => Ok(Some(ToolType::OneOf)),
|
Some("auto") => ToolChoice(Some(ToolType::OneOf)),
|
||||||
_ => Ok(Some(ToolType::FunctionName(s))),
|
Some(s) => ToolChoice(Some(ToolType::FunctionName(s.to_string()))),
|
||||||
|
None => ToolChoice(Some(ToolType::OneOf)),
|
||||||
},
|
},
|
||||||
Value::Object(map) => {
|
ToolTypeDeserializer::Some(tool_type) => ToolChoice(Some(tool_type)),
|
||||||
if let Some(content) = map
|
|
||||||
.get("function")
|
|
||||||
.and_then(|v| v.get("name"))
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
{
|
|
||||||
Ok(Some(ToolType::FunctionName(content.to_string())))
|
|
||||||
} else {
|
|
||||||
Err(de::Error::custom("function key not found in tool choice"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Value::Null => Ok(Some(ToolType::OneOf)),
|
|
||||||
_ => Err(de::Error::custom("invalid token format")),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -950,26 +940,16 @@ pub(crate) struct ToolCall {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
||||||
struct Url {
|
pub struct Url {
|
||||||
url: String,
|
url: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
|
||||||
struct ImageUrl {
|
|
||||||
image_url: Url,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
|
||||||
struct Text {
|
|
||||||
text: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
||||||
#[serde(tag = "type")]
|
#[serde(tag = "type")]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
enum MessageChunk {
|
pub enum MessageChunk {
|
||||||
Text(Text),
|
Text { text: String },
|
||||||
ImageUrl(ImageUrl),
|
ImageUrl { image_url: Url },
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
||||||
|
@ -977,35 +957,31 @@ pub struct Message {
|
||||||
#[schema(example = "user")]
|
#[schema(example = "user")]
|
||||||
role: String,
|
role: String,
|
||||||
#[schema(example = "My name is David and I")]
|
#[schema(example = "My name is David and I")]
|
||||||
#[serde(deserialize_with = "message_content_serde::deserialize")]
|
pub content: MessageContent,
|
||||||
content: Vec<MessageChunk>,
|
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
#[schema(example = "\"David\"")]
|
#[schema(example = "\"David\"")]
|
||||||
name: Option<String>,
|
name: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
mod message_content_serde {
|
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
|
||||||
use super::*;
|
|
||||||
use serde::{Deserialize, Deserializer};
|
|
||||||
|
|
||||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<MessageChunk>, D::Error>
|
|
||||||
where
|
|
||||||
D: Deserializer<'de>,
|
|
||||||
{
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
#[serde(untagged)]
|
#[serde(untagged)]
|
||||||
enum Message {
|
pub enum MessageContent {
|
||||||
Text(String),
|
SingleText(String),
|
||||||
Chunks(Vec<MessageChunk>),
|
MultipleChunks(Vec<MessageChunk>),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pushing a chunk to a single text message will convert it to a multiple chunks message
|
||||||
|
impl MessageContent {
|
||||||
|
pub fn push(&mut self, chunk: MessageChunk) {
|
||||||
|
match self {
|
||||||
|
MessageContent::SingleText(text) => {
|
||||||
|
*self =
|
||||||
|
MessageContent::MultipleChunks(vec![MessageChunk::Text { text: text.clone() }]);
|
||||||
|
}
|
||||||
|
MessageContent::MultipleChunks(chunks) => {
|
||||||
|
chunks.push(chunk);
|
||||||
}
|
}
|
||||||
let message: Message = Deserialize::deserialize(deserializer)?;
|
|
||||||
let chunks = match message {
|
|
||||||
Message::Text(text) => {
|
|
||||||
vec![MessageChunk::Text(Text { text })]
|
|
||||||
}
|
}
|
||||||
Message::Chunks(s) => s,
|
|
||||||
};
|
|
||||||
Ok(chunks)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1021,18 +997,17 @@ impl From<Message> for TextMessage {
|
||||||
fn from(value: Message) -> Self {
|
fn from(value: Message) -> Self {
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: value.role,
|
role: value.role,
|
||||||
content: value
|
content: match value.content {
|
||||||
.content
|
MessageContent::SingleText(text) => text,
|
||||||
|
MessageContent::MultipleChunks(chunks) => chunks
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|c| match c {
|
.map(|chunk| match chunk {
|
||||||
MessageChunk::Text(Text { text }) => text,
|
MessageChunk::Text { text } => text,
|
||||||
MessageChunk::ImageUrl(image) => {
|
MessageChunk::ImageUrl { image_url } => format!("![]({})", image_url.url),
|
||||||
let url = image.image_url.url;
|
|
||||||
format!("![]({url})")
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.join(""),
|
.join(""),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1240,9 +1215,16 @@ mod tests {
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
config.bos_token,
|
config.bos_token,
|
||||||
Some("<|begin▁of▁sentence|>".to_string())
|
Some(TokenizerConfigToken::String(
|
||||||
|
"<|begin▁of▁sentence|>".to_string()
|
||||||
|
))
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
config.eos_token,
|
||||||
|
Some(TokenizerConfigToken::String(
|
||||||
|
"<|end▁of▁sentence|>".to_string()
|
||||||
|
))
|
||||||
);
|
);
|
||||||
assert_eq!(config.eos_token, Some("<|end▁of▁sentence|>".to_string()));
|
|
||||||
|
|
||||||
// in this case we expect the tokens to be encoded as structured tokens
|
// in this case we expect the tokens to be encoded as structured tokens
|
||||||
// we want the content of the structured token
|
// we want the content of the structured token
|
||||||
|
@ -1275,9 +1257,16 @@ mod tests {
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
config.bos_token,
|
config.bos_token,
|
||||||
Some("<|begin▁of▁sentence|>".to_string())
|
Some(TokenizerConfigToken::Object {
|
||||||
|
content: "<|begin▁of▁sentence|>".to_string()
|
||||||
|
})
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
config.eos_token,
|
||||||
|
Some(TokenizerConfigToken::Object {
|
||||||
|
content: "<|end▁of▁sentence|>".to_string()
|
||||||
|
})
|
||||||
);
|
);
|
||||||
assert_eq!(config.eos_token, Some("<|end▁of▁sentence|>".to_string()));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -1295,9 +1284,7 @@ mod tests {
|
||||||
request.messages[0],
|
request.messages[0],
|
||||||
Message {
|
Message {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: vec![MessageChunk::Text(Text {
|
content: MessageContent::SingleText("What is Deep Learning?".to_string()),
|
||||||
text: "What is Deep Learning?".to_string()
|
|
||||||
}),],
|
|
||||||
name: None
|
name: None
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
@ -1321,10 +1308,10 @@ mod tests {
|
||||||
request.messages[0],
|
request.messages[0],
|
||||||
Message{
|
Message{
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: vec![
|
content: MessageContent::MultipleChunks(vec![
|
||||||
MessageChunk::Text(Text { text: "Whats in this image?".to_string() }),
|
MessageChunk::Text { text: "Whats in this image?".to_string() },
|
||||||
MessageChunk::ImageUrl(ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } })
|
MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() }},
|
||||||
],
|
]),
|
||||||
name: None
|
name: None
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
@ -1334,10 +1321,10 @@ mod tests {
|
||||||
fn text_message_convert() {
|
fn text_message_convert() {
|
||||||
let message = Message{
|
let message = Message{
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: vec![
|
content: MessageContent::MultipleChunks(vec![
|
||||||
MessageChunk::Text(Text { text: "Whats in this image?".to_string() }),
|
MessageChunk::Text { text: "Whats in this image?".to_string() },
|
||||||
MessageChunk::ImageUrl(ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } })
|
MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } }
|
||||||
],
|
]),
|
||||||
name: None
|
name: None
|
||||||
};
|
};
|
||||||
let textmsg: TextMessage = message.into();
|
let textmsg: TextMessage = message.into();
|
||||||
|
|
|
@ -553,11 +553,11 @@ pub fn create_post_processor(
|
||||||
if add_bos_token {
|
if add_bos_token {
|
||||||
if let Some(bos) = bos_token {
|
if let Some(bos) = bos_token {
|
||||||
let bos_token_id = tokenizer
|
let bos_token_id = tokenizer
|
||||||
.token_to_id(bos)
|
.token_to_id(bos.as_str())
|
||||||
.expect("Should have found the bos token id");
|
.expect("Should have found the bos token id");
|
||||||
special_tokens.push((bos.clone(), bos_token_id));
|
special_tokens.push((bos.as_str(), bos_token_id));
|
||||||
single.push(format!("{}:0", bos));
|
single.push(format!("{}:0", bos.as_str()));
|
||||||
pair.push(format!("{}:0", bos));
|
pair.push(format!("{}:0", bos.as_str()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -567,17 +567,17 @@ pub fn create_post_processor(
|
||||||
if add_eos_token {
|
if add_eos_token {
|
||||||
if let Some(eos) = eos_token {
|
if let Some(eos) = eos_token {
|
||||||
let eos_token_id = tokenizer
|
let eos_token_id = tokenizer
|
||||||
.token_to_id(eos)
|
.token_to_id(eos.as_str())
|
||||||
.expect("Should have found the eos token id");
|
.expect("Should have found the eos token id");
|
||||||
special_tokens.push((eos.clone(), eos_token_id));
|
special_tokens.push((eos.as_str(), eos_token_id));
|
||||||
single.push(format!("{}:0", eos));
|
single.push(format!("{}:0", eos.as_str()));
|
||||||
pair.push(format!("{}:0", eos));
|
pair.push(format!("{}:0", eos.as_str()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if add_bos_token {
|
if add_bos_token {
|
||||||
if let Some(bos) = bos_token {
|
if let Some(bos) = bos_token {
|
||||||
pair.push(format!("{}:1", bos));
|
pair.push(format!("{}:1", bos.as_str()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -585,7 +585,7 @@ pub fn create_post_processor(
|
||||||
|
|
||||||
if add_eos_token {
|
if add_eos_token {
|
||||||
if let Some(eos) = eos_token {
|
if let Some(eos) = eos_token {
|
||||||
pair.push(format!("{}:1", eos));
|
pair.push(format!("{}:1", eos.as_str()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -611,14 +611,15 @@ enum RouterError {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use text_generation_router::TokenizerConfigToken;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_create_post_processor() {
|
fn test_create_post_processor() {
|
||||||
let tokenizer_config = HubTokenizerConfig {
|
let tokenizer_config = HubTokenizerConfig {
|
||||||
add_bos_token: None,
|
add_bos_token: None,
|
||||||
add_eos_token: None,
|
add_eos_token: None,
|
||||||
bos_token: Some("<s>".to_string()),
|
bos_token: Some(TokenizerConfigToken::String("<s>".to_string())),
|
||||||
eos_token: Some("</s>".to_string()),
|
eos_token: Some(TokenizerConfigToken::String("</s>".to_string())),
|
||||||
chat_template: None,
|
chat_template: None,
|
||||||
tokenizer_class: None,
|
tokenizer_class: None,
|
||||||
completion_template: None,
|
completion_template: None,
|
||||||
|
@ -629,9 +630,9 @@ mod tests {
|
||||||
let post_processor = create_post_processor(&tokenizer, &tokenizer_config).unwrap();
|
let post_processor = create_post_processor(&tokenizer, &tokenizer_config).unwrap();
|
||||||
|
|
||||||
let expected = TemplateProcessing::builder()
|
let expected = TemplateProcessing::builder()
|
||||||
.try_single("<s>:0 $A:0 <s>:1")
|
.try_single("<s>:0 $A:0")
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.try_pair("<s>:0 $A:0 $B:1")
|
.try_pair("<s>:0 $A:0 <s>:1 $B:1")
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.special_tokens(vec![("<s>".to_string(), 1)])
|
.special_tokens(vec![("<s>".to_string(), 1)])
|
||||||
.build()
|
.build()
|
||||||
|
|
|
@ -12,17 +12,18 @@ use crate::kserve::{
|
||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
use crate::{
|
use crate::{
|
||||||
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
|
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
|
||||||
GenerateResponse, GrammarType, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig,
|
GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info,
|
||||||
HubTokenizerConfig, Info, Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse,
|
Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse,
|
||||||
Token, TokenizeResponse, Usage, Validation,
|
Usage, Validation,
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
|
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
|
||||||
ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob,
|
ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob,
|
||||||
ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk,
|
ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk,
|
||||||
CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse,
|
CompletionRequest, CompletionType, DeltaToolCall, Function, Tool, VertexRequest,
|
||||||
|
VertexResponse,
|
||||||
};
|
};
|
||||||
use crate::{FunctionDefinition, ToolCall, ToolType};
|
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolType};
|
||||||
use async_stream::__private::AsyncStream;
|
use async_stream::__private::AsyncStream;
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, Method, StatusCode};
|
use axum::http::{HeaderMap, Method, StatusCode};
|
||||||
|
@ -635,7 +636,7 @@ async fn completions(
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.prompt.len() > info.max_client_batch_size {
|
if req.prompt.0.len() > info.max_client_batch_size {
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||||
return Err((
|
return Err((
|
||||||
StatusCode::UNPROCESSABLE_ENTITY,
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
|
@ -651,6 +652,7 @@ async fn completions(
|
||||||
|
|
||||||
let generate_requests: Vec<GenerateRequest> = req
|
let generate_requests: Vec<GenerateRequest> = req
|
||||||
.prompt
|
.prompt
|
||||||
|
.0
|
||||||
.iter()
|
.iter()
|
||||||
.map(|prompt| GenerateRequest {
|
.map(|prompt| GenerateRequest {
|
||||||
inputs: prompt.to_string(),
|
inputs: prompt.to_string(),
|
||||||
|
@ -705,7 +707,6 @@ async fn completions(
|
||||||
event
|
event
|
||||||
.json_data(CompletionCompleteChunk {
|
.json_data(CompletionCompleteChunk {
|
||||||
id: "".to_string(),
|
id: "".to_string(),
|
||||||
object: "text_completion".to_string(),
|
|
||||||
created: current_time,
|
created: current_time,
|
||||||
|
|
||||||
choices: vec![CompletionComplete {
|
choices: vec![CompletionComplete {
|
||||||
|
@ -932,7 +933,6 @@ async fn completions(
|
||||||
|
|
||||||
let response = Completion {
|
let response = Completion {
|
||||||
id: "".to_string(),
|
id: "".to_string(),
|
||||||
object: "text_completion".to_string(),
|
|
||||||
created: current_time,
|
created: current_time,
|
||||||
model: info.model_id.clone(),
|
model: info.model_id.clone(),
|
||||||
system_fingerprint: format!(
|
system_fingerprint: format!(
|
||||||
|
@ -1153,7 +1153,8 @@ async fn chat_completions(
|
||||||
};
|
};
|
||||||
|
|
||||||
event
|
event
|
||||||
.json_data(ChatCompletionChunk::new(
|
.json_data(CompletionType::ChatCompletionChunk(
|
||||||
|
ChatCompletionChunk::new(
|
||||||
model_id.clone(),
|
model_id.clone(),
|
||||||
system_fingerprint.clone(),
|
system_fingerprint.clone(),
|
||||||
content,
|
content,
|
||||||
|
@ -1161,6 +1162,7 @@ async fn chat_completions(
|
||||||
current_time,
|
current_time,
|
||||||
logprobs,
|
logprobs,
|
||||||
stream_token.details.map(|d| d.finish_reason.to_string()),
|
stream_token.details.map(|d| d.finish_reason.to_string()),
|
||||||
|
),
|
||||||
))
|
))
|
||||||
.unwrap_or_else(|e| {
|
.unwrap_or_else(|e| {
|
||||||
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
|
println!("Failed to serialize ChatCompletionChunk: {:?}", e);
|
||||||
|
@ -1228,7 +1230,7 @@ async fn chat_completions(
|
||||||
(None, Some(generation.generated_text))
|
(None, Some(generation.generated_text))
|
||||||
};
|
};
|
||||||
// build the complete response object with the full text
|
// build the complete response object with the full text
|
||||||
let response = ChatCompletion::new(
|
let response = CompletionType::ChatCompletion(ChatCompletion::new(
|
||||||
model_id,
|
model_id,
|
||||||
system_fingerprint,
|
system_fingerprint,
|
||||||
output,
|
output,
|
||||||
|
@ -1236,7 +1238,7 @@ async fn chat_completions(
|
||||||
generation.details.unwrap(),
|
generation.details.unwrap(),
|
||||||
logprobs,
|
logprobs,
|
||||||
tool_calls,
|
tool_calls,
|
||||||
);
|
));
|
||||||
|
|
||||||
// wrap generation inside a Vec to match api-inference
|
// wrap generation inside a Vec to match api-inference
|
||||||
Ok((headers, Json(response)).into_response())
|
Ok((headers, Json(response)).into_response())
|
||||||
|
|
Loading…
Reference in New Issue