From e6c524c66b4b04b33203b98df115c8d2c3b5f611 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 2 Sep 2024 14:48:16 +0000 Subject: [PATCH] WIP --- Cargo.lock | 18 ++++++++ launcher/src/main.rs | 15 ++----- router/Cargo.toml | 1 + router/src/config.rs | 1 + router/src/validation.rs | 97 ++++++++++++++++++++++++++++++++++++++-- 5 files changed, 116 insertions(+), 16 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 00c7f005..6d138032 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3884,6 +3884,12 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + [[package]] name = "strsim" version = "0.11.1" @@ -4175,6 +4181,7 @@ dependencies = [ "tracing", "tracing-opentelemetry 0.21.0", "tracing-subscriber", + "twox-hash", "ureq", "utoipa", "utoipa-swagger-ui", @@ -4776,6 +4783,17 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "twox-hash" +version = "1.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675" +dependencies = [ + "cfg-if", + "rand", + "static_assertions", +] + [[package]] name = "typenum" version = "1.17.0" diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 8e5c9dcd..49fb3998 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -68,14 +68,9 @@ fn resolve_attention(config: &Option, lora_adapters: &Option) -> let mut prefix_caching: Option = std::env::var("USE_PREFIX_CACHING").ok(); let mut attention: Option = std::env::var("ATTENTION").ok(); if let Some(config) = config { - if prefix_caching.is_none() { - if config.vision_config.is_some() { - tracing::info!("Disabling prefix caching because of VLM model"); - prefix_caching = Some("0".to_string()); - } else if config.is_encoder_decoder { - tracing::info!("Disabling prefix caching because of seq2seq model"); - prefix_caching = Some("0".to_string()); - } + if prefix_caching.is_none() && config.is_encoder_decoder { + tracing::info!("Disabling prefix caching because of seq2seq model"); + prefix_caching = Some("0".to_string()); } match config.head_dim { Some(h) if h == 64 || h == 128 || h == 256 => { @@ -126,7 +121,6 @@ struct RawConfig { hidden_size: Option, num_attention_heads: Option, head_dim: Option, - vision_config: Option, is_encoder_decoder: Option, } @@ -144,7 +138,6 @@ struct Config { quantize: Option, head_dim: Option, model_type: Option, - vision_config: Option, is_encoder_decoder: bool, } @@ -172,14 +165,12 @@ impl From for Config { } }); let model_type = other.model_type; - let vision_config = other.vision_config; let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false); Config { max_position_embeddings, quantize, head_dim, model_type, - vision_config, is_encoder_decoder, } } diff --git a/router/Cargo.toml b/router/Cargo.toml index 5c328e8a..fadb4ccc 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -61,6 +61,7 @@ uuid = { version = "1.9.1", default-features = false, features = [ ] } csv = "1.3.0" ureq = "=2.9" +twox-hash = "1.6.3" [build-dependencies] diff --git a/router/src/config.rs b/router/src/config.rs index 5d0be9c8..af148217 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -129,6 +129,7 @@ pub struct PaliTextConfig { #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub struct Paligemma { + pub(crate) image_token_index: u32, pub(crate) text_config: PaliTextConfig, } diff --git a/router/src/validation.rs b/router/src/validation.rs index 92491d88..fd57c26a 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -6,17 +6,23 @@ use crate::{ }; use base64::{engine::general_purpose::STANDARD, Engine}; use image::{ImageFormat, ImageReader}; +use itertools::Itertools; use jsonschema::{Draft, JSONSchema}; use rand::{thread_rng, Rng}; use serde_json::Value; +use std::collections::HashMap; +use std::hash::Hasher; use std::io::Cursor; use std::iter; use std::sync::Arc; use thiserror::Error; use tokenizers::tokenizer::Tokenizer; +use tokenizers::Encoding; use tokio::sync::mpsc; use tokio::sync::oneshot; use tracing::{instrument, Span}; +use twox_hash::xxh3::HasherExt; +use twox_hash::Xxh3Hash128; use {once_cell::sync::Lazy, regex::Regex}; /// Validation @@ -596,6 +602,45 @@ fn image_tokens( } } +fn image_id(config: &Config) -> u32 { + use Config::*; + match config { + Paligemma(pali_gemma) => pali_gemma.image_token_index, + _ => unimplemented!("Images tokens are not supported for this model configuration"), + } +} + +fn n_image_tokens( + config: &Config, + preprocessor_config: Option<&HubPreprocessorConfig>, + height: usize, + width: usize, +) -> usize { + use Config::*; + use HubPreprocessorConfig::*; + match config { + Idefics => 1, + Idefics2(config) => { + let repeats = if matches!( + preprocessor_config, + Some(Idefics2Processor(Idefics2Preprocessor { + do_image_splitting: true, + .. + })) + ) { + 5 + } else { + 1 + }; + + config.get_number_of_features(height, width) * repeats + } + Paligemma(config) => config.get_number_of_features(height, width), + LlavaNext(config) => config.get_number_of_features(height, width), + _ => unimplemented!("Images tokens are not supported for this model configuration"), + } +} + fn image_tokens_fixup(config: &Config, text: String) -> String { match config { Config::Idefics2(_) => { @@ -617,8 +662,10 @@ fn prepare_input( ) -> Result<(tokenizers::Encoding, Vec), ValidationError> { use Config::*; static RE: Lazy = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap()); - let (tokenizer_query, input_chunks) = match config { + let (tokenizer_query, input_chunks, image_token_id, image_hashes, image_lens) = match config { Some(config @ (Idefics | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => { + let mut image_hashes = Vec::new(); + let mut image_lens = Vec::new(); let mut input_chunks = Vec::new(); let mut tokenizer_query = String::with_capacity(inputs.len()); let mut start = 0; @@ -630,8 +677,15 @@ fn prepare_input( tokenizer_query.push_str(&inputs[start..chunk_start]); } let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?; - input_chunks.push(Chunk::Image(Image { data, mimetype })); + tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width)); + + let mut hasher = Xxh3Hash128::default(); + hasher.write(&data); + image_hashes.push(hasher.finish_ext()); + image_lens.push(n_image_tokens(config, preprocessor_config, height, width)); + + input_chunks.push(Chunk::Image(Image { data, mimetype })); start = chunk_end; } if start != inputs.len() { @@ -641,9 +695,15 @@ fn prepare_input( tokenizer_query = image_tokens_fixup(config, tokenizer_query); - (tokenizer_query, input_chunks) + ( + tokenizer_query, + input_chunks, + image_id(&config), + image_hashes, + image_lens, + ) } - _ => (inputs.clone(), vec![Chunk::Text(inputs)]), + _ => (inputs.clone(), vec![Chunk::Text(inputs)], 0, vec![], vec![]), }; // Get the number of tokens in the input @@ -651,6 +711,35 @@ fn prepare_input( .encode(tokenizer_query, add_special_tokens) .map_err(|err| ValidationError::Tokenizer(err.to_string()))?; + tracing::info!("encoding before hash: {:?}", encoding.get_ids()); + + // Replace image tokens by hashes. The first token of an image + // must be specific to the image for prefix caching. + let mut token_ids = encoding.get_ids().to_owned(); + let mut iter = token_ids.iter_mut().filter(|id| **id == image_token_id); + for (image_hash, n_tokens) in image_hashes.iter().zip(image_lens.iter()) { + let image_token = iter.next().ok_or(ValidationError::Tokenizer( + "Image token not found".to_string(), + ))?; + *image_token = *image_hash as u32; + // Skip the remaining tokens of the current image. + iter = iter.dropping(n_tokens - 1); + } + + let encoding = Encoding::new( + token_ids, + encoding.get_type_ids().to_owned(), + encoding.get_tokens().to_owned(), + encoding.get_word_ids().to_owned(), + encoding.get_offsets().to_owned(), + encoding.get_special_tokens_mask().to_owned(), + encoding.get_attention_mask().to_owned(), + encoding.get_overflowing().to_owned(), + HashMap::new(), + ); + + tracing::info!("encoding after hash: {:?}", encoding.get_ids()); + Ok((encoding, input_chunks)) }