diff --git a/Cargo.lock b/Cargo.lock index a03da8b2..090e2e80 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3762,7 +3762,7 @@ dependencies = [ [[package]] name = "text-generation-benchmark" -version = "2.0.5-dev0" +version = "2.1.1-dev0" dependencies = [ "average", "clap", @@ -3783,7 +3783,7 @@ dependencies = [ [[package]] name = "text-generation-client" -version = "2.0.5-dev0" +version = "2.1.1-dev0" dependencies = [ "async-trait", "base64 0.22.1", @@ -3801,7 +3801,7 @@ dependencies = [ [[package]] name = "text-generation-launcher" -version = "2.0.5-dev0" +version = "2.1.1-dev0" dependencies = [ "clap", "ctrlc", @@ -3820,7 +3820,7 @@ dependencies = [ [[package]] name = "text-generation-router" -version = "2.0.5-dev0" +version = "2.1.1-dev0" dependencies = [ "async-stream", "axum 0.7.5", @@ -3855,8 +3855,6 @@ dependencies = [ "tokio-stream", "tower-http", "tracing", - "tracing-core", - "tracing-log 0.2.0", "tracing-opentelemetry 0.21.0", "tracing-subscriber", "utoipa", diff --git a/router/src/config.rs b/router/src/config.rs index ccbdd8b2..7737165e 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -162,6 +162,7 @@ pub enum Config { Baichuan, Paligemma(Paligemma), Gemma, + Gemma2, Cohere, Drbx, Falcon, diff --git a/router/src/lib.rs b/router/src/lib.rs index 4ba76f5f..a5b97af3 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -61,6 +61,9 @@ pub struct HubTokenizerConfig { pub bos_token: Option, #[serde(deserialize_with = "token_serde::deserialize")] pub eos_token: Option, + pub tokenizer_class: Option, + pub add_bos_token: Option, + pub add_eos_token: Option, } impl HubTokenizerConfig { diff --git a/router/src/main.rs b/router/src/main.rs index 68b6b1fc..3aa5a6bf 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -17,7 +17,7 @@ use text_generation_router::{ server, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig, }; use thiserror::Error; -use tokenizers::Tokenizer; +use tokenizers::{processors::template::TemplateProcessing, Tokenizer}; use tower_http::cors::AllowOrigin; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; @@ -275,8 +275,6 @@ async fn main() -> Result<(), RouterError> { ) } }; - let tokenizer: Option = - tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok()); let config: Option = config_filename.and_then(|filename| { std::fs::read_to_string(filename) .ok() @@ -306,6 +304,37 @@ async fn main() -> Result<(), RouterError> { tracing::warn!("Could not find tokenizer config locally and no API specified"); HubTokenizerConfig::default() }); + let tokenizer: Option = + tokenizer_filename.and_then(|filename| { + let mut tokenizer = Tokenizer::from_file(filename).ok(); + if let Some(tokenizer) = &mut tokenizer{ + if let Some(class) = &tokenizer_config.tokenizer_class{ + if class == "LlamaTokenizer" || class == "LlamaTokenizerFast" { + tracing::info!("Overriding LllamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205"); + let mut single = vec![]; + let mut special_tokens = vec![]; + if let Some(true) = &tokenizer_config.add_bos_token{ + if let Some(bos_token) = &tokenizer_config.bos_token{ + let bos_token_id = tokenizer.token_to_id(&bos_token).expect("Should have found the bos token id"); + special_tokens.push((bos_token.clone(), bos_token_id)); + single.push(bos_token.to_string()); + } + } + single.push("$0".to_string()); + if let Some(true) = &tokenizer_config.add_eos_token{ + if let Some(eos_token) = &tokenizer_config.eos_token{ + let eos_token_id = tokenizer.token_to_id(&eos_token).expect("Should have found the eos token id"); + special_tokens.push((eos_token.clone(), eos_token_id)); + single.push(eos_token.to_string()); + } + } + let post_processor = TemplateProcessing::builder().try_single(single).unwrap().special_tokens(special_tokens).build().unwrap(); + tokenizer.with_post_processor(post_processor); + }} + } + tokenizer + + }); let preprocessor_config = preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file);