Processor config chat template (#1954)

This PR loads the `processor_config` similar to the `tokenizer_config`
and uses the processor_config's chat_template if the tokenizer_config
does not include one. These changes enable chat with idefics2
This commit is contained in:
drbh 2024-05-27 10:03:16 -04:00 committed by GitHub
parent a401c83c35
commit 0732b9d2f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 44 additions and 7 deletions

View File

@ -2,7 +2,8 @@
use crate::validation::{Validation, ValidationError};
use crate::{
ChatTemplateInputs, ChatTemplateVersions, Entry, GenerateRequest, GenerateStreamResponse,
HubTokenizerConfig, Message, MessageChunk, PrefillToken, Queue, Text, TextMessage, Token,
HubProcessorConfig, HubTokenizerConfig, Message, MessageChunk, PrefillToken, Queue, Text,
TextMessage, Token,
};
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
use futures::future::try_join_all;
@ -67,6 +68,7 @@ impl Infer {
speculate: u32,
generation_health: Arc<AtomicBool>,
tokenizer_config: HubTokenizerConfig,
processor_config: HubProcessorConfig,
) -> Self {
// Infer shared state
let queue = Queue::new(requires_padding, 16, window_size, speculate);
@ -89,6 +91,7 @@ impl Infer {
let chat_template = tokenizer_config
.chat_template
.or(processor_config.chat_template)
.and_then(|t| match t {
ChatTemplateVersions::Single(template) => Some(template),
ChatTemplateVersions::Multiple(templates) => templates
@ -98,7 +101,10 @@ impl Infer {
})
.map(|t| {
// .strip() is not supported in minijinja
let t = t.replace(".strip()", " | trim");
// .capitalize() is not supported in minijinja but we can use | capitalize
let t = t
.replace(".strip()", " | trim")
.replace(".capitalize()", " | capitalize");
ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)
});

View File

@ -80,6 +80,20 @@ impl HubTokenizerConfig {
}
}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct HubProcessorConfig {
pub chat_template: Option<ChatTemplateVersions>,
pub image_seq_len: usize,
pub processor_class: Option<String>,
}
impl HubProcessorConfig {
pub fn from_file<P: AsRef<std::path::Path>>(filename: P) -> Option<Self> {
let content = std::fs::read_to_string(filename).ok()?;
serde_json::from_str(&content).ok()
}
}
#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
#[serde(tag = "type", content = "value")]
pub(crate) enum GrammarType {

View File

@ -14,7 +14,7 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::{Path, PathBuf};
use text_generation_client::{ClientError, ShardedClient};
use text_generation_router::config::Config;
use text_generation_router::{server, HubModelInfo, HubTokenizerConfig};
use text_generation_router::{server, HubModelInfo, HubProcessorConfig, HubTokenizerConfig};
use thiserror::Error;
use tokenizers::Tokenizer;
use tower_http::cors::AllowOrigin;
@ -206,11 +206,18 @@ async fn main() -> Result<(), RouterError> {
};
// Load tokenizer and model info
let (tokenizer_filename, config_filename, tokenizer_config_filename, model_info) = match api {
let (
tokenizer_filename,
config_filename,
tokenizer_config_filename,
processor_config_filename,
model_info,
) = match api {
Type::None => (
Some(local_path.join("tokenizer.json")),
Some(local_path.join("config.json")),
Some(local_path.join("tokenizer_config.json")),
Some(local_path.join("processor_config.json")),
None,
),
Type::Api(api) => {
@ -226,6 +233,7 @@ async fn main() -> Result<(), RouterError> {
};
let config_filename = api_repo.get("config.json").await.ok();
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
Some(model_info)
@ -237,6 +245,7 @@ async fn main() -> Result<(), RouterError> {
tokenizer_filename,
config_filename,
tokenizer_config_filename,
processor_config_filename,
model_info,
)
}
@ -250,6 +259,7 @@ async fn main() -> Result<(), RouterError> {
repo.get("tokenizer.json"),
repo.get("config.json"),
repo.get("tokenizer_config.json"),
repo.get("processor_config.json"),
None,
)
}
@ -286,6 +296,10 @@ async fn main() -> Result<(), RouterError> {
HubTokenizerConfig::default()
});
let processor_config = processor_config_filename
.and_then(HubProcessorConfig::from_file)
.unwrap_or_default();
tracing::info!("Using config {config:?}");
if tokenizer.is_none() {
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
@ -397,6 +411,7 @@ async fn main() -> Result<(), RouterError> {
ngrok_authtoken,
ngrok_edge,
tokenizer_config,
processor_config,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,

View File

@ -5,9 +5,9 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse, ToolGrammar};
use crate::validation::ValidationError;
use crate::{
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
GenerateResponse, GrammarType, HubModelInfo, HubTokenizerConfig, Infer, Info, Message,
PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Usage,
Validation,
GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Infer,
Info, Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token,
TokenizeResponse, Usage, Validation,
};
use crate::{
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
@ -1395,6 +1395,7 @@ pub async fn run(
ngrok_authtoken: Option<String>,
ngrok_edge: Option<String>,
tokenizer_config: HubTokenizerConfig,
processor_config: HubProcessorConfig,
messages_api_enabled: bool,
grammar_support: bool,
max_client_batch_size: usize,
@ -1495,6 +1496,7 @@ pub async fn run(
shard_info.speculate,
generation_health,
tokenizer_config,
processor_config,
);
// Duration buckets