Fixing malformed rust tokenizers (#2134)
* Fixing malformed rust tokenizers * Fix for deepseek too.
This commit is contained in:
parent
dd2d91b043
commit
0e4ab6d31c
|
@ -3762,7 +3762,7 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-benchmark"
|
name = "text-generation-benchmark"
|
||||||
version = "2.0.5-dev0"
|
version = "2.1.1-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"average",
|
"average",
|
||||||
"clap",
|
"clap",
|
||||||
|
@ -3783,7 +3783,7 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-client"
|
name = "text-generation-client"
|
||||||
version = "2.0.5-dev0"
|
version = "2.1.1-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"base64 0.22.1",
|
"base64 0.22.1",
|
||||||
|
@ -3801,7 +3801,7 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-launcher"
|
name = "text-generation-launcher"
|
||||||
version = "2.0.5-dev0"
|
version = "2.1.1-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"clap",
|
"clap",
|
||||||
"ctrlc",
|
"ctrlc",
|
||||||
|
@ -3820,7 +3820,7 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-router"
|
name = "text-generation-router"
|
||||||
version = "2.0.5-dev0"
|
version = "2.1.1-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"axum 0.7.5",
|
"axum 0.7.5",
|
||||||
|
@ -3855,8 +3855,6 @@ dependencies = [
|
||||||
"tokio-stream",
|
"tokio-stream",
|
||||||
"tower-http",
|
"tower-http",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-core",
|
|
||||||
"tracing-log 0.2.0",
|
|
||||||
"tracing-opentelemetry 0.21.0",
|
"tracing-opentelemetry 0.21.0",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
"utoipa",
|
"utoipa",
|
||||||
|
|
|
@ -162,6 +162,7 @@ pub enum Config {
|
||||||
Baichuan,
|
Baichuan,
|
||||||
Paligemma(Paligemma),
|
Paligemma(Paligemma),
|
||||||
Gemma,
|
Gemma,
|
||||||
|
Gemma2,
|
||||||
Cohere,
|
Cohere,
|
||||||
Drbx,
|
Drbx,
|
||||||
Falcon,
|
Falcon,
|
||||||
|
|
|
@ -61,6 +61,9 @@ pub struct HubTokenizerConfig {
|
||||||
pub bos_token: Option<String>,
|
pub bos_token: Option<String>,
|
||||||
#[serde(deserialize_with = "token_serde::deserialize")]
|
#[serde(deserialize_with = "token_serde::deserialize")]
|
||||||
pub eos_token: Option<String>,
|
pub eos_token: Option<String>,
|
||||||
|
pub tokenizer_class: Option<String>,
|
||||||
|
pub add_bos_token: Option<bool>,
|
||||||
|
pub add_eos_token: Option<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl HubTokenizerConfig {
|
impl HubTokenizerConfig {
|
||||||
|
|
|
@ -17,7 +17,7 @@ use text_generation_router::{
|
||||||
server, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig,
|
server, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig,
|
||||||
};
|
};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::{processors::template::TemplateProcessing, Tokenizer};
|
||||||
use tower_http::cors::AllowOrigin;
|
use tower_http::cors::AllowOrigin;
|
||||||
use tracing_subscriber::layer::SubscriberExt;
|
use tracing_subscriber::layer::SubscriberExt;
|
||||||
use tracing_subscriber::util::SubscriberInitExt;
|
use tracing_subscriber::util::SubscriberInitExt;
|
||||||
|
@ -275,8 +275,6 @@ async fn main() -> Result<(), RouterError> {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let tokenizer: Option<Tokenizer> =
|
|
||||||
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok());
|
|
||||||
let config: Option<Config> = config_filename.and_then(|filename| {
|
let config: Option<Config> = config_filename.and_then(|filename| {
|
||||||
std::fs::read_to_string(filename)
|
std::fs::read_to_string(filename)
|
||||||
.ok()
|
.ok()
|
||||||
|
@ -306,6 +304,37 @@ async fn main() -> Result<(), RouterError> {
|
||||||
tracing::warn!("Could not find tokenizer config locally and no API specified");
|
tracing::warn!("Could not find tokenizer config locally and no API specified");
|
||||||
HubTokenizerConfig::default()
|
HubTokenizerConfig::default()
|
||||||
});
|
});
|
||||||
|
let tokenizer: Option<Tokenizer> =
|
||||||
|
tokenizer_filename.and_then(|filename| {
|
||||||
|
let mut tokenizer = Tokenizer::from_file(filename).ok();
|
||||||
|
if let Some(tokenizer) = &mut tokenizer{
|
||||||
|
if let Some(class) = &tokenizer_config.tokenizer_class{
|
||||||
|
if class == "LlamaTokenizer" || class == "LlamaTokenizerFast" {
|
||||||
|
tracing::info!("Overriding LllamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
|
||||||
|
let mut single = vec![];
|
||||||
|
let mut special_tokens = vec![];
|
||||||
|
if let Some(true) = &tokenizer_config.add_bos_token{
|
||||||
|
if let Some(bos_token) = &tokenizer_config.bos_token{
|
||||||
|
let bos_token_id = tokenizer.token_to_id(&bos_token).expect("Should have found the bos token id");
|
||||||
|
special_tokens.push((bos_token.clone(), bos_token_id));
|
||||||
|
single.push(bos_token.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
single.push("$0".to_string());
|
||||||
|
if let Some(true) = &tokenizer_config.add_eos_token{
|
||||||
|
if let Some(eos_token) = &tokenizer_config.eos_token{
|
||||||
|
let eos_token_id = tokenizer.token_to_id(&eos_token).expect("Should have found the eos token id");
|
||||||
|
special_tokens.push((eos_token.clone(), eos_token_id));
|
||||||
|
single.push(eos_token.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let post_processor = TemplateProcessing::builder().try_single(single).unwrap().special_tokens(special_tokens).build().unwrap();
|
||||||
|
tokenizer.with_post_processor(post_processor);
|
||||||
|
}}
|
||||||
|
}
|
||||||
|
tokenizer
|
||||||
|
|
||||||
|
});
|
||||||
|
|
||||||
let preprocessor_config =
|
let preprocessor_config =
|
||||||
preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file);
|
preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file);
|
||||||
|
|
Loading…
Reference in New Issue