From f171bdc82313e58bf82a05fee1d5c01d2f3f2be0 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 22 Mar 2024 17:14:54 +0100 Subject: [PATCH] Inline images for multimodal models. (#1666) --- Cargo.lock | 2 ++ integration-tests/models/test_idefics.py | 8 +++++--- router/Cargo.toml | 2 ++ router/src/validation.rs | 18 +++++++++++++++--- 4 files changed, 24 insertions(+), 6 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3659a048..4db5b24f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3045,9 +3045,11 @@ dependencies = [ "minijinja", "ngrok", "nohash-hasher", + "once_cell", "opentelemetry", "opentelemetry-otlp", "rand", + "regex", "reqwest", "serde", "serde_json", diff --git a/integration-tests/models/test_idefics.py b/integration-tests/models/test_idefics.py index f38b9cde..882971f2 100644 --- a/integration-tests/models/test_idefics.py +++ b/integration-tests/models/test_idefics.py @@ -20,13 +20,14 @@ async def idefics(idefics_handle): def get_chicken(): with open("integration-tests/images/chicken_on_money.png", "rb") as image_file: encoded_string = base64.b64encode(image_file.read()) - return f"data:image/png;base64,{encoded_string}" + return f"data:image/png;base64,{encoded_string.decode('utf-8')}" @pytest.mark.asyncio async def test_idefics(idefics, response_snapshot): + chicken = get_chicken() response = await idefics.generate( - "User:![](https://huggingface.co/spaces/HuggingFaceM4/idefics_playground/resolve/main/example_images/chicken_on_money.png?download=true)Can you tell me a very short story based on the image?", + f"User:![]({chicken})Can you tell me a very short story based on the image?", max_new_tokens=10, decoder_input_details=True, ) @@ -37,9 +38,10 @@ async def test_idefics(idefics, response_snapshot): @pytest.mark.asyncio async def test_idefics_load(idefics, generate_load, response_snapshot): + chicken = get_chicken() responses = await generate_load( idefics, - "User:![](https://huggingface.co/spaces/HuggingFaceM4/idefics_playground/resolve/main/example_images/chicken_on_money.png?download=true)Can you tell me a very short story based on the image?", + f"User:![]({chicken})Can you tell me a very short story based on the image?", max_new_tokens=10, n=4, ) diff --git a/router/Cargo.toml b/router/Cargo.toml index 71f9f89c..0c57a886 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -46,6 +46,8 @@ ngrok = { version = "0.13.1", features = ["axum"], optional = true } init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } minijinja = { git = "https://github.com/mitsuhiko/minijinja.git", branch = "main", commit = "5cd4efb" } futures-util = "0.3.30" +regex = "1.10.3" +once_cell = "1.19.0" [build-dependencies] vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] } diff --git a/router/src/validation.rs b/router/src/validation.rs index 204dbf92..2aa9775a 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -13,6 +13,7 @@ use tokenizers::TruncationDirection; use tokio::sync::mpsc; use tokio::sync::oneshot; use tracing::{instrument, Span}; +use {once_cell::sync::Lazy, regex::Regex}; /// Validation #[derive(Debug, Clone)] @@ -409,10 +410,14 @@ async fn round_robin_task( /// Start tokenization workers fn tokenizer_worker(tokenizer: Tokenizer, mut receiver: mpsc::UnboundedReceiver) { // Loop over requests + let is_multimodal = { + let vocab = tokenizer.get_vocab(true); + vocab.contains_key("") + }; while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() { parent_span.in_scope(|| { response_tx - .send(prepare_input(inputs, truncate, &tokenizer)) + .send(prepare_input(inputs, truncate, &tokenizer, is_multimodal)) .unwrap_or(()) }) } @@ -423,15 +428,22 @@ fn prepare_input( mut inputs: String, truncate: Option, tokenizer: &Tokenizer, + is_multimodal: bool, ) -> Result<(tokenizers::Encoding, String), ValidationError> { + let simplified_query = if is_multimodal { + static RE: Lazy = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap()); + RE.replace_all(&inputs, "").into() + } else { + inputs.clone() + }; // Get the number of tokens in the input let mut encoding = tokenizer - .encode(inputs.clone(), true) + .encode(simplified_query, true) .map_err(|err| ValidationError::Tokenizer(err.to_string()))?; // Optionally truncate if let Some(truncate) = truncate { - if truncate < encoding.len() { + if truncate < encoding.len() && !is_multimodal { encoding.truncate(truncate, 0, TruncationDirection::Left); inputs = tokenizer .decode(encoding.get_ids(), false)