Processor config chat template (#1954)
This PR loads the `processor_config` similar to the `tokenizer_config` and uses the processor_config's chat_template if the tokenizer_config does not include one. These changes enable chat with idefics2
This commit is contained in:
parent
a401c83c35
commit
0732b9d2f0
|
@ -2,7 +2,8 @@
|
|||
use crate::validation::{Validation, ValidationError};
|
||||
use crate::{
|
||||
ChatTemplateInputs, ChatTemplateVersions, Entry, GenerateRequest, GenerateStreamResponse,
|
||||
HubTokenizerConfig, Message, MessageChunk, PrefillToken, Queue, Text, TextMessage, Token,
|
||||
HubProcessorConfig, HubTokenizerConfig, Message, MessageChunk, PrefillToken, Queue, Text,
|
||||
TextMessage, Token,
|
||||
};
|
||||
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
|
||||
use futures::future::try_join_all;
|
||||
|
@ -67,6 +68,7 @@ impl Infer {
|
|||
speculate: u32,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
tokenizer_config: HubTokenizerConfig,
|
||||
processor_config: HubProcessorConfig,
|
||||
) -> Self {
|
||||
// Infer shared state
|
||||
let queue = Queue::new(requires_padding, 16, window_size, speculate);
|
||||
|
@ -89,6 +91,7 @@ impl Infer {
|
|||
|
||||
let chat_template = tokenizer_config
|
||||
.chat_template
|
||||
.or(processor_config.chat_template)
|
||||
.and_then(|t| match t {
|
||||
ChatTemplateVersions::Single(template) => Some(template),
|
||||
ChatTemplateVersions::Multiple(templates) => templates
|
||||
|
@ -98,7 +101,10 @@ impl Infer {
|
|||
})
|
||||
.map(|t| {
|
||||
// .strip() is not supported in minijinja
|
||||
let t = t.replace(".strip()", " | trim");
|
||||
// .capitalize() is not supported in minijinja but we can use | capitalize
|
||||
let t = t
|
||||
.replace(".strip()", " | trim")
|
||||
.replace(".capitalize()", " | capitalize");
|
||||
ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)
|
||||
});
|
||||
|
||||
|
|
|
@ -80,6 +80,20 @@ impl HubTokenizerConfig {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Default)]
|
||||
pub struct HubProcessorConfig {
|
||||
pub chat_template: Option<ChatTemplateVersions>,
|
||||
pub image_seq_len: usize,
|
||||
pub processor_class: Option<String>,
|
||||
}
|
||||
|
||||
impl HubProcessorConfig {
|
||||
pub fn from_file<P: AsRef<std::path::Path>>(filename: P) -> Option<Self> {
|
||||
let content = std::fs::read_to_string(filename).ok()?;
|
||||
serde_json::from_str(&content).ok()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
|
||||
#[serde(tag = "type", content = "value")]
|
||||
pub(crate) enum GrammarType {
|
||||
|
|
|
@ -14,7 +14,7 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
|||
use std::path::{Path, PathBuf};
|
||||
use text_generation_client::{ClientError, ShardedClient};
|
||||
use text_generation_router::config::Config;
|
||||
use text_generation_router::{server, HubModelInfo, HubTokenizerConfig};
|
||||
use text_generation_router::{server, HubModelInfo, HubProcessorConfig, HubTokenizerConfig};
|
||||
use thiserror::Error;
|
||||
use tokenizers::Tokenizer;
|
||||
use tower_http::cors::AllowOrigin;
|
||||
|
@ -206,11 +206,18 @@ async fn main() -> Result<(), RouterError> {
|
|||
};
|
||||
|
||||
// Load tokenizer and model info
|
||||
let (tokenizer_filename, config_filename, tokenizer_config_filename, model_info) = match api {
|
||||
let (
|
||||
tokenizer_filename,
|
||||
config_filename,
|
||||
tokenizer_config_filename,
|
||||
processor_config_filename,
|
||||
model_info,
|
||||
) = match api {
|
||||
Type::None => (
|
||||
Some(local_path.join("tokenizer.json")),
|
||||
Some(local_path.join("config.json")),
|
||||
Some(local_path.join("tokenizer_config.json")),
|
||||
Some(local_path.join("processor_config.json")),
|
||||
None,
|
||||
),
|
||||
Type::Api(api) => {
|
||||
|
@ -226,6 +233,7 @@ async fn main() -> Result<(), RouterError> {
|
|||
};
|
||||
let config_filename = api_repo.get("config.json").await.ok();
|
||||
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
|
||||
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
|
||||
|
||||
let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
|
||||
Some(model_info)
|
||||
|
@ -237,6 +245,7 @@ async fn main() -> Result<(), RouterError> {
|
|||
tokenizer_filename,
|
||||
config_filename,
|
||||
tokenizer_config_filename,
|
||||
processor_config_filename,
|
||||
model_info,
|
||||
)
|
||||
}
|
||||
|
@ -250,6 +259,7 @@ async fn main() -> Result<(), RouterError> {
|
|||
repo.get("tokenizer.json"),
|
||||
repo.get("config.json"),
|
||||
repo.get("tokenizer_config.json"),
|
||||
repo.get("processor_config.json"),
|
||||
None,
|
||||
)
|
||||
}
|
||||
|
@ -286,6 +296,10 @@ async fn main() -> Result<(), RouterError> {
|
|||
HubTokenizerConfig::default()
|
||||
});
|
||||
|
||||
let processor_config = processor_config_filename
|
||||
.and_then(HubProcessorConfig::from_file)
|
||||
.unwrap_or_default();
|
||||
|
||||
tracing::info!("Using config {config:?}");
|
||||
if tokenizer.is_none() {
|
||||
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
|
||||
|
@ -397,6 +411,7 @@ async fn main() -> Result<(), RouterError> {
|
|||
ngrok_authtoken,
|
||||
ngrok_edge,
|
||||
tokenizer_config,
|
||||
processor_config,
|
||||
messages_api_enabled,
|
||||
disable_grammar_support,
|
||||
max_client_batch_size,
|
||||
|
|
|
@ -5,9 +5,9 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse, ToolGrammar};
|
|||
use crate::validation::ValidationError;
|
||||
use crate::{
|
||||
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
|
||||
GenerateResponse, GrammarType, HubModelInfo, HubTokenizerConfig, Infer, Info, Message,
|
||||
PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Usage,
|
||||
Validation,
|
||||
GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Infer,
|
||||
Info, Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token,
|
||||
TokenizeResponse, Usage, Validation,
|
||||
};
|
||||
use crate::{
|
||||
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
|
||||
|
@ -1395,6 +1395,7 @@ pub async fn run(
|
|||
ngrok_authtoken: Option<String>,
|
||||
ngrok_edge: Option<String>,
|
||||
tokenizer_config: HubTokenizerConfig,
|
||||
processor_config: HubProcessorConfig,
|
||||
messages_api_enabled: bool,
|
||||
grammar_support: bool,
|
||||
max_client_batch_size: usize,
|
||||
|
@ -1495,6 +1496,7 @@ pub async fn run(
|
|||
shard_info.speculate,
|
||||
generation_health,
|
||||
tokenizer_config,
|
||||
processor_config,
|
||||
);
|
||||
|
||||
// Duration buckets
|
||||
|
|
Loading…
Reference in New Issue