router changes

This commit is contained in:
Miquel Farre 2024-11-13 12:42:11 +00:00 committed by drbh
parent de6c68443e
commit fc5b0ac1fd
2 changed files with 40 additions and 2 deletions

View File

@ -1134,6 +1134,7 @@ pub struct Url {
pub enum MessageChunk { pub enum MessageChunk {
Text { text: String }, Text { text: String },
ImageUrl { image_url: Url }, ImageUrl { image_url: Url },
Video { video_url: Url },
} }
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]

View File

@ -560,6 +560,14 @@ fn fetch_image(input: &str) -> Result<(Vec<u8>, String, usize, usize), Validatio
} }
} }
fn fetch_video(input: &str) -> Result<String, ValidationError> {
if input.starts_with("http://") || input.starts_with("https://") {
Ok(input.to_string())
} else {
Err(ValidationError::InvalidVideoContent(input.to_string()))
}
}
fn image_tokens( fn image_tokens(
config: &Config, config: &Config,
preprocessor_config: Option<&HubPreprocessorConfig>, preprocessor_config: Option<&HubPreprocessorConfig>,
@ -625,6 +633,9 @@ fn prepare_input<T: TokenizerTrait>(
) -> Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError> { ) -> Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError> {
use Config::*; use Config::*;
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap()); static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
// Add video regex
static VIDEO_RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"<video>\((https?://[^\)]+)\)").unwrap());
let (tokenizer_query, input_chunks) = match config { let (tokenizer_query, input_chunks) = match config {
Some( Some(
config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_) | Qwen2Vl(_)), config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_) | Qwen2Vl(_)),
@ -632,25 +643,45 @@ fn prepare_input<T: TokenizerTrait>(
let mut input_chunks = Vec::new(); let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len()); let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0; let mut start = 0;
for chunk in RE.find_iter(&inputs) {
// Process videos first
for chunk in VIDEO_RE.find_iter(&inputs) {
let chunk_start = chunk.start(); let chunk_start = chunk.start();
let chunk_end = chunk.end(); let chunk_end = chunk.end();
if chunk_start != start { if chunk_start != start {
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string())); input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()));
tokenizer_query.push_str(&inputs[start..chunk_start]); tokenizer_query.push_str(&inputs[start..chunk_start]);
} }
let video_url = &inputs[chunk_start + 8..chunk_end - 1]; // Remove <video>( and )
input_chunks.push(Chunk::Video(video_url.to_string()));
// For videos, we use the default size as height/width don't matter for the initial processing
tokenizer_query.push_str(&image_tokens(config, preprocessor_config, 1, 1));
start = chunk_end;
}
// Process remaining content for images
let remaining_input = &inputs[start..];
for chunk in RE.find_iter(remaining_input) {
let chunk_start = chunk.start() + start;
let chunk_end = chunk.end() + start;
if chunk_start != start {
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()));
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?; let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
input_chunks.push(Chunk::Image(Image { data, mimetype })); input_chunks.push(Chunk::Image(Image { data, mimetype }));
tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width)); tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width));
start = chunk_end; start = chunk_end;
} }
// Add any remaining text
if start != inputs.len() { if start != inputs.len() {
input_chunks.push(Chunk::Text(inputs[start..].to_string())); input_chunks.push(Chunk::Text(inputs[start..].to_string()));
tokenizer_query.push_str(&inputs[start..]); tokenizer_query.push_str(&inputs[start..]);
} }
// Apply any necessary token fixups
tokenizer_query = image_tokens_fixup(config, tokenizer_query); tokenizer_query = image_tokens_fixup(config, tokenizer_query);
(tokenizer_query, input_chunks) (tokenizer_query, input_chunks)
} }
_ => (inputs.clone(), vec![Chunk::Text(inputs)]), _ => (inputs.clone(), vec![Chunk::Text(inputs)]),
@ -680,6 +711,7 @@ pub struct Image {
pub enum Chunk { pub enum Chunk {
Text(String), Text(String),
Image(Image), Image(Image),
Video(String),
} }
/// Convert input chunks to a stringly-typed input for backwards /// Convert input chunks to a stringly-typed input for backwards
@ -698,6 +730,9 @@ impl ChunksToString for Vec<Chunk> {
let encoded = STANDARD.encode(data); let encoded = STANDARD.encode(data);
output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded)) output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded))
} }
Chunk::Video(url) => {
output.push_str(&format!("<video>({})", url))
}
}); });
output output
} }
@ -822,6 +857,8 @@ pub enum ValidationError {
FailedFetchImage(#[from] reqwest::Error), FailedFetchImage(#[from] reqwest::Error),
#[error("{0} modality is not supported")] #[error("{0} modality is not supported")]
UnsupportedModality(&'static str), UnsupportedModality(&'static str),
#[error("invalid video content: {0}")]
InvalidVideoContent(String),
} }
#[cfg(test)] #[cfg(test)]