feat: support raise_exception, bos and eos tokens (#1450)

This PR adds support to handle the custom jinja function
`raise_exception` and passes the `bos` and `eos` tokens into the
template

Additionally this PR adds 3 tests to validate and show examples of what
can and cannot be parsed currently.

```bash
cargo test --package text-generation-router --lib -- infer::tests --nocapture
#     Finished test [unoptimized + debuginfo] target(s) in 7.82s
#      Running unittests src/lib.rs (target/debug/deps/text_generation_router-18a0bbf99c2ca1b4)

# running 3 tests
# test infer::tests::test_chat_template_valid_with_raise ... ok
# test infer::tests::test_chat_template ... ok
# test infer::tests::test_chat_template_invalid_with_raise ... ok

# test result: ok. 3 passed; 0 failed; 0 ignored; 0 measured; 15 filtered out; finished in 0.00s
```
This commit is contained in:
drbh 2024-01-18 06:31:56 -05:00 committed by GitHub
parent 0eabc83541
commit 3ccb3bb0b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 249 additions and 16 deletions

View File

@ -1,8 +1,9 @@
/// Batching and inference logic /// Batching and inference logic
use crate::validation::{Validation, ValidationError}; use crate::validation::{Validation, ValidationError};
use crate::HubTokenizerConfig; use crate::{
use crate::{ChatRequest, GenerateRequest, GenerateStreamResponse, PrefillToken}; ChatTemplateInputs, Entry, GenerateRequest, GenerateStreamResponse, HubTokenizerConfig,
use crate::{Entry, Queue, Token}; Message, PrefillToken, Queue, Token,
};
use futures::future::try_join_all; use futures::future::try_join_all;
use minijinja::{Environment, ErrorKind, Template}; use minijinja::{Environment, ErrorKind, Template};
use nohash_hasher::IntMap; use nohash_hasher::IntMap;
@ -32,8 +33,12 @@ pub struct Infer {
shared: Arc<Shared>, shared: Arc<Shared>,
/// Inference limit /// Inference limit
limit_concurrent_requests: Arc<Semaphore>, limit_concurrent_requests: Arc<Semaphore>,
/// Chat template /// Chat template (template, bos_token, eos_token)
template: Option<Template<'static, 'static>>, template: (
Option<Template<'static, 'static>>,
Option<String>,
Option<String>,
),
} }
/// Infer shared state /// Infer shared state
@ -42,6 +47,11 @@ struct Shared {
batching_task: Notify, batching_task: Notify,
} }
/// Raise a exception (custom function) used in the chat templates
fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text))
}
impl Infer { impl Infer {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub(crate) fn new( pub(crate) fn new(
@ -80,20 +90,28 @@ impl Infer {
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
let template = tokenizer_config.chat_template.map(|t| { let template = tokenizer_config.chat_template.map(|t| {
let env = Box::new(Environment::new()); let mut env = Box::new(Environment::new());
let template_str = t.into_boxed_str(); let template_str = t.into_boxed_str();
env.add_function("raise_exception", raise_exception);
// leaking env and template_str as read-only, static resources for performance. // leaking env and template_str as read-only, static resources for performance.
Box::leak(env) Box::leak(env)
.template_from_str(Box::leak(template_str)) .template_from_str(Box::leak(template_str))
.unwrap() .unwrap()
}); });
let eos_token = tokenizer_config
.eos_token
.map_or_else(String::new, |t| t)
.into();
let bos_token = tokenizer_config
.bos_token
.map_or_else(String::new, |t| t)
.into();
Self { Self {
validation, validation,
queue, queue,
shared, shared,
limit_concurrent_requests: semaphore, limit_concurrent_requests: semaphore,
template, template: (template, eos_token, bos_token),
} }
} }
@ -149,11 +167,16 @@ impl Infer {
/// Apply the chat template to the chat request /// Apply the chat template to the chat request
#[instrument(skip_all)] #[instrument(skip_all)]
pub(crate) fn apply_chat_template(&self, chat: ChatRequest) -> Result<String, InferError> { pub(crate) fn apply_chat_template(&self, messages: Vec<Message>) -> Result<String, InferError> {
self.template let (template, bos_token, eos_token) = &self.template;
template
.as_ref() .as_ref()
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
.render(chat) .render(ChatTemplateInputs {
messages,
eos_token: eos_token.as_deref(),
bos_token: bos_token.as_deref(),
})
.map_err(|e| { .map_err(|e| {
metrics::increment_counter!("tgi_request_failure", "err" => "template"); metrics::increment_counter!("tgi_request_failure", "err" => "template");
tracing::error!("{e}"); tracing::error!("{e}");
@ -702,3 +725,205 @@ impl InferError {
} }
} }
} }
// tests
#[cfg(test)]
mod tests {
use crate::infer::raise_exception;
use crate::ChatTemplateInputs;
use crate::Message;
use minijinja::Environment;
#[test]
fn test_chat_template() {
let env = Environment::new();
let source = r#"
{% for message in messages %}
{% if message['role'] == 'system' %}
{% if message['content']%}
{{'### System:\n' + message['content']+'\n\n'}}
{% endif %}
{% elif message['role'] == 'user' %}
{{'### User:\n' + message['content']+'\n\n'}}
{% elif message['role'] == 'assistant' %}
{{'### Assistant:\n' + message['content']}}
{% endif %}
{% if loop.last and add_generation_prompt %}
{{ '### Assistant:\n' }}
{% endif %}
{% endfor %}"#;
// trim all the whitespace
let source = source
.lines()
.map(|line| line.trim())
.collect::<Vec<&str>>()
.join("");
let tmpl = env.template_from_str(&source);
let chat_template_inputs = ChatTemplateInputs {
messages: vec![
Message {
role: "user".to_string(),
content: "Hi!".to_string(),
},
Message {
role: "assistant".to_string(),
content: "Hello how can I help?".to_string(),
},
Message {
role: "user".to_string(),
content: "What is Deep Learning?".to_string(),
},
Message {
role: "assistant".to_string(),
content: "magic!".to_string(),
},
],
bos_token: Some("[BOS]"),
eos_token: Some("[EOS]"),
};
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
assert_eq!(
result,
r#"### User:
Hi!
### Assistant:
Hello how can I help?### User:
What is Deep Learning?
### Assistant:
magic!"#
);
}
#[test]
fn test_chat_template_invalid_with_raise() {
let mut env = Environment::new();
env.add_function("raise_exception", raise_exception);
let source = r#"
{{ bos_token }}
{% for message in messages %}
{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
{% endif %}
{% if message['role'] == 'user' %}
{{ '[INST] ' + message['content'] + ' [/INST]' }}
{% elif message['role'] == 'assistant' %}
{{ message['content'] + eos_token}}
{% else %}
{{ raise_exception('Only user and assistant roles are supported!') }}
{% endif %}
{% endfor %}"#;
// trim all the whitespace
let source = source
.lines()
.map(|line| line.trim())
.collect::<Vec<&str>>()
.join("");
let tmpl = env.template_from_str(&source);
let chat_template_inputs = ChatTemplateInputs {
messages: vec![
Message {
role: "user".to_string(),
content: "Hi!".to_string(),
},
Message {
role: "user".to_string(),
content: "Hi again!".to_string(),
},
Message {
role: "assistant".to_string(),
content: "Hello how can I help?".to_string(),
},
Message {
role: "user".to_string(),
content: "What is Deep Learning?".to_string(),
},
Message {
role: "assistant".to_string(),
content: "magic!".to_string(),
},
],
bos_token: Some("[BOS]"),
eos_token: Some("[EOS]"),
};
let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap();
match result {
Ok(_) => panic!("Should have failed"),
Err(e) => {
assert_eq!(
e.detail().unwrap(),
"Conversation roles must alternate user/assistant/user/assistant/..."
);
}
}
}
#[test]
fn test_chat_template_valid_with_raise() {
let mut env = Environment::new();
env.add_function("raise_exception", raise_exception);
let source = r#"
{{ bos_token }}
{% for message in messages %}
{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
{% endif %}
{% if message['role'] == 'user' %}
{{ '[INST] ' + message['content'] + ' [/INST]' }}
{% elif message['role'] == 'assistant' %}
{{ message['content'] + eos_token}}
{% else %}
{{ raise_exception('Only user and assistant roles are supported!') }}
{% endif %}
{% endfor %}"#;
// trim all the whitespace
let source = source
.lines()
.map(|line| line.trim())
.collect::<Vec<&str>>()
.join("");
let tmpl = env.template_from_str(&source);
let chat_template_inputs = ChatTemplateInputs {
messages: vec![
Message {
role: "user".to_string(),
content: "Hi!".to_string(),
},
Message {
role: "assistant".to_string(),
content: "Hello how can I help?".to_string(),
},
Message {
role: "user".to_string(),
content: "What is Deep Learning?".to_string(),
},
Message {
role: "assistant".to_string(),
content: "magic!".to_string(),
},
],
bos_token: Some("[BOS]"),
eos_token: Some("[EOS]"),
};
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
assert_eq!(result, "[BOS][INST] Hi! [/INST]Hello how can I help?[EOS][INST] What is Deep Learning? [/INST]magic![EOS]");
}
}

View File

@ -31,8 +31,9 @@ pub struct HubModelInfo {
#[derive(Clone, Deserialize, Default)] #[derive(Clone, Deserialize, Default)]
pub struct HubTokenizerConfig { pub struct HubTokenizerConfig {
#[serde(default)]
pub chat_template: Option<String>, pub chat_template: Option<String>,
pub bos_token: Option<String>,
pub eos_token: Option<String>,
} }
impl HubTokenizerConfig { impl HubTokenizerConfig {
@ -366,6 +367,13 @@ pub(crate) struct ChatRequest {
pub seed: Option<u64>, pub seed: Option<u64>,
} }
#[derive(Clone, Serialize, Deserialize)]
pub(crate) struct ChatTemplateInputs<'a> {
messages: Vec<Message>,
bos_token: Option<&'a str>,
eos_token: Option<&'a str>,
}
#[derive(Clone, Deserialize, ToSchema, Serialize)] #[derive(Clone, Deserialize, ToSchema, Serialize)]
pub(crate) struct Message { pub(crate) struct Message {
#[schema(example = "user")] #[schema(example = "user")]

View File

@ -2,11 +2,11 @@
use crate::health::Health; use crate::health::Health;
use crate::infer::{InferError, InferResponse, InferStreamResponse}; use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::HubTokenizerConfig;
use crate::{ use crate::{
BestOfSequence, ChatCompletion, ChatCompletionChunk, ChatRequest, CompatGenerateRequest, BestOfSequence, ChatCompletion, ChatCompletionChunk, ChatRequest, CompatGenerateRequest,
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
HubModelInfo, Infer, Info, PrefillToken, StreamDetails, StreamResponse, Token, Validation, HubModelInfo, HubTokenizerConfig, Infer, Info, PrefillToken, StreamDetails, StreamResponse,
Token, Validation,
}; };
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode}; use axum::http::{HeaderMap, Method, StatusCode};
@ -572,7 +572,7 @@ async fn chat_completions(
let seed = req.seed; let seed = req.seed;
// apply chat template to flatten the request into a single input // apply chat template to flatten the request into a single input
let inputs = match infer.apply_chat_template(req) { let inputs = match infer.apply_chat_template(req.messages) {
Ok(inputs) => inputs, Ok(inputs) => inputs,
Err(err) => { Err(err) => {
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); metrics::increment_counter!("tgi_request_failure", "err" => "validation");
@ -659,9 +659,9 @@ async fn chat_completions(
// build the complete response object with the full text // build the complete response object with the full text
let response = ChatCompletion::new( let response = ChatCompletion::new(
generation.generated_text,
model_id, model_id,
system_fingerprint, system_fingerprint,
generation.generated_text,
current_time, current_time,
generation.details.unwrap(), generation.details.unwrap(),
logprobs, logprobs,