feat: add deserialize_with that handles strings or objects with content (#1550)
This PR adds a simple custom `deserialize_with` function that parses a string or an object with a content property. This should help support more token configuration files stored on the hub
This commit is contained in:
parent
0d794af6a5
commit
246ad39d04
|
@ -32,7 +32,9 @@ pub struct HubModelInfo {
|
||||||
#[derive(Clone, Deserialize, Default)]
|
#[derive(Clone, Deserialize, Default)]
|
||||||
pub struct HubTokenizerConfig {
|
pub struct HubTokenizerConfig {
|
||||||
pub chat_template: Option<String>,
|
pub chat_template: Option<String>,
|
||||||
|
#[serde(deserialize_with = "token_serde::deserialize")]
|
||||||
pub bos_token: Option<String>,
|
pub bos_token: Option<String>,
|
||||||
|
#[serde(deserialize_with = "token_serde::deserialize")]
|
||||||
pub eos_token: Option<String>,
|
pub eos_token: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -43,6 +45,34 @@ impl HubTokenizerConfig {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mod token_serde {
|
||||||
|
use super::*;
|
||||||
|
use serde::de;
|
||||||
|
use serde::Deserializer;
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
let value = Value::deserialize(deserializer)?;
|
||||||
|
|
||||||
|
match value {
|
||||||
|
Value::String(s) => Ok(Some(s)),
|
||||||
|
Value::Object(map) => {
|
||||||
|
if let Some(content) = map.get("content").and_then(|v| v.as_str()) {
|
||||||
|
Ok(Some(content.to_string()))
|
||||||
|
} else {
|
||||||
|
Err(de::Error::custom(
|
||||||
|
"content key not found in structured token",
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => Err(de::Error::custom("invalid token format")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, ToSchema)]
|
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||||
pub struct Info {
|
pub struct Info {
|
||||||
/// Model info
|
/// Model info
|
||||||
|
@ -638,6 +668,8 @@ pub(crate) struct ErrorResponse {
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
pub(crate) async fn get_tokenizer() -> Tokenizer {
|
pub(crate) async fn get_tokenizer() -> Tokenizer {
|
||||||
|
@ -646,4 +678,57 @@ mod tests {
|
||||||
let filename = repo.get("tokenizer.json").unwrap();
|
let filename = repo.get("tokenizer.json").unwrap();
|
||||||
Tokenizer::from_file(filename).unwrap()
|
Tokenizer::from_file(filename).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_hub_nested_tokens_tokenizer_config() {
|
||||||
|
// this is a subset of the tokenizer.json file
|
||||||
|
// in this case we expect the tokens to be encoded as simple strings
|
||||||
|
let json_content = r#"{
|
||||||
|
"chat_template": "test",
|
||||||
|
"bos_token": "<|begin▁of▁sentence|>",
|
||||||
|
"eos_token": "<|end▁of▁sentence|>"
|
||||||
|
}"#;
|
||||||
|
|
||||||
|
let config: HubTokenizerConfig = serde_json::from_str(json_content).unwrap();
|
||||||
|
|
||||||
|
// check that we successfully parsed the tokens
|
||||||
|
assert_eq!(config.chat_template, Some("test".to_string()));
|
||||||
|
assert_eq!(
|
||||||
|
config.bos_token,
|
||||||
|
Some("<|begin▁of▁sentence|>".to_string())
|
||||||
|
);
|
||||||
|
assert_eq!(config.eos_token, Some("<|end▁of▁sentence|>".to_string()));
|
||||||
|
|
||||||
|
// in this case we expect the tokens to be encoded as structured tokens
|
||||||
|
// we want the content of the structured token
|
||||||
|
let json_content = r#"{
|
||||||
|
"chat_template": "test",
|
||||||
|
"bos_token": {
|
||||||
|
"__type": "AddedToken",
|
||||||
|
"content": "<|begin▁of▁sentence|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": true,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
},
|
||||||
|
"eos_token": {
|
||||||
|
"__type": "AddedToken",
|
||||||
|
"content": "<|end▁of▁sentence|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": true,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
}
|
||||||
|
}"#;
|
||||||
|
|
||||||
|
let config: HubTokenizerConfig = serde_json::from_str(json_content).unwrap();
|
||||||
|
|
||||||
|
// check that we successfully parsed the tokens
|
||||||
|
assert_eq!(config.chat_template, Some("test".to_string()));
|
||||||
|
assert_eq!(
|
||||||
|
config.bos_token,
|
||||||
|
Some("<|begin▁of▁sentence|>".to_string())
|
||||||
|
);
|
||||||
|
assert_eq!(config.eos_token, Some("<|end▁of▁sentence|>".to_string()));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue