WIP
This commit is contained in:
parent
e4ab855480
commit
e6c524c66b
|
@ -3884,6 +3884,12 @@ dependencies = [
|
|||
"unicode-segmentation",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "static_assertions"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
|
||||
|
||||
[[package]]
|
||||
name = "strsim"
|
||||
version = "0.11.1"
|
||||
|
@ -4175,6 +4181,7 @@ dependencies = [
|
|||
"tracing",
|
||||
"tracing-opentelemetry 0.21.0",
|
||||
"tracing-subscriber",
|
||||
"twox-hash",
|
||||
"ureq",
|
||||
"utoipa",
|
||||
"utoipa-swagger-ui",
|
||||
|
@ -4776,6 +4783,17 @@ version = "0.2.5"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
|
||||
|
||||
[[package]]
|
||||
name = "twox-hash"
|
||||
version = "1.6.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"rand",
|
||||
"static_assertions",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typenum"
|
||||
version = "1.17.0"
|
||||
|
|
|
@ -68,14 +68,9 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
|||
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
|
||||
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
|
||||
if let Some(config) = config {
|
||||
if prefix_caching.is_none() {
|
||||
if config.vision_config.is_some() {
|
||||
tracing::info!("Disabling prefix caching because of VLM model");
|
||||
prefix_caching = Some("0".to_string());
|
||||
} else if config.is_encoder_decoder {
|
||||
tracing::info!("Disabling prefix caching because of seq2seq model");
|
||||
prefix_caching = Some("0".to_string());
|
||||
}
|
||||
if prefix_caching.is_none() && config.is_encoder_decoder {
|
||||
tracing::info!("Disabling prefix caching because of seq2seq model");
|
||||
prefix_caching = Some("0".to_string());
|
||||
}
|
||||
match config.head_dim {
|
||||
Some(h) if h == 64 || h == 128 || h == 256 => {
|
||||
|
@ -126,7 +121,6 @@ struct RawConfig {
|
|||
hidden_size: Option<usize>,
|
||||
num_attention_heads: Option<usize>,
|
||||
head_dim: Option<usize>,
|
||||
vision_config: Option<VisionConfig>,
|
||||
is_encoder_decoder: Option<bool>,
|
||||
}
|
||||
|
||||
|
@ -144,7 +138,6 @@ struct Config {
|
|||
quantize: Option<Quantization>,
|
||||
head_dim: Option<usize>,
|
||||
model_type: Option<String>,
|
||||
vision_config: Option<VisionConfig>,
|
||||
is_encoder_decoder: bool,
|
||||
}
|
||||
|
||||
|
@ -172,14 +165,12 @@ impl From<RawConfig> for Config {
|
|||
}
|
||||
});
|
||||
let model_type = other.model_type;
|
||||
let vision_config = other.vision_config;
|
||||
let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false);
|
||||
Config {
|
||||
max_position_embeddings,
|
||||
quantize,
|
||||
head_dim,
|
||||
model_type,
|
||||
vision_config,
|
||||
is_encoder_decoder,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -61,6 +61,7 @@ uuid = { version = "1.9.1", default-features = false, features = [
|
|||
] }
|
||||
csv = "1.3.0"
|
||||
ureq = "=2.9"
|
||||
twox-hash = "1.6.3"
|
||||
|
||||
|
||||
[build-dependencies]
|
||||
|
|
|
@ -129,6 +129,7 @@ pub struct PaliTextConfig {
|
|||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct Paligemma {
|
||||
pub(crate) image_token_index: u32,
|
||||
pub(crate) text_config: PaliTextConfig,
|
||||
}
|
||||
|
||||
|
|
|
@ -6,17 +6,23 @@ use crate::{
|
|||
};
|
||||
use base64::{engine::general_purpose::STANDARD, Engine};
|
||||
use image::{ImageFormat, ImageReader};
|
||||
use itertools::Itertools;
|
||||
use jsonschema::{Draft, JSONSchema};
|
||||
use rand::{thread_rng, Rng};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::hash::Hasher;
|
||||
use std::io::Cursor;
|
||||
use std::iter;
|
||||
use std::sync::Arc;
|
||||
use thiserror::Error;
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
use tokenizers::Encoding;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
use tracing::{instrument, Span};
|
||||
use twox_hash::xxh3::HasherExt;
|
||||
use twox_hash::Xxh3Hash128;
|
||||
use {once_cell::sync::Lazy, regex::Regex};
|
||||
|
||||
/// Validation
|
||||
|
@ -596,6 +602,45 @@ fn image_tokens(
|
|||
}
|
||||
}
|
||||
|
||||
fn image_id(config: &Config) -> u32 {
|
||||
use Config::*;
|
||||
match config {
|
||||
Paligemma(pali_gemma) => pali_gemma.image_token_index,
|
||||
_ => unimplemented!("Images tokens are not supported for this model configuration"),
|
||||
}
|
||||
}
|
||||
|
||||
fn n_image_tokens(
|
||||
config: &Config,
|
||||
preprocessor_config: Option<&HubPreprocessorConfig>,
|
||||
height: usize,
|
||||
width: usize,
|
||||
) -> usize {
|
||||
use Config::*;
|
||||
use HubPreprocessorConfig::*;
|
||||
match config {
|
||||
Idefics => 1,
|
||||
Idefics2(config) => {
|
||||
let repeats = if matches!(
|
||||
preprocessor_config,
|
||||
Some(Idefics2Processor(Idefics2Preprocessor {
|
||||
do_image_splitting: true,
|
||||
..
|
||||
}))
|
||||
) {
|
||||
5
|
||||
} else {
|
||||
1
|
||||
};
|
||||
|
||||
config.get_number_of_features(height, width) * repeats
|
||||
}
|
||||
Paligemma(config) => config.get_number_of_features(height, width),
|
||||
LlavaNext(config) => config.get_number_of_features(height, width),
|
||||
_ => unimplemented!("Images tokens are not supported for this model configuration"),
|
||||
}
|
||||
}
|
||||
|
||||
fn image_tokens_fixup(config: &Config, text: String) -> String {
|
||||
match config {
|
||||
Config::Idefics2(_) => {
|
||||
|
@ -617,8 +662,10 @@ fn prepare_input(
|
|||
) -> Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError> {
|
||||
use Config::*;
|
||||
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
||||
let (tokenizer_query, input_chunks) = match config {
|
||||
let (tokenizer_query, input_chunks, image_token_id, image_hashes, image_lens) = match config {
|
||||
Some(config @ (Idefics | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => {
|
||||
let mut image_hashes = Vec::new();
|
||||
let mut image_lens = Vec::new();
|
||||
let mut input_chunks = Vec::new();
|
||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||
let mut start = 0;
|
||||
|
@ -630,8 +677,15 @@ fn prepare_input(
|
|||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||
}
|
||||
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||
input_chunks.push(Chunk::Image(Image { data, mimetype }));
|
||||
|
||||
tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width));
|
||||
|
||||
let mut hasher = Xxh3Hash128::default();
|
||||
hasher.write(&data);
|
||||
image_hashes.push(hasher.finish_ext());
|
||||
image_lens.push(n_image_tokens(config, preprocessor_config, height, width));
|
||||
|
||||
input_chunks.push(Chunk::Image(Image { data, mimetype }));
|
||||
start = chunk_end;
|
||||
}
|
||||
if start != inputs.len() {
|
||||
|
@ -641,9 +695,15 @@ fn prepare_input(
|
|||
|
||||
tokenizer_query = image_tokens_fixup(config, tokenizer_query);
|
||||
|
||||
(tokenizer_query, input_chunks)
|
||||
(
|
||||
tokenizer_query,
|
||||
input_chunks,
|
||||
image_id(&config),
|
||||
image_hashes,
|
||||
image_lens,
|
||||
)
|
||||
}
|
||||
_ => (inputs.clone(), vec![Chunk::Text(inputs)]),
|
||||
_ => (inputs.clone(), vec![Chunk::Text(inputs)], 0, vec![], vec![]),
|
||||
};
|
||||
|
||||
// Get the number of tokens in the input
|
||||
|
@ -651,6 +711,35 @@ fn prepare_input(
|
|||
.encode(tokenizer_query, add_special_tokens)
|
||||
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||
|
||||
tracing::info!("encoding before hash: {:?}", encoding.get_ids());
|
||||
|
||||
// Replace image tokens by hashes. The first token of an image
|
||||
// must be specific to the image for prefix caching.
|
||||
let mut token_ids = encoding.get_ids().to_owned();
|
||||
let mut iter = token_ids.iter_mut().filter(|id| **id == image_token_id);
|
||||
for (image_hash, n_tokens) in image_hashes.iter().zip(image_lens.iter()) {
|
||||
let image_token = iter.next().ok_or(ValidationError::Tokenizer(
|
||||
"Image token not found".to_string(),
|
||||
))?;
|
||||
*image_token = *image_hash as u32;
|
||||
// Skip the remaining tokens of the current image.
|
||||
iter = iter.dropping(n_tokens - 1);
|
||||
}
|
||||
|
||||
let encoding = Encoding::new(
|
||||
token_ids,
|
||||
encoding.get_type_ids().to_owned(),
|
||||
encoding.get_tokens().to_owned(),
|
||||
encoding.get_word_ids().to_owned(),
|
||||
encoding.get_offsets().to_owned(),
|
||||
encoding.get_special_tokens_mask().to_owned(),
|
||||
encoding.get_attention_mask().to_owned(),
|
||||
encoding.get_overflowing().to_owned(),
|
||||
HashMap::new(),
|
||||
);
|
||||
|
||||
tracing::info!("encoding after hash: {:?}", encoding.get_ids());
|
||||
|
||||
Ok((encoding, input_chunks))
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue