feat: add deserialize_with that handles strings or objects with content (#1550)
This PR adds a simple custom `deserialize_with` function that parses a string or an object with a content property. This should help support more token configuration files stored on the hub
This commit is contained in:
parent
0d794af6a5
commit
246ad39d04
|
@ -32,7 +32,9 @@ pub struct HubModelInfo {
|
|||
#[derive(Clone, Deserialize, Default)]
|
||||
pub struct HubTokenizerConfig {
|
||||
pub chat_template: Option<String>,
|
||||
#[serde(deserialize_with = "token_serde::deserialize")]
|
||||
pub bos_token: Option<String>,
|
||||
#[serde(deserialize_with = "token_serde::deserialize")]
|
||||
pub eos_token: Option<String>,
|
||||
}
|
||||
|
||||
|
@ -43,6 +45,34 @@ impl HubTokenizerConfig {
|
|||
}
|
||||
}
|
||||
|
||||
mod token_serde {
|
||||
use super::*;
|
||||
use serde::de;
|
||||
use serde::Deserializer;
|
||||
use serde_json::Value;
|
||||
|
||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let value = Value::deserialize(deserializer)?;
|
||||
|
||||
match value {
|
||||
Value::String(s) => Ok(Some(s)),
|
||||
Value::Object(map) => {
|
||||
if let Some(content) = map.get("content").and_then(|v| v.as_str()) {
|
||||
Ok(Some(content.to_string()))
|
||||
} else {
|
||||
Err(de::Error::custom(
|
||||
"content key not found in structured token",
|
||||
))
|
||||
}
|
||||
}
|
||||
_ => Err(de::Error::custom("invalid token format")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||
pub struct Info {
|
||||
/// Model info
|
||||
|
@ -638,6 +668,8 @@ pub(crate) struct ErrorResponse {
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
pub(crate) async fn get_tokenizer() -> Tokenizer {
|
||||
|
@ -646,4 +678,57 @@ mod tests {
|
|||
let filename = repo.get("tokenizer.json").unwrap();
|
||||
Tokenizer::from_file(filename).unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hub_nested_tokens_tokenizer_config() {
|
||||
// this is a subset of the tokenizer.json file
|
||||
// in this case we expect the tokens to be encoded as simple strings
|
||||
let json_content = r#"{
|
||||
"chat_template": "test",
|
||||
"bos_token": "<|begin▁of▁sentence|>",
|
||||
"eos_token": "<|end▁of▁sentence|>"
|
||||
}"#;
|
||||
|
||||
let config: HubTokenizerConfig = serde_json::from_str(json_content).unwrap();
|
||||
|
||||
// check that we successfully parsed the tokens
|
||||
assert_eq!(config.chat_template, Some("test".to_string()));
|
||||
assert_eq!(
|
||||
config.bos_token,
|
||||
Some("<|begin▁of▁sentence|>".to_string())
|
||||
);
|
||||
assert_eq!(config.eos_token, Some("<|end▁of▁sentence|>".to_string()));
|
||||
|
||||
// in this case we expect the tokens to be encoded as structured tokens
|
||||
// we want the content of the structured token
|
||||
let json_content = r#"{
|
||||
"chat_template": "test",
|
||||
"bos_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "<|begin▁of▁sentence|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"eos_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "<|end▁of▁sentence|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}"#;
|
||||
|
||||
let config: HubTokenizerConfig = serde_json::from_str(json_content).unwrap();
|
||||
|
||||
// check that we successfully parsed the tokens
|
||||
assert_eq!(config.chat_template, Some("test".to_string()));
|
||||
assert_eq!(
|
||||
config.bos_token,
|
||||
Some("<|begin▁of▁sentence|>".to_string())
|
||||
);
|
||||
assert_eq!(config.eos_token, Some("<|end▁of▁sentence|>".to_string()));
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue