fix: refactor post_processor logic and add test (#2137)
* fix: refactor post_processor logic and add test * fix: remove dev comment * fix: adjust when post_processor is overridden and improve create_post_processor
This commit is contained in:
parent
3ea8259af1
commit
74b0231b19
|
@ -304,36 +304,20 @@ async fn main() -> Result<(), RouterError> {
|
|||
tracing::warn!("Could not find tokenizer config locally and no API specified");
|
||||
HubTokenizerConfig::default()
|
||||
});
|
||||
let tokenizer: Option<Tokenizer> =
|
||||
tokenizer_filename.and_then(|filename| {
|
||||
|
||||
let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| {
|
||||
let mut tokenizer = Tokenizer::from_file(filename).ok();
|
||||
if let Some(tokenizer) = &mut tokenizer {
|
||||
if let Some(class) = &tokenizer_config.tokenizer_class {
|
||||
if class == "LlamaTokenizer" || class == "LlamaTokenizerFast" {
|
||||
tracing::info!("Overriding LllamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
|
||||
let mut single = vec![];
|
||||
let mut special_tokens = vec![];
|
||||
if let Some(true) = &tokenizer_config.add_bos_token{
|
||||
if let Some(bos_token) = &tokenizer_config.bos_token{
|
||||
let bos_token_id = tokenizer.token_to_id(&bos_token).expect("Should have found the bos token id");
|
||||
special_tokens.push((bos_token.clone(), bos_token_id));
|
||||
single.push(bos_token.to_string());
|
||||
}
|
||||
}
|
||||
single.push("$0".to_string());
|
||||
if let Some(true) = &tokenizer_config.add_eos_token{
|
||||
if let Some(eos_token) = &tokenizer_config.eos_token{
|
||||
let eos_token_id = tokenizer.token_to_id(&eos_token).expect("Should have found the eos token id");
|
||||
special_tokens.push((eos_token.clone(), eos_token_id));
|
||||
single.push(eos_token.to_string());
|
||||
}
|
||||
}
|
||||
let post_processor = TemplateProcessing::builder().try_single(single).unwrap().special_tokens(special_tokens).build().unwrap();
|
||||
if (class == "LlamaTokenizer" || class == "LlamaTokenizerFast") && tokenizer.get_post_processor().is_none() {
|
||||
if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
|
||||
tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
|
||||
tokenizer.with_post_processor(post_processor);
|
||||
}}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
tokenizer
|
||||
|
||||
});
|
||||
|
||||
let preprocessor_config =
|
||||
|
@ -543,6 +527,77 @@ pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConf
|
|||
Some(tokenizer_config)
|
||||
}
|
||||
|
||||
/// Create a post_processor for the LlamaTokenizer
|
||||
pub fn create_post_processor(
|
||||
tokenizer: &Tokenizer,
|
||||
tokenizer_config: &HubTokenizerConfig,
|
||||
) -> Result<TemplateProcessing, tokenizers::processors::template::TemplateProcessingBuilderError> {
|
||||
let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true);
|
||||
let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false);
|
||||
|
||||
let bos_token = tokenizer_config.bos_token.as_ref();
|
||||
let eos_token = tokenizer_config.eos_token.as_ref();
|
||||
|
||||
if add_bos_token && bos_token.is_none() {
|
||||
panic!("add_bos_token = true but bos_token is None");
|
||||
}
|
||||
|
||||
if add_eos_token && eos_token.is_none() {
|
||||
panic!("add_eos_token = true but eos_token is None");
|
||||
}
|
||||
|
||||
let mut single = Vec::new();
|
||||
let mut pair = Vec::new();
|
||||
let mut special_tokens = Vec::new();
|
||||
|
||||
if add_bos_token {
|
||||
if let Some(bos) = bos_token {
|
||||
let bos_token_id = tokenizer
|
||||
.token_to_id(bos)
|
||||
.expect("Should have found the bos token id");
|
||||
special_tokens.push((bos.clone(), bos_token_id));
|
||||
single.push(format!("{}:0", bos));
|
||||
pair.push(format!("{}:0", bos));
|
||||
}
|
||||
}
|
||||
|
||||
single.push("$A:0".to_string());
|
||||
pair.push("$A:0".to_string());
|
||||
|
||||
if add_eos_token {
|
||||
if let Some(eos) = eos_token {
|
||||
let eos_token_id = tokenizer
|
||||
.token_to_id(eos)
|
||||
.expect("Should have found the eos token id");
|
||||
special_tokens.push((eos.clone(), eos_token_id));
|
||||
single.push(format!("{}:0", eos));
|
||||
pair.push(format!("{}:0", eos));
|
||||
}
|
||||
}
|
||||
|
||||
if add_bos_token {
|
||||
if let Some(bos) = bos_token {
|
||||
single.push(format!("{}:1", bos));
|
||||
}
|
||||
}
|
||||
|
||||
pair.push("$B:1".to_string());
|
||||
|
||||
if add_eos_token {
|
||||
if let Some(eos) = eos_token {
|
||||
pair.push(format!("{}:1", eos));
|
||||
}
|
||||
}
|
||||
|
||||
let post_processor = TemplateProcessing::builder()
|
||||
.try_single(single)?
|
||||
.try_pair(pair)?
|
||||
.special_tokens(special_tokens)
|
||||
.build()?;
|
||||
|
||||
Ok(post_processor)
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
enum RouterError {
|
||||
#[error("Argument validation error: {0}")]
|
||||
|
@ -552,3 +607,36 @@ enum RouterError {
|
|||
#[error("Tokio runtime failed to start: {0}")]
|
||||
Tokio(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_create_post_processor() {
|
||||
let tokenizer_config = HubTokenizerConfig {
|
||||
add_bos_token: None,
|
||||
add_eos_token: None,
|
||||
bos_token: Some("<s>".to_string()),
|
||||
eos_token: Some("</s>".to_string()),
|
||||
chat_template: None,
|
||||
tokenizer_class: None,
|
||||
completion_template: None,
|
||||
};
|
||||
|
||||
let tokenizer =
|
||||
Tokenizer::from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", None).unwrap();
|
||||
let post_processor = create_post_processor(&tokenizer, &tokenizer_config).unwrap();
|
||||
|
||||
let expected = TemplateProcessing::builder()
|
||||
.try_single("<s>:0 $A:0 <s>:1")
|
||||
.unwrap()
|
||||
.try_pair("<s>:0 $A:0 $B:1")
|
||||
.unwrap()
|
||||
.special_tokens(vec![("<s>".to_string(), 1)])
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(post_processor, expected);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue