From 0f2daad8b959361aa41d5500d3778e23b1118bdc Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Fri, 16 Feb 2024 16:37:32 +0100 Subject: [PATCH] feat: add chat template struct to avoid tuple ordering errors (#1570) --- router/src/infer.rs | 79 ++++++++++++++++++++++++++------------------- 1 file changed, 46 insertions(+), 33 deletions(-) diff --git a/router/src/infer.rs b/router/src/infer.rs index 655759b9..472b7d66 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -31,14 +31,10 @@ pub struct Infer { queue: Queue, /// Shared state shared: Arc, + /// Chat template + chat_template: Option, /// Inference limit limit_concurrent_requests: Arc, - /// Chat template (template, bos_token, eos_token) - template: ( - Option>, - Option, - Option, - ), } /// Infer shared state @@ -88,32 +84,19 @@ impl Infer { generation_health, )); + let chat_template = tokenizer_config + .chat_template + .map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)); + // Inference limit with a semaphore let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); - let template = tokenizer_config.chat_template.map(|t| { - let mut env = Box::new(Environment::new()); - let template_str = t.into_boxed_str(); - env.add_function("raise_exception", raise_exception); - // leaking env and template_str as read-only, static resources for performance. - Box::leak(env) - .template_from_str(Box::leak(template_str)) - .unwrap() - }); - let eos_token = tokenizer_config - .eos_token - .map_or_else(String::new, |t| t) - .into(); - let bos_token = tokenizer_config - .bos_token - .map_or_else(String::new, |t| t) - .into(); Self { validation, queue, shared, + chat_template, limit_concurrent_requests: semaphore, - template: (template, bos_token, eos_token), } } @@ -192,20 +175,14 @@ impl Infer { /// Apply the chat template to the chat request #[instrument(skip_all)] pub(crate) fn apply_chat_template(&self, messages: Vec) -> Result { - let (template, bos_token, eos_token) = &self.template; - template + self.chat_template .as_ref() .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? - .render(ChatTemplateInputs { - messages, - eos_token: eos_token.as_deref(), - bos_token: bos_token.as_deref(), - add_generation_prompt: true, - }) + .apply(messages) .map_err(|e| { metrics::increment_counter!("tgi_request_failure", "err" => "template"); tracing::error!("{e}"); - InferError::TemplateError(e) + e }) } @@ -329,6 +306,42 @@ impl Infer { } } +#[derive(Clone)] +struct ChatTemplate { + template: Template<'static, 'static>, + bos_token: Option, + eos_token: Option, +} + +impl ChatTemplate { + fn new(template: String, bos_token: Option, eos_token: Option) -> Self { + let mut env = Box::new(Environment::new()); + let template_str = template.into_boxed_str(); + env.add_function("raise_exception", raise_exception); + // leaking env and template_str as read-only, static resources for performance. + let template = Box::leak(env) + .template_from_str(Box::leak(template_str)) + .unwrap(); + + Self { + template, + bos_token, + eos_token, + } + } + + fn apply(&self, messages: Vec) -> Result { + self.template + .render(ChatTemplateInputs { + messages, + bos_token: self.bos_token.as_deref(), + eos_token: self.eos_token.as_deref(), + add_generation_prompt: true, + }) + .map_err(InferError::TemplateError) + } +} + /// Batching logic /// Will be launched in a background Tokio task ///