feat: add deserialize_with that handles strings or objects with content (#1550)

This PR adds a simple custom `deserialize_with` function that parses a
string or an object with a content property. This should help support
more token configuration files stored on the hub
This commit is contained in:
drbh 2024-02-13 10:01:02 -05:00 committed by GitHub
parent 0d794af6a5
commit 246ad39d04
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 85 additions and 0 deletions

View File

@ -32,7 +32,9 @@ pub struct HubModelInfo {
#[derive(Clone, Deserialize, Default)] #[derive(Clone, Deserialize, Default)]
pub struct HubTokenizerConfig { pub struct HubTokenizerConfig {
pub chat_template: Option<String>, pub chat_template: Option<String>,
#[serde(deserialize_with = "token_serde::deserialize")]
pub bos_token: Option<String>, pub bos_token: Option<String>,
#[serde(deserialize_with = "token_serde::deserialize")]
pub eos_token: Option<String>, pub eos_token: Option<String>,
} }
@ -43,6 +45,34 @@ impl HubTokenizerConfig {
} }
} }
mod token_serde {
use super::*;
use serde::de;
use serde::Deserializer;
use serde_json::Value;
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
where
D: Deserializer<'de>,
{
let value = Value::deserialize(deserializer)?;
match value {
Value::String(s) => Ok(Some(s)),
Value::Object(map) => {
if let Some(content) = map.get("content").and_then(|v| v.as_str()) {
Ok(Some(content.to_string()))
} else {
Err(de::Error::custom(
"content key not found in structured token",
))
}
}
_ => Err(de::Error::custom("invalid token format")),
}
}
}
#[derive(Clone, Debug, Serialize, ToSchema)] #[derive(Clone, Debug, Serialize, ToSchema)]
pub struct Info { pub struct Info {
/// Model info /// Model info
@ -638,6 +668,8 @@ pub(crate) struct ErrorResponse {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
pub(crate) async fn get_tokenizer() -> Tokenizer { pub(crate) async fn get_tokenizer() -> Tokenizer {
@ -646,4 +678,57 @@ mod tests {
let filename = repo.get("tokenizer.json").unwrap(); let filename = repo.get("tokenizer.json").unwrap();
Tokenizer::from_file(filename).unwrap() Tokenizer::from_file(filename).unwrap()
} }
#[test]
fn test_hub_nested_tokens_tokenizer_config() {
// this is a subset of the tokenizer.json file
// in this case we expect the tokens to be encoded as simple strings
let json_content = r#"{
"chat_template": "test",
"bos_token": "<begin▁of▁sentence>",
"eos_token": "<end▁of▁sentence>"
}"#;
let config: HubTokenizerConfig = serde_json::from_str(json_content).unwrap();
// check that we successfully parsed the tokens
assert_eq!(config.chat_template, Some("test".to_string()));
assert_eq!(
config.bos_token,
Some("<begin▁of▁sentence>".to_string())
);
assert_eq!(config.eos_token, Some("<end▁of▁sentence>".to_string()));
// in this case we expect the tokens to be encoded as structured tokens
// we want the content of the structured token
let json_content = r#"{
"chat_template": "test",
"bos_token": {
"__type": "AddedToken",
"content": "<begin▁of▁sentence>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"eos_token": {
"__type": "AddedToken",
"content": "<end▁of▁sentence>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
}
}"#;
let config: HubTokenizerConfig = serde_json::from_str(json_content).unwrap();
// check that we successfully parsed the tokens
assert_eq!(config.chat_template, Some("test".to_string()));
assert_eq!(
config.bos_token,
Some("<begin▁of▁sentence>".to_string())
);
assert_eq!(config.eos_token, Some("<end▁of▁sentence>".to_string()));
}
} }