Fixing malformed rust tokenizers (#2134)

* Fixing malformed rust tokenizers

* Fix for deepseek too.
This commit is contained in:
Nicolas Patry 2024-06-27 16:04:03 +02:00 committed by GitHub
parent dd2d91b043
commit 0e4ab6d31c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 40 additions and 9 deletions

10
Cargo.lock generated
View File

@ -3762,7 +3762,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-benchmark" name = "text-generation-benchmark"
version = "2.0.5-dev0" version = "2.1.1-dev0"
dependencies = [ dependencies = [
"average", "average",
"clap", "clap",
@ -3783,7 +3783,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-client" name = "text-generation-client"
version = "2.0.5-dev0" version = "2.1.1-dev0"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"base64 0.22.1", "base64 0.22.1",
@ -3801,7 +3801,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-launcher" name = "text-generation-launcher"
version = "2.0.5-dev0" version = "2.1.1-dev0"
dependencies = [ dependencies = [
"clap", "clap",
"ctrlc", "ctrlc",
@ -3820,7 +3820,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-router" name = "text-generation-router"
version = "2.0.5-dev0" version = "2.1.1-dev0"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"axum 0.7.5", "axum 0.7.5",
@ -3855,8 +3855,6 @@ dependencies = [
"tokio-stream", "tokio-stream",
"tower-http", "tower-http",
"tracing", "tracing",
"tracing-core",
"tracing-log 0.2.0",
"tracing-opentelemetry 0.21.0", "tracing-opentelemetry 0.21.0",
"tracing-subscriber", "tracing-subscriber",
"utoipa", "utoipa",

View File

@ -162,6 +162,7 @@ pub enum Config {
Baichuan, Baichuan,
Paligemma(Paligemma), Paligemma(Paligemma),
Gemma, Gemma,
Gemma2,
Cohere, Cohere,
Drbx, Drbx,
Falcon, Falcon,

View File

@ -61,6 +61,9 @@ pub struct HubTokenizerConfig {
pub bos_token: Option<String>, pub bos_token: Option<String>,
#[serde(deserialize_with = "token_serde::deserialize")] #[serde(deserialize_with = "token_serde::deserialize")]
pub eos_token: Option<String>, pub eos_token: Option<String>,
pub tokenizer_class: Option<String>,
pub add_bos_token: Option<bool>,
pub add_eos_token: Option<bool>,
} }
impl HubTokenizerConfig { impl HubTokenizerConfig {

View File

@ -17,7 +17,7 @@ use text_generation_router::{
server, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig, server, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig,
}; };
use thiserror::Error; use thiserror::Error;
use tokenizers::Tokenizer; use tokenizers::{processors::template::TemplateProcessing, Tokenizer};
use tower_http::cors::AllowOrigin; use tower_http::cors::AllowOrigin;
use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::util::SubscriberInitExt;
@ -275,8 +275,6 @@ async fn main() -> Result<(), RouterError> {
) )
} }
}; };
let tokenizer: Option<Tokenizer> =
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok());
let config: Option<Config> = config_filename.and_then(|filename| { let config: Option<Config> = config_filename.and_then(|filename| {
std::fs::read_to_string(filename) std::fs::read_to_string(filename)
.ok() .ok()
@ -306,6 +304,37 @@ async fn main() -> Result<(), RouterError> {
tracing::warn!("Could not find tokenizer config locally and no API specified"); tracing::warn!("Could not find tokenizer config locally and no API specified");
HubTokenizerConfig::default() HubTokenizerConfig::default()
}); });
let tokenizer: Option<Tokenizer> =
tokenizer_filename.and_then(|filename| {
let mut tokenizer = Tokenizer::from_file(filename).ok();
if let Some(tokenizer) = &mut tokenizer{
if let Some(class) = &tokenizer_config.tokenizer_class{
if class == "LlamaTokenizer" || class == "LlamaTokenizerFast" {
tracing::info!("Overriding LllamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
let mut single = vec![];
let mut special_tokens = vec![];
if let Some(true) = &tokenizer_config.add_bos_token{
if let Some(bos_token) = &tokenizer_config.bos_token{
let bos_token_id = tokenizer.token_to_id(&bos_token).expect("Should have found the bos token id");
special_tokens.push((bos_token.clone(), bos_token_id));
single.push(bos_token.to_string());
}
}
single.push("$0".to_string());
if let Some(true) = &tokenizer_config.add_eos_token{
if let Some(eos_token) = &tokenizer_config.eos_token{
let eos_token_id = tokenizer.token_to_id(&eos_token).expect("Should have found the eos token id");
special_tokens.push((eos_token.clone(), eos_token_id));
single.push(eos_token.to_string());
}
}
let post_processor = TemplateProcessing::builder().try_single(single).unwrap().special_tokens(special_tokens).build().unwrap();
tokenizer.with_post_processor(post_processor);
}}
}
tokenizer
});
let preprocessor_config = let preprocessor_config =
preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file); preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file);