diff --git a/router/src/lib.rs b/router/src/lib.rs index 3ce9eca8..a9d783bb 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -32,7 +32,9 @@ pub struct HubModelInfo { #[derive(Clone, Deserialize, Default)] pub struct HubTokenizerConfig { pub chat_template: Option, + #[serde(deserialize_with = "token_serde::deserialize")] pub bos_token: Option, + #[serde(deserialize_with = "token_serde::deserialize")] pub eos_token: Option, } @@ -43,6 +45,34 @@ impl HubTokenizerConfig { } } +mod token_serde { + use super::*; + use serde::de; + use serde::Deserializer; + use serde_json::Value; + + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: Deserializer<'de>, + { + let value = Value::deserialize(deserializer)?; + + match value { + Value::String(s) => Ok(Some(s)), + Value::Object(map) => { + if let Some(content) = map.get("content").and_then(|v| v.as_str()) { + Ok(Some(content.to_string())) + } else { + Err(de::Error::custom( + "content key not found in structured token", + )) + } + } + _ => Err(de::Error::custom("invalid token format")), + } + } +} + #[derive(Clone, Debug, Serialize, ToSchema)] pub struct Info { /// Model info @@ -638,6 +668,8 @@ pub(crate) struct ErrorResponse { #[cfg(test)] mod tests { + use super::*; + use tokenizers::Tokenizer; pub(crate) async fn get_tokenizer() -> Tokenizer { @@ -646,4 +678,57 @@ mod tests { let filename = repo.get("tokenizer.json").unwrap(); Tokenizer::from_file(filename).unwrap() } + + #[test] + fn test_hub_nested_tokens_tokenizer_config() { + // this is a subset of the tokenizer.json file + // in this case we expect the tokens to be encoded as simple strings + let json_content = r#"{ + "chat_template": "test", + "bos_token": "<|begin▁of▁sentence|>", + "eos_token": "<|end▁of▁sentence|>" + }"#; + + let config: HubTokenizerConfig = serde_json::from_str(json_content).unwrap(); + + // check that we successfully parsed the tokens + assert_eq!(config.chat_template, Some("test".to_string())); + assert_eq!( + config.bos_token, + Some("<|begin▁of▁sentence|>".to_string()) + ); + assert_eq!(config.eos_token, Some("<|end▁of▁sentence|>".to_string())); + + // in this case we expect the tokens to be encoded as structured tokens + // we want the content of the structured token + let json_content = r#"{ + "chat_template": "test", + "bos_token": { + "__type": "AddedToken", + "content": "<|begin▁of▁sentence|>", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false + }, + "eos_token": { + "__type": "AddedToken", + "content": "<|end▁of▁sentence|>", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false + } + }"#; + + let config: HubTokenizerConfig = serde_json::from_str(json_content).unwrap(); + + // check that we successfully parsed the tokens + assert_eq!(config.chat_template, Some("test".to_string())); + assert_eq!( + config.bos_token, + Some("<|begin▁of▁sentence|>".to_string()) + ); + assert_eq!(config.eos_token, Some("<|end▁of▁sentence|>".to_string())); + } }