feat: support raise_exception, bos and eos tokens (#1450)
This PR adds support to handle the custom jinja function `raise_exception` and passes the `bos` and `eos` tokens into the template Additionally this PR adds 3 tests to validate and show examples of what can and cannot be parsed currently. ```bash cargo test --package text-generation-router --lib -- infer::tests --nocapture # Finished test [unoptimized + debuginfo] target(s) in 7.82s # Running unittests src/lib.rs (target/debug/deps/text_generation_router-18a0bbf99c2ca1b4) # running 3 tests # test infer::tests::test_chat_template_valid_with_raise ... ok # test infer::tests::test_chat_template ... ok # test infer::tests::test_chat_template_invalid_with_raise ... ok # test result: ok. 3 passed; 0 failed; 0 ignored; 0 measured; 15 filtered out; finished in 0.00s ```
This commit is contained in:
parent
0eabc83541
commit
3ccb3bb0b5
|
@ -1,8 +1,9 @@
|
||||||
/// Batching and inference logic
|
/// Batching and inference logic
|
||||||
use crate::validation::{Validation, ValidationError};
|
use crate::validation::{Validation, ValidationError};
|
||||||
use crate::HubTokenizerConfig;
|
use crate::{
|
||||||
use crate::{ChatRequest, GenerateRequest, GenerateStreamResponse, PrefillToken};
|
ChatTemplateInputs, Entry, GenerateRequest, GenerateStreamResponse, HubTokenizerConfig,
|
||||||
use crate::{Entry, Queue, Token};
|
Message, PrefillToken, Queue, Token,
|
||||||
|
};
|
||||||
use futures::future::try_join_all;
|
use futures::future::try_join_all;
|
||||||
use minijinja::{Environment, ErrorKind, Template};
|
use minijinja::{Environment, ErrorKind, Template};
|
||||||
use nohash_hasher::IntMap;
|
use nohash_hasher::IntMap;
|
||||||
|
@ -32,8 +33,12 @@ pub struct Infer {
|
||||||
shared: Arc<Shared>,
|
shared: Arc<Shared>,
|
||||||
/// Inference limit
|
/// Inference limit
|
||||||
limit_concurrent_requests: Arc<Semaphore>,
|
limit_concurrent_requests: Arc<Semaphore>,
|
||||||
/// Chat template
|
/// Chat template (template, bos_token, eos_token)
|
||||||
template: Option<Template<'static, 'static>>,
|
template: (
|
||||||
|
Option<Template<'static, 'static>>,
|
||||||
|
Option<String>,
|
||||||
|
Option<String>,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Infer shared state
|
/// Infer shared state
|
||||||
|
@ -42,6 +47,11 @@ struct Shared {
|
||||||
batching_task: Notify,
|
batching_task: Notify,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Raise a exception (custom function) used in the chat templates
|
||||||
|
fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
|
||||||
|
Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text))
|
||||||
|
}
|
||||||
|
|
||||||
impl Infer {
|
impl Infer {
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
|
@ -80,20 +90,28 @@ impl Infer {
|
||||||
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
||||||
|
|
||||||
let template = tokenizer_config.chat_template.map(|t| {
|
let template = tokenizer_config.chat_template.map(|t| {
|
||||||
let env = Box::new(Environment::new());
|
let mut env = Box::new(Environment::new());
|
||||||
let template_str = t.into_boxed_str();
|
let template_str = t.into_boxed_str();
|
||||||
|
env.add_function("raise_exception", raise_exception);
|
||||||
// leaking env and template_str as read-only, static resources for performance.
|
// leaking env and template_str as read-only, static resources for performance.
|
||||||
Box::leak(env)
|
Box::leak(env)
|
||||||
.template_from_str(Box::leak(template_str))
|
.template_from_str(Box::leak(template_str))
|
||||||
.unwrap()
|
.unwrap()
|
||||||
});
|
});
|
||||||
|
let eos_token = tokenizer_config
|
||||||
|
.eos_token
|
||||||
|
.map_or_else(String::new, |t| t)
|
||||||
|
.into();
|
||||||
|
let bos_token = tokenizer_config
|
||||||
|
.bos_token
|
||||||
|
.map_or_else(String::new, |t| t)
|
||||||
|
.into();
|
||||||
Self {
|
Self {
|
||||||
validation,
|
validation,
|
||||||
queue,
|
queue,
|
||||||
shared,
|
shared,
|
||||||
limit_concurrent_requests: semaphore,
|
limit_concurrent_requests: semaphore,
|
||||||
template,
|
template: (template, eos_token, bos_token),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -149,11 +167,16 @@ impl Infer {
|
||||||
|
|
||||||
/// Apply the chat template to the chat request
|
/// Apply the chat template to the chat request
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
pub(crate) fn apply_chat_template(&self, chat: ChatRequest) -> Result<String, InferError> {
|
pub(crate) fn apply_chat_template(&self, messages: Vec<Message>) -> Result<String, InferError> {
|
||||||
self.template
|
let (template, bos_token, eos_token) = &self.template;
|
||||||
|
template
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
||||||
.render(chat)
|
.render(ChatTemplateInputs {
|
||||||
|
messages,
|
||||||
|
eos_token: eos_token.as_deref(),
|
||||||
|
bos_token: bos_token.as_deref(),
|
||||||
|
})
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "template");
|
metrics::increment_counter!("tgi_request_failure", "err" => "template");
|
||||||
tracing::error!("{e}");
|
tracing::error!("{e}");
|
||||||
|
@ -702,3 +725,205 @@ impl InferError {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// tests
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use crate::infer::raise_exception;
|
||||||
|
use crate::ChatTemplateInputs;
|
||||||
|
use crate::Message;
|
||||||
|
use minijinja::Environment;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_chat_template() {
|
||||||
|
let env = Environment::new();
|
||||||
|
|
||||||
|
let source = r#"
|
||||||
|
{% for message in messages %}
|
||||||
|
{% if message['role'] == 'system' %}
|
||||||
|
{% if message['content']%}
|
||||||
|
{{'### System:\n' + message['content']+'\n\n'}}
|
||||||
|
{% endif %}
|
||||||
|
{% elif message['role'] == 'user' %}
|
||||||
|
{{'### User:\n' + message['content']+'\n\n'}}
|
||||||
|
{% elif message['role'] == 'assistant' %}
|
||||||
|
{{'### Assistant:\n' + message['content']}}
|
||||||
|
{% endif %}
|
||||||
|
{% if loop.last and add_generation_prompt %}
|
||||||
|
{{ '### Assistant:\n' }}
|
||||||
|
{% endif %}
|
||||||
|
{% endfor %}"#;
|
||||||
|
|
||||||
|
// trim all the whitespace
|
||||||
|
let source = source
|
||||||
|
.lines()
|
||||||
|
.map(|line| line.trim())
|
||||||
|
.collect::<Vec<&str>>()
|
||||||
|
.join("");
|
||||||
|
|
||||||
|
let tmpl = env.template_from_str(&source);
|
||||||
|
|
||||||
|
let chat_template_inputs = ChatTemplateInputs {
|
||||||
|
messages: vec![
|
||||||
|
Message {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: "Hi!".to_string(),
|
||||||
|
},
|
||||||
|
Message {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: "Hello how can I help?".to_string(),
|
||||||
|
},
|
||||||
|
Message {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: "What is Deep Learning?".to_string(),
|
||||||
|
},
|
||||||
|
Message {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: "magic!".to_string(),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
bos_token: Some("[BOS]"),
|
||||||
|
eos_token: Some("[EOS]"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
result,
|
||||||
|
r#"### User:
|
||||||
|
Hi!
|
||||||
|
|
||||||
|
### Assistant:
|
||||||
|
Hello how can I help?### User:
|
||||||
|
What is Deep Learning?
|
||||||
|
|
||||||
|
### Assistant:
|
||||||
|
magic!"#
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_chat_template_invalid_with_raise() {
|
||||||
|
let mut env = Environment::new();
|
||||||
|
env.add_function("raise_exception", raise_exception);
|
||||||
|
|
||||||
|
let source = r#"
|
||||||
|
{{ bos_token }}
|
||||||
|
{% for message in messages %}
|
||||||
|
{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
|
||||||
|
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
|
||||||
|
{% endif %}
|
||||||
|
{% if message['role'] == 'user' %}
|
||||||
|
{{ '[INST] ' + message['content'] + ' [/INST]' }}
|
||||||
|
{% elif message['role'] == 'assistant' %}
|
||||||
|
{{ message['content'] + eos_token}}
|
||||||
|
{% else %}
|
||||||
|
{{ raise_exception('Only user and assistant roles are supported!') }}
|
||||||
|
{% endif %}
|
||||||
|
{% endfor %}"#;
|
||||||
|
|
||||||
|
// trim all the whitespace
|
||||||
|
let source = source
|
||||||
|
.lines()
|
||||||
|
.map(|line| line.trim())
|
||||||
|
.collect::<Vec<&str>>()
|
||||||
|
.join("");
|
||||||
|
|
||||||
|
let tmpl = env.template_from_str(&source);
|
||||||
|
|
||||||
|
let chat_template_inputs = ChatTemplateInputs {
|
||||||
|
messages: vec![
|
||||||
|
Message {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: "Hi!".to_string(),
|
||||||
|
},
|
||||||
|
Message {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: "Hi again!".to_string(),
|
||||||
|
},
|
||||||
|
Message {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: "Hello how can I help?".to_string(),
|
||||||
|
},
|
||||||
|
Message {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: "What is Deep Learning?".to_string(),
|
||||||
|
},
|
||||||
|
Message {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: "magic!".to_string(),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
bos_token: Some("[BOS]"),
|
||||||
|
eos_token: Some("[EOS]"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap();
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(_) => panic!("Should have failed"),
|
||||||
|
Err(e) => {
|
||||||
|
assert_eq!(
|
||||||
|
e.detail().unwrap(),
|
||||||
|
"Conversation roles must alternate user/assistant/user/assistant/..."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_chat_template_valid_with_raise() {
|
||||||
|
let mut env = Environment::new();
|
||||||
|
env.add_function("raise_exception", raise_exception);
|
||||||
|
|
||||||
|
let source = r#"
|
||||||
|
{{ bos_token }}
|
||||||
|
{% for message in messages %}
|
||||||
|
{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
|
||||||
|
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
|
||||||
|
{% endif %}
|
||||||
|
{% if message['role'] == 'user' %}
|
||||||
|
{{ '[INST] ' + message['content'] + ' [/INST]' }}
|
||||||
|
{% elif message['role'] == 'assistant' %}
|
||||||
|
{{ message['content'] + eos_token}}
|
||||||
|
{% else %}
|
||||||
|
{{ raise_exception('Only user and assistant roles are supported!') }}
|
||||||
|
{% endif %}
|
||||||
|
{% endfor %}"#;
|
||||||
|
|
||||||
|
// trim all the whitespace
|
||||||
|
let source = source
|
||||||
|
.lines()
|
||||||
|
.map(|line| line.trim())
|
||||||
|
.collect::<Vec<&str>>()
|
||||||
|
.join("");
|
||||||
|
|
||||||
|
let tmpl = env.template_from_str(&source);
|
||||||
|
|
||||||
|
let chat_template_inputs = ChatTemplateInputs {
|
||||||
|
messages: vec![
|
||||||
|
Message {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: "Hi!".to_string(),
|
||||||
|
},
|
||||||
|
Message {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: "Hello how can I help?".to_string(),
|
||||||
|
},
|
||||||
|
Message {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: "What is Deep Learning?".to_string(),
|
||||||
|
},
|
||||||
|
Message {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: "magic!".to_string(),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
bos_token: Some("[BOS]"),
|
||||||
|
eos_token: Some("[EOS]"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
|
||||||
|
assert_eq!(result, "[BOS][INST] Hi! [/INST]Hello how can I help?[EOS][INST] What is Deep Learning? [/INST]magic![EOS]");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -31,8 +31,9 @@ pub struct HubModelInfo {
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Default)]
|
#[derive(Clone, Deserialize, Default)]
|
||||||
pub struct HubTokenizerConfig {
|
pub struct HubTokenizerConfig {
|
||||||
#[serde(default)]
|
|
||||||
pub chat_template: Option<String>,
|
pub chat_template: Option<String>,
|
||||||
|
pub bos_token: Option<String>,
|
||||||
|
pub eos_token: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl HubTokenizerConfig {
|
impl HubTokenizerConfig {
|
||||||
|
@ -366,6 +367,13 @@ pub(crate) struct ChatRequest {
|
||||||
pub seed: Option<u64>,
|
pub seed: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Serialize, Deserialize)]
|
||||||
|
pub(crate) struct ChatTemplateInputs<'a> {
|
||||||
|
messages: Vec<Message>,
|
||||||
|
bos_token: Option<&'a str>,
|
||||||
|
eos_token: Option<&'a str>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||||
pub(crate) struct Message {
|
pub(crate) struct Message {
|
||||||
#[schema(example = "user")]
|
#[schema(example = "user")]
|
||||||
|
|
|
@ -2,11 +2,11 @@
|
||||||
use crate::health::Health;
|
use crate::health::Health;
|
||||||
use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
use crate::HubTokenizerConfig;
|
|
||||||
use crate::{
|
use crate::{
|
||||||
BestOfSequence, ChatCompletion, ChatCompletionChunk, ChatRequest, CompatGenerateRequest,
|
BestOfSequence, ChatCompletion, ChatCompletionChunk, ChatRequest, CompatGenerateRequest,
|
||||||
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
|
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
|
||||||
HubModelInfo, Infer, Info, PrefillToken, StreamDetails, StreamResponse, Token, Validation,
|
HubModelInfo, HubTokenizerConfig, Infer, Info, PrefillToken, StreamDetails, StreamResponse,
|
||||||
|
Token, Validation,
|
||||||
};
|
};
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, Method, StatusCode};
|
use axum::http::{HeaderMap, Method, StatusCode};
|
||||||
|
@ -572,7 +572,7 @@ async fn chat_completions(
|
||||||
let seed = req.seed;
|
let seed = req.seed;
|
||||||
|
|
||||||
// apply chat template to flatten the request into a single input
|
// apply chat template to flatten the request into a single input
|
||||||
let inputs = match infer.apply_chat_template(req) {
|
let inputs = match infer.apply_chat_template(req.messages) {
|
||||||
Ok(inputs) => inputs,
|
Ok(inputs) => inputs,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||||
|
@ -659,9 +659,9 @@ async fn chat_completions(
|
||||||
|
|
||||||
// build the complete response object with the full text
|
// build the complete response object with the full text
|
||||||
let response = ChatCompletion::new(
|
let response = ChatCompletion::new(
|
||||||
generation.generated_text,
|
|
||||||
model_id,
|
model_id,
|
||||||
system_fingerprint,
|
system_fingerprint,
|
||||||
|
generation.generated_text,
|
||||||
current_time,
|
current_time,
|
||||||
generation.details.unwrap(),
|
generation.details.unwrap(),
|
||||||
logprobs,
|
logprobs,
|
||||||
|
|
Loading…
Reference in New Issue