diff --git a/router/src/infer.rs b/router/src/infer.rs index 6de07982..8a9875eb 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -1,8 +1,9 @@ /// Batching and inference logic use crate::validation::{Validation, ValidationError}; -use crate::HubTokenizerConfig; -use crate::{ChatRequest, GenerateRequest, GenerateStreamResponse, PrefillToken}; -use crate::{Entry, Queue, Token}; +use crate::{ + ChatTemplateInputs, Entry, GenerateRequest, GenerateStreamResponse, HubTokenizerConfig, + Message, PrefillToken, Queue, Token, +}; use futures::future::try_join_all; use minijinja::{Environment, ErrorKind, Template}; use nohash_hasher::IntMap; @@ -32,8 +33,12 @@ pub struct Infer { shared: Arc, /// Inference limit limit_concurrent_requests: Arc, - /// Chat template - template: Option>, + /// Chat template (template, bos_token, eos_token) + template: ( + Option>, + Option, + Option, + ), } /// Infer shared state @@ -42,6 +47,11 @@ struct Shared { batching_task: Notify, } +/// Raise a exception (custom function) used in the chat templates +fn raise_exception(err_text: String) -> Result { + Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text)) +} + impl Infer { #[allow(clippy::too_many_arguments)] pub(crate) fn new( @@ -80,20 +90,28 @@ impl Infer { let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); let template = tokenizer_config.chat_template.map(|t| { - let env = Box::new(Environment::new()); + let mut env = Box::new(Environment::new()); let template_str = t.into_boxed_str(); + env.add_function("raise_exception", raise_exception); // leaking env and template_str as read-only, static resources for performance. Box::leak(env) .template_from_str(Box::leak(template_str)) .unwrap() }); - + let eos_token = tokenizer_config + .eos_token + .map_or_else(String::new, |t| t) + .into(); + let bos_token = tokenizer_config + .bos_token + .map_or_else(String::new, |t| t) + .into(); Self { validation, queue, shared, limit_concurrent_requests: semaphore, - template, + template: (template, eos_token, bos_token), } } @@ -149,11 +167,16 @@ impl Infer { /// Apply the chat template to the chat request #[instrument(skip_all)] - pub(crate) fn apply_chat_template(&self, chat: ChatRequest) -> Result { - self.template + pub(crate) fn apply_chat_template(&self, messages: Vec) -> Result { + let (template, bos_token, eos_token) = &self.template; + template .as_ref() .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? - .render(chat) + .render(ChatTemplateInputs { + messages, + eos_token: eos_token.as_deref(), + bos_token: bos_token.as_deref(), + }) .map_err(|e| { metrics::increment_counter!("tgi_request_failure", "err" => "template"); tracing::error!("{e}"); @@ -702,3 +725,205 @@ impl InferError { } } } + +// tests +#[cfg(test)] +mod tests { + use crate::infer::raise_exception; + use crate::ChatTemplateInputs; + use crate::Message; + use minijinja::Environment; + + #[test] + fn test_chat_template() { + let env = Environment::new(); + + let source = r#" + {% for message in messages %} + {% if message['role'] == 'system' %} + {% if message['content']%} + {{'### System:\n' + message['content']+'\n\n'}} + {% endif %} + {% elif message['role'] == 'user' %} + {{'### User:\n' + message['content']+'\n\n'}} + {% elif message['role'] == 'assistant' %} + {{'### Assistant:\n' + message['content']}} + {% endif %} + {% if loop.last and add_generation_prompt %} + {{ '### Assistant:\n' }} + {% endif %} + {% endfor %}"#; + + // trim all the whitespace + let source = source + .lines() + .map(|line| line.trim()) + .collect::>() + .join(""); + + let tmpl = env.template_from_str(&source); + + let chat_template_inputs = ChatTemplateInputs { + messages: vec![ + Message { + role: "user".to_string(), + content: "Hi!".to_string(), + }, + Message { + role: "assistant".to_string(), + content: "Hello how can I help?".to_string(), + }, + Message { + role: "user".to_string(), + content: "What is Deep Learning?".to_string(), + }, + Message { + role: "assistant".to_string(), + content: "magic!".to_string(), + }, + ], + bos_token: Some("[BOS]"), + eos_token: Some("[EOS]"), + }; + + let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); + + assert_eq!( + result, + r#"### User: +Hi! + +### Assistant: +Hello how can I help?### User: +What is Deep Learning? + +### Assistant: +magic!"# + ); + } + + #[test] + fn test_chat_template_invalid_with_raise() { + let mut env = Environment::new(); + env.add_function("raise_exception", raise_exception); + + let source = r#" + {{ bos_token }} + {% for message in messages %} + {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} + {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} + {% endif %} + {% if message['role'] == 'user' %} + {{ '[INST] ' + message['content'] + ' [/INST]' }} + {% elif message['role'] == 'assistant' %} + {{ message['content'] + eos_token}} + {% else %} + {{ raise_exception('Only user and assistant roles are supported!') }} + {% endif %} + {% endfor %}"#; + + // trim all the whitespace + let source = source + .lines() + .map(|line| line.trim()) + .collect::>() + .join(""); + + let tmpl = env.template_from_str(&source); + + let chat_template_inputs = ChatTemplateInputs { + messages: vec![ + Message { + role: "user".to_string(), + content: "Hi!".to_string(), + }, + Message { + role: "user".to_string(), + content: "Hi again!".to_string(), + }, + Message { + role: "assistant".to_string(), + content: "Hello how can I help?".to_string(), + }, + Message { + role: "user".to_string(), + content: "What is Deep Learning?".to_string(), + }, + Message { + role: "assistant".to_string(), + content: "magic!".to_string(), + }, + ], + bos_token: Some("[BOS]"), + eos_token: Some("[EOS]"), + }; + + let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap(); + + match result { + Ok(_) => panic!("Should have failed"), + Err(e) => { + assert_eq!( + e.detail().unwrap(), + "Conversation roles must alternate user/assistant/user/assistant/..." + ); + } + } + } + + #[test] + fn test_chat_template_valid_with_raise() { + let mut env = Environment::new(); + env.add_function("raise_exception", raise_exception); + + let source = r#" + {{ bos_token }} + {% for message in messages %} + {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} + {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} + {% endif %} + {% if message['role'] == 'user' %} + {{ '[INST] ' + message['content'] + ' [/INST]' }} + {% elif message['role'] == 'assistant' %} + {{ message['content'] + eos_token}} + {% else %} + {{ raise_exception('Only user and assistant roles are supported!') }} + {% endif %} + {% endfor %}"#; + + // trim all the whitespace + let source = source + .lines() + .map(|line| line.trim()) + .collect::>() + .join(""); + + let tmpl = env.template_from_str(&source); + + let chat_template_inputs = ChatTemplateInputs { + messages: vec![ + Message { + role: "user".to_string(), + content: "Hi!".to_string(), + }, + Message { + role: "assistant".to_string(), + content: "Hello how can I help?".to_string(), + }, + Message { + role: "user".to_string(), + content: "What is Deep Learning?".to_string(), + }, + Message { + role: "assistant".to_string(), + content: "magic!".to_string(), + }, + ], + bos_token: Some("[BOS]"), + eos_token: Some("[EOS]"), + }; + + let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); + assert_eq!(result, "[BOS][INST] Hi! [/INST]Hello how can I help?[EOS][INST] What is Deep Learning? [/INST]magic![EOS]"); + } +} diff --git a/router/src/lib.rs b/router/src/lib.rs index f6f8276f..983079d6 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -31,8 +31,9 @@ pub struct HubModelInfo { #[derive(Clone, Deserialize, Default)] pub struct HubTokenizerConfig { - #[serde(default)] pub chat_template: Option, + pub bos_token: Option, + pub eos_token: Option, } impl HubTokenizerConfig { @@ -366,6 +367,13 @@ pub(crate) struct ChatRequest { pub seed: Option, } +#[derive(Clone, Serialize, Deserialize)] +pub(crate) struct ChatTemplateInputs<'a> { + messages: Vec, + bos_token: Option<&'a str>, + eos_token: Option<&'a str>, +} + #[derive(Clone, Deserialize, ToSchema, Serialize)] pub(crate) struct Message { #[schema(example = "user")] diff --git a/router/src/server.rs b/router/src/server.rs index fe1827c4..530a935b 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -2,11 +2,11 @@ use crate::health::Health; use crate::infer::{InferError, InferResponse, InferStreamResponse}; use crate::validation::ValidationError; -use crate::HubTokenizerConfig; use crate::{ BestOfSequence, ChatCompletion, ChatCompletionChunk, ChatRequest, CompatGenerateRequest, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, - HubModelInfo, Infer, Info, PrefillToken, StreamDetails, StreamResponse, Token, Validation, + HubModelInfo, HubTokenizerConfig, Infer, Info, PrefillToken, StreamDetails, StreamResponse, + Token, Validation, }; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; @@ -572,7 +572,7 @@ async fn chat_completions( let seed = req.seed; // apply chat template to flatten the request into a single input - let inputs = match infer.apply_chat_template(req) { + let inputs = match infer.apply_chat_template(req.messages) { Ok(inputs) => inputs, Err(err) => { metrics::increment_counter!("tgi_request_failure", "err" => "validation"); @@ -659,9 +659,9 @@ async fn chat_completions( // build the complete response object with the full text let response = ChatCompletion::new( - generation.generated_text, model_id, system_fingerprint, + generation.generated_text, current_time, generation.details.unwrap(), logprobs,