Processor config chat template (#1954)
This PR loads the `processor_config` similar to the `tokenizer_config` and uses the processor_config's chat_template if the tokenizer_config does not include one. These changes enable chat with idefics2
This commit is contained in:
parent
a401c83c35
commit
0732b9d2f0
|
@ -2,7 +2,8 @@
|
||||||
use crate::validation::{Validation, ValidationError};
|
use crate::validation::{Validation, ValidationError};
|
||||||
use crate::{
|
use crate::{
|
||||||
ChatTemplateInputs, ChatTemplateVersions, Entry, GenerateRequest, GenerateStreamResponse,
|
ChatTemplateInputs, ChatTemplateVersions, Entry, GenerateRequest, GenerateStreamResponse,
|
||||||
HubTokenizerConfig, Message, MessageChunk, PrefillToken, Queue, Text, TextMessage, Token,
|
HubProcessorConfig, HubTokenizerConfig, Message, MessageChunk, PrefillToken, Queue, Text,
|
||||||
|
TextMessage, Token,
|
||||||
};
|
};
|
||||||
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
|
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
|
||||||
use futures::future::try_join_all;
|
use futures::future::try_join_all;
|
||||||
|
@ -67,6 +68,7 @@ impl Infer {
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
generation_health: Arc<AtomicBool>,
|
generation_health: Arc<AtomicBool>,
|
||||||
tokenizer_config: HubTokenizerConfig,
|
tokenizer_config: HubTokenizerConfig,
|
||||||
|
processor_config: HubProcessorConfig,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Infer shared state
|
// Infer shared state
|
||||||
let queue = Queue::new(requires_padding, 16, window_size, speculate);
|
let queue = Queue::new(requires_padding, 16, window_size, speculate);
|
||||||
|
@ -89,6 +91,7 @@ impl Infer {
|
||||||
|
|
||||||
let chat_template = tokenizer_config
|
let chat_template = tokenizer_config
|
||||||
.chat_template
|
.chat_template
|
||||||
|
.or(processor_config.chat_template)
|
||||||
.and_then(|t| match t {
|
.and_then(|t| match t {
|
||||||
ChatTemplateVersions::Single(template) => Some(template),
|
ChatTemplateVersions::Single(template) => Some(template),
|
||||||
ChatTemplateVersions::Multiple(templates) => templates
|
ChatTemplateVersions::Multiple(templates) => templates
|
||||||
|
@ -98,7 +101,10 @@ impl Infer {
|
||||||
})
|
})
|
||||||
.map(|t| {
|
.map(|t| {
|
||||||
// .strip() is not supported in minijinja
|
// .strip() is not supported in minijinja
|
||||||
let t = t.replace(".strip()", " | trim");
|
// .capitalize() is not supported in minijinja but we can use | capitalize
|
||||||
|
let t = t
|
||||||
|
.replace(".strip()", " | trim")
|
||||||
|
.replace(".capitalize()", " | capitalize");
|
||||||
ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)
|
ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
@ -80,6 +80,20 @@ impl HubTokenizerConfig {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Default)]
|
||||||
|
pub struct HubProcessorConfig {
|
||||||
|
pub chat_template: Option<ChatTemplateVersions>,
|
||||||
|
pub image_seq_len: usize,
|
||||||
|
pub processor_class: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HubProcessorConfig {
|
||||||
|
pub fn from_file<P: AsRef<std::path::Path>>(filename: P) -> Option<Self> {
|
||||||
|
let content = std::fs::read_to_string(filename).ok()?;
|
||||||
|
serde_json::from_str(&content).ok()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
|
#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
|
||||||
#[serde(tag = "type", content = "value")]
|
#[serde(tag = "type", content = "value")]
|
||||||
pub(crate) enum GrammarType {
|
pub(crate) enum GrammarType {
|
||||||
|
|
|
@ -14,7 +14,7 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use text_generation_client::{ClientError, ShardedClient};
|
use text_generation_client::{ClientError, ShardedClient};
|
||||||
use text_generation_router::config::Config;
|
use text_generation_router::config::Config;
|
||||||
use text_generation_router::{server, HubModelInfo, HubTokenizerConfig};
|
use text_generation_router::{server, HubModelInfo, HubProcessorConfig, HubTokenizerConfig};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tower_http::cors::AllowOrigin;
|
use tower_http::cors::AllowOrigin;
|
||||||
|
@ -206,11 +206,18 @@ async fn main() -> Result<(), RouterError> {
|
||||||
};
|
};
|
||||||
|
|
||||||
// Load tokenizer and model info
|
// Load tokenizer and model info
|
||||||
let (tokenizer_filename, config_filename, tokenizer_config_filename, model_info) = match api {
|
let (
|
||||||
|
tokenizer_filename,
|
||||||
|
config_filename,
|
||||||
|
tokenizer_config_filename,
|
||||||
|
processor_config_filename,
|
||||||
|
model_info,
|
||||||
|
) = match api {
|
||||||
Type::None => (
|
Type::None => (
|
||||||
Some(local_path.join("tokenizer.json")),
|
Some(local_path.join("tokenizer.json")),
|
||||||
Some(local_path.join("config.json")),
|
Some(local_path.join("config.json")),
|
||||||
Some(local_path.join("tokenizer_config.json")),
|
Some(local_path.join("tokenizer_config.json")),
|
||||||
|
Some(local_path.join("processor_config.json")),
|
||||||
None,
|
None,
|
||||||
),
|
),
|
||||||
Type::Api(api) => {
|
Type::Api(api) => {
|
||||||
|
@ -226,6 +233,7 @@ async fn main() -> Result<(), RouterError> {
|
||||||
};
|
};
|
||||||
let config_filename = api_repo.get("config.json").await.ok();
|
let config_filename = api_repo.get("config.json").await.ok();
|
||||||
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
|
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
|
||||||
|
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
|
||||||
|
|
||||||
let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
|
let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
|
||||||
Some(model_info)
|
Some(model_info)
|
||||||
|
@ -237,6 +245,7 @@ async fn main() -> Result<(), RouterError> {
|
||||||
tokenizer_filename,
|
tokenizer_filename,
|
||||||
config_filename,
|
config_filename,
|
||||||
tokenizer_config_filename,
|
tokenizer_config_filename,
|
||||||
|
processor_config_filename,
|
||||||
model_info,
|
model_info,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -250,6 +259,7 @@ async fn main() -> Result<(), RouterError> {
|
||||||
repo.get("tokenizer.json"),
|
repo.get("tokenizer.json"),
|
||||||
repo.get("config.json"),
|
repo.get("config.json"),
|
||||||
repo.get("tokenizer_config.json"),
|
repo.get("tokenizer_config.json"),
|
||||||
|
repo.get("processor_config.json"),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -286,6 +296,10 @@ async fn main() -> Result<(), RouterError> {
|
||||||
HubTokenizerConfig::default()
|
HubTokenizerConfig::default()
|
||||||
});
|
});
|
||||||
|
|
||||||
|
let processor_config = processor_config_filename
|
||||||
|
.and_then(HubProcessorConfig::from_file)
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
tracing::info!("Using config {config:?}");
|
tracing::info!("Using config {config:?}");
|
||||||
if tokenizer.is_none() {
|
if tokenizer.is_none() {
|
||||||
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
|
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
|
||||||
|
@ -397,6 +411,7 @@ async fn main() -> Result<(), RouterError> {
|
||||||
ngrok_authtoken,
|
ngrok_authtoken,
|
||||||
ngrok_edge,
|
ngrok_edge,
|
||||||
tokenizer_config,
|
tokenizer_config,
|
||||||
|
processor_config,
|
||||||
messages_api_enabled,
|
messages_api_enabled,
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
|
|
|
@ -5,9 +5,9 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse, ToolGrammar};
|
||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
use crate::{
|
use crate::{
|
||||||
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
|
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
|
||||||
GenerateResponse, GrammarType, HubModelInfo, HubTokenizerConfig, Infer, Info, Message,
|
GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Infer,
|
||||||
PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Usage,
|
Info, Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token,
|
||||||
Validation,
|
TokenizeResponse, Usage, Validation,
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
|
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
|
||||||
|
@ -1395,6 +1395,7 @@ pub async fn run(
|
||||||
ngrok_authtoken: Option<String>,
|
ngrok_authtoken: Option<String>,
|
||||||
ngrok_edge: Option<String>,
|
ngrok_edge: Option<String>,
|
||||||
tokenizer_config: HubTokenizerConfig,
|
tokenizer_config: HubTokenizerConfig,
|
||||||
|
processor_config: HubProcessorConfig,
|
||||||
messages_api_enabled: bool,
|
messages_api_enabled: bool,
|
||||||
grammar_support: bool,
|
grammar_support: bool,
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
|
@ -1495,6 +1496,7 @@ pub async fn run(
|
||||||
shard_info.speculate,
|
shard_info.speculate,
|
||||||
generation_health,
|
generation_health,
|
||||||
tokenizer_config,
|
tokenizer_config,
|
||||||
|
processor_config,
|
||||||
);
|
);
|
||||||
|
|
||||||
// Duration buckets
|
// Duration buckets
|
||||||
|
|
Loading…
Reference in New Issue