This commit is contained in:
Daniël de Kok 2024-09-02 14:48:16 +00:00
parent e4ab855480
commit e6c524c66b
5 changed files with 116 additions and 16 deletions

18
Cargo.lock generated
View File

@ -3884,6 +3884,12 @@ dependencies = [
"unicode-segmentation",
]
[[package]]
name = "static_assertions"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
[[package]]
name = "strsim"
version = "0.11.1"
@ -4175,6 +4181,7 @@ dependencies = [
"tracing",
"tracing-opentelemetry 0.21.0",
"tracing-subscriber",
"twox-hash",
"ureq",
"utoipa",
"utoipa-swagger-ui",
@ -4776,6 +4783,17 @@ version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
[[package]]
name = "twox-hash"
version = "1.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675"
dependencies = [
"cfg-if",
"rand",
"static_assertions",
]
[[package]]
name = "typenum"
version = "1.17.0"

View File

@ -68,14 +68,9 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
if let Some(config) = config {
if prefix_caching.is_none() {
if config.vision_config.is_some() {
tracing::info!("Disabling prefix caching because of VLM model");
prefix_caching = Some("0".to_string());
} else if config.is_encoder_decoder {
tracing::info!("Disabling prefix caching because of seq2seq model");
prefix_caching = Some("0".to_string());
}
if prefix_caching.is_none() && config.is_encoder_decoder {
tracing::info!("Disabling prefix caching because of seq2seq model");
prefix_caching = Some("0".to_string());
}
match config.head_dim {
Some(h) if h == 64 || h == 128 || h == 256 => {
@ -126,7 +121,6 @@ struct RawConfig {
hidden_size: Option<usize>,
num_attention_heads: Option<usize>,
head_dim: Option<usize>,
vision_config: Option<VisionConfig>,
is_encoder_decoder: Option<bool>,
}
@ -144,7 +138,6 @@ struct Config {
quantize: Option<Quantization>,
head_dim: Option<usize>,
model_type: Option<String>,
vision_config: Option<VisionConfig>,
is_encoder_decoder: bool,
}
@ -172,14 +165,12 @@ impl From<RawConfig> for Config {
}
});
let model_type = other.model_type;
let vision_config = other.vision_config;
let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false);
Config {
max_position_embeddings,
quantize,
head_dim,
model_type,
vision_config,
is_encoder_decoder,
}
}

View File

@ -61,6 +61,7 @@ uuid = { version = "1.9.1", default-features = false, features = [
] }
csv = "1.3.0"
ureq = "=2.9"
twox-hash = "1.6.3"
[build-dependencies]

View File

@ -129,6 +129,7 @@ pub struct PaliTextConfig {
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Paligemma {
pub(crate) image_token_index: u32,
pub(crate) text_config: PaliTextConfig,
}

View File

@ -6,17 +6,23 @@ use crate::{
};
use base64::{engine::general_purpose::STANDARD, Engine};
use image::{ImageFormat, ImageReader};
use itertools::Itertools;
use jsonschema::{Draft, JSONSchema};
use rand::{thread_rng, Rng};
use serde_json::Value;
use std::collections::HashMap;
use std::hash::Hasher;
use std::io::Cursor;
use std::iter;
use std::sync::Arc;
use thiserror::Error;
use tokenizers::tokenizer::Tokenizer;
use tokenizers::Encoding;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tracing::{instrument, Span};
use twox_hash::xxh3::HasherExt;
use twox_hash::Xxh3Hash128;
use {once_cell::sync::Lazy, regex::Regex};
/// Validation
@ -596,6 +602,45 @@ fn image_tokens(
}
}
fn image_id(config: &Config) -> u32 {
use Config::*;
match config {
Paligemma(pali_gemma) => pali_gemma.image_token_index,
_ => unimplemented!("Images tokens are not supported for this model configuration"),
}
}
fn n_image_tokens(
config: &Config,
preprocessor_config: Option<&HubPreprocessorConfig>,
height: usize,
width: usize,
) -> usize {
use Config::*;
use HubPreprocessorConfig::*;
match config {
Idefics => 1,
Idefics2(config) => {
let repeats = if matches!(
preprocessor_config,
Some(Idefics2Processor(Idefics2Preprocessor {
do_image_splitting: true,
..
}))
) {
5
} else {
1
};
config.get_number_of_features(height, width) * repeats
}
Paligemma(config) => config.get_number_of_features(height, width),
LlavaNext(config) => config.get_number_of_features(height, width),
_ => unimplemented!("Images tokens are not supported for this model configuration"),
}
}
fn image_tokens_fixup(config: &Config, text: String) -> String {
match config {
Config::Idefics2(_) => {
@ -617,8 +662,10 @@ fn prepare_input(
) -> Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError> {
use Config::*;
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
let (tokenizer_query, input_chunks) = match config {
let (tokenizer_query, input_chunks, image_token_id, image_hashes, image_lens) = match config {
Some(config @ (Idefics | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => {
let mut image_hashes = Vec::new();
let mut image_lens = Vec::new();
let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
@ -630,8 +677,15 @@ fn prepare_input(
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
input_chunks.push(Chunk::Image(Image { data, mimetype }));
tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width));
let mut hasher = Xxh3Hash128::default();
hasher.write(&data);
image_hashes.push(hasher.finish_ext());
image_lens.push(n_image_tokens(config, preprocessor_config, height, width));
input_chunks.push(Chunk::Image(Image { data, mimetype }));
start = chunk_end;
}
if start != inputs.len() {
@ -641,9 +695,15 @@ fn prepare_input(
tokenizer_query = image_tokens_fixup(config, tokenizer_query);
(tokenizer_query, input_chunks)
(
tokenizer_query,
input_chunks,
image_id(&config),
image_hashes,
image_lens,
)
}
_ => (inputs.clone(), vec![Chunk::Text(inputs)]),
_ => (inputs.clone(), vec![Chunk::Text(inputs)], 0, vec![], vec![]),
};
// Get the number of tokens in the input
@ -651,6 +711,35 @@ fn prepare_input(
.encode(tokenizer_query, add_special_tokens)
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
tracing::info!("encoding before hash: {:?}", encoding.get_ids());
// Replace image tokens by hashes. The first token of an image
// must be specific to the image for prefix caching.
let mut token_ids = encoding.get_ids().to_owned();
let mut iter = token_ids.iter_mut().filter(|id| **id == image_token_id);
for (image_hash, n_tokens) in image_hashes.iter().zip(image_lens.iter()) {
let image_token = iter.next().ok_or(ValidationError::Tokenizer(
"Image token not found".to_string(),
))?;
*image_token = *image_hash as u32;
// Skip the remaining tokens of the current image.
iter = iter.dropping(n_tokens - 1);
}
let encoding = Encoding::new(
token_ids,
encoding.get_type_ids().to_owned(),
encoding.get_tokens().to_owned(),
encoding.get_word_ids().to_owned(),
encoding.get_offsets().to_owned(),
encoding.get_special_tokens_mask().to_owned(),
encoding.get_attention_mask().to_owned(),
encoding.get_overflowing().to_owned(),
HashMap::new(),
);
tracing::info!("encoding after hash: {:?}", encoding.get_ids());
Ok((encoding, input_chunks))
}