Idefics2: sync added image tokens with transformers (#2080)
Before this change, the number of reserved image tokens was not the same as the number of images. Fixes #2029. While at it, also remove all the image token handling duplication in `prepare_input`.
This commit is contained in:
parent
b53b21c63a
commit
dd2d91b043
|
@ -3832,6 +3832,7 @@ dependencies = [
|
|||
"hf-hub",
|
||||
"image",
|
||||
"init-tracing-opentelemetry",
|
||||
"itertools 0.10.5",
|
||||
"jsonschema",
|
||||
"metrics 0.21.1",
|
||||
"metrics-exporter-prometheus",
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -8,61 +8,61 @@
|
|||
"tokens": [
|
||||
{
|
||||
"id": 330,
|
||||
"logprob": -0.13000488,
|
||||
"logprob": -0.08660889,
|
||||
"special": false,
|
||||
"text": " A"
|
||||
},
|
||||
{
|
||||
"id": 13088,
|
||||
"logprob": -0.6713867,
|
||||
"logprob": -0.7089844,
|
||||
"special": false,
|
||||
"text": " chicken"
|
||||
},
|
||||
{
|
||||
"id": 349,
|
||||
"logprob": -0.2980957,
|
||||
"logprob": -0.32885742,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 6398,
|
||||
"logprob": -0.060638428,
|
||||
"logprob": -0.05126953,
|
||||
"special": false,
|
||||
"text": " sitting"
|
||||
},
|
||||
{
|
||||
"id": 356,
|
||||
"logprob": -0.27319336,
|
||||
"logprob": -0.35229492,
|
||||
"special": false,
|
||||
"text": " on"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.140625,
|
||||
"logprob": -0.12561035,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 17972,
|
||||
"logprob": -0.040405273,
|
||||
"logprob": -0.038085938,
|
||||
"special": false,
|
||||
"text": " pile"
|
||||
},
|
||||
{
|
||||
"id": 302,
|
||||
"logprob": -0.0002708435,
|
||||
"logprob": -0.00018656254,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 2445,
|
||||
"logprob": -0.095336914,
|
||||
"logprob": -0.07293701,
|
||||
"special": false,
|
||||
"text": " money"
|
||||
},
|
||||
{
|
||||
"id": 28723,
|
||||
"logprob": -0.0068359375,
|
||||
"logprob": -0.004852295,
|
||||
"special": false,
|
||||
"text": "."
|
||||
}
|
||||
|
|
|
@ -8,115 +8,115 @@
|
|||
"tokens": [
|
||||
{
|
||||
"id": 415,
|
||||
"logprob": -0.04421997,
|
||||
"logprob": -0.039886475,
|
||||
"special": false,
|
||||
"text": " The"
|
||||
},
|
||||
{
|
||||
"id": 12072,
|
||||
"logprob": -0.13500977,
|
||||
"logprob": -0.1430664,
|
||||
"special": false,
|
||||
"text": " cow"
|
||||
},
|
||||
{
|
||||
"id": 349,
|
||||
"logprob": -0.06750488,
|
||||
"logprob": -0.056488037,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 6328,
|
||||
"logprob": -0.6352539,
|
||||
"logprob": -0.6855469,
|
||||
"special": false,
|
||||
"text": " standing"
|
||||
},
|
||||
{
|
||||
"id": 356,
|
||||
"logprob": -0.16186523,
|
||||
"logprob": -0.1685791,
|
||||
"special": false,
|
||||
"text": " on"
|
||||
},
|
||||
{
|
||||
"id": 272,
|
||||
"logprob": -0.5078125,
|
||||
"logprob": -0.50097656,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 10305,
|
||||
"logprob": -0.017913818,
|
||||
"logprob": -0.017303467,
|
||||
"special": false,
|
||||
"text": " beach"
|
||||
},
|
||||
{
|
||||
"id": 304,
|
||||
"logprob": -1.5205078,
|
||||
"logprob": -1.3564453,
|
||||
"special": false,
|
||||
"text": " and"
|
||||
},
|
||||
{
|
||||
"id": 272,
|
||||
"logprob": -0.029174805,
|
||||
"logprob": -0.017868042,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 13088,
|
||||
"logprob": -0.003479004,
|
||||
"logprob": -0.0027103424,
|
||||
"special": false,
|
||||
"text": " chicken"
|
||||
},
|
||||
{
|
||||
"id": 349,
|
||||
"logprob": -0.0035095215,
|
||||
"logprob": -0.003156662,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 6398,
|
||||
"logprob": -0.3088379,
|
||||
"logprob": -0.37304688,
|
||||
"special": false,
|
||||
"text": " sitting"
|
||||
},
|
||||
{
|
||||
"id": 356,
|
||||
"logprob": -0.027755737,
|
||||
"logprob": -0.034576416,
|
||||
"special": false,
|
||||
"text": " on"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.31884766,
|
||||
"logprob": -0.29418945,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 17972,
|
||||
"logprob": -0.047943115,
|
||||
"logprob": -0.042877197,
|
||||
"special": false,
|
||||
"text": " pile"
|
||||
},
|
||||
{
|
||||
"id": 302,
|
||||
"logprob": -0.0002925396,
|
||||
"logprob": -0.00028443336,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 2445,
|
||||
"logprob": -0.02935791,
|
||||
"logprob": -0.023223877,
|
||||
"special": false,
|
||||
"text": " money"
|
||||
},
|
||||
{
|
||||
"id": 28723,
|
||||
"logprob": -0.031219482,
|
||||
"logprob": -0.018157959,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 32002,
|
||||
"logprob": -0.00034475327,
|
||||
"logprob": -0.00018393993,
|
||||
"special": true,
|
||||
"text": "<end_of_utterance>"
|
||||
},
|
||||
|
|
|
@ -22,6 +22,7 @@ text-generation-client = { path = "client" }
|
|||
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||
futures = "0.3.28"
|
||||
hf-hub = { workspace = true }
|
||||
itertools = "0.10"
|
||||
jsonschema = { version = "0.17.1", features = ["draft202012"] }
|
||||
metrics = "0.21.1"
|
||||
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
|
||||
|
|
|
@ -71,10 +71,12 @@ fn get_unpadded_features(
|
|||
let current_aspect_ratio: f64 = current_width as f64 / current_height as f64;
|
||||
let (current_height, current_width) = if aspect_ratio > current_aspect_ratio {
|
||||
let new_height = (height * current_width) / width;
|
||||
(new_height, current_width)
|
||||
let padding = (current_height - new_height) / 2;
|
||||
(current_height - (2 * padding), current_width)
|
||||
} else {
|
||||
let new_width = (width * current_height) / height;
|
||||
(current_height, new_width)
|
||||
let padding = (current_width - new_width) / 2;
|
||||
(current_height, current_width - (2 * padding))
|
||||
};
|
||||
|
||||
let unpadded_features = current_height * current_width;
|
||||
|
@ -88,7 +90,9 @@ impl LlavaNext {
|
|||
let patch_size = self.vision_config.patch_size;
|
||||
assert!(image_size % patch_size == 0);
|
||||
let npatches = image_size / patch_size;
|
||||
let (num_patch_height, num_patch_width) =
|
||||
// Dimensions are intentionally swapped to be bug-compatible with
|
||||
// upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
|
||||
let (num_patch_width, num_patch_height) =
|
||||
get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size);
|
||||
|
||||
let (unpadded_features, newline_features) =
|
||||
|
@ -112,7 +116,7 @@ pub struct Idefics2 {}
|
|||
|
||||
impl Idefics2 {
|
||||
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
|
||||
320
|
||||
64
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -70,6 +70,25 @@ impl HubTokenizerConfig {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "processor_class")]
|
||||
pub enum HubPreprocessorConfig {
|
||||
Idefics2Processor(Idefics2Preprocessor),
|
||||
}
|
||||
|
||||
impl HubPreprocessorConfig {
|
||||
pub fn from_file<P: AsRef<std::path::Path>>(filename: P) -> Option<Self> {
|
||||
let content = std::fs::read_to_string(filename).ok()?;
|
||||
serde_json::from_str(&content).ok()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Idefics2Preprocessor {
|
||||
#[serde(default)]
|
||||
do_image_splitting: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Default)]
|
||||
pub struct HubProcessorConfig {
|
||||
pub chat_template: Option<ChatTemplateVersions>,
|
||||
|
|
|
@ -13,7 +13,9 @@ use std::io::BufReader;
|
|||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::path::{Path, PathBuf};
|
||||
use text_generation_router::config::Config;
|
||||
use text_generation_router::{server, HubModelInfo, HubProcessorConfig, HubTokenizerConfig};
|
||||
use text_generation_router::{
|
||||
server, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig,
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokenizers::Tokenizer;
|
||||
use tower_http::cors::AllowOrigin;
|
||||
|
@ -214,6 +216,7 @@ async fn main() -> Result<(), RouterError> {
|
|||
tokenizer_filename,
|
||||
config_filename,
|
||||
tokenizer_config_filename,
|
||||
preprocessor_config_filename,
|
||||
processor_config_filename,
|
||||
model_info,
|
||||
) = match api {
|
||||
|
@ -221,6 +224,7 @@ async fn main() -> Result<(), RouterError> {
|
|||
Some(local_path.join("tokenizer.json")),
|
||||
Some(local_path.join("config.json")),
|
||||
Some(local_path.join("tokenizer_config.json")),
|
||||
Some(local_path.join("preprocessor_config.json")),
|
||||
Some(local_path.join("processor_config.json")),
|
||||
None,
|
||||
),
|
||||
|
@ -237,6 +241,7 @@ async fn main() -> Result<(), RouterError> {
|
|||
};
|
||||
let config_filename = api_repo.get("config.json").await.ok();
|
||||
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
|
||||
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
|
||||
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
|
||||
|
||||
let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
|
||||
|
@ -249,6 +254,7 @@ async fn main() -> Result<(), RouterError> {
|
|||
tokenizer_filename,
|
||||
config_filename,
|
||||
tokenizer_config_filename,
|
||||
preprocessor_config_filename,
|
||||
processor_config_filename,
|
||||
model_info,
|
||||
)
|
||||
|
@ -263,6 +269,7 @@ async fn main() -> Result<(), RouterError> {
|
|||
repo.get("tokenizer.json"),
|
||||
repo.get("config.json"),
|
||||
repo.get("tokenizer_config.json"),
|
||||
repo.get("preprocessor_config.json"),
|
||||
repo.get("processor_config.json"),
|
||||
None,
|
||||
)
|
||||
|
@ -300,6 +307,8 @@ async fn main() -> Result<(), RouterError> {
|
|||
HubTokenizerConfig::default()
|
||||
});
|
||||
|
||||
let preprocessor_config =
|
||||
preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file);
|
||||
let processor_config = processor_config_filename
|
||||
.and_then(HubProcessorConfig::from_file)
|
||||
.unwrap_or_default();
|
||||
|
@ -361,6 +370,7 @@ async fn main() -> Result<(), RouterError> {
|
|||
ngrok_authtoken,
|
||||
ngrok_edge,
|
||||
tokenizer_config,
|
||||
preprocessor_config,
|
||||
processor_config,
|
||||
messages_api_enabled,
|
||||
disable_grammar_support,
|
||||
|
|
|
@ -12,9 +12,9 @@ use crate::kserve::{
|
|||
use crate::validation::ValidationError;
|
||||
use crate::{
|
||||
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
|
||||
GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info,
|
||||
Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse,
|
||||
Usage, Validation,
|
||||
GenerateResponse, GrammarType, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig,
|
||||
HubTokenizerConfig, Info, Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse,
|
||||
Token, TokenizeResponse, Usage, Validation,
|
||||
};
|
||||
use crate::{
|
||||
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
|
||||
|
@ -1423,6 +1423,7 @@ pub async fn run(
|
|||
_ngrok_authtoken: Option<String>,
|
||||
_ngrok_edge: Option<String>,
|
||||
tokenizer_config: HubTokenizerConfig,
|
||||
preprocessor_config: Option<HubPreprocessorConfig>,
|
||||
processor_config: HubProcessorConfig,
|
||||
messages_api_enabled: bool,
|
||||
grammar_support: bool,
|
||||
|
@ -1636,6 +1637,7 @@ pub async fn run(
|
|||
validation_workers,
|
||||
tokenizer,
|
||||
config,
|
||||
preprocessor_config,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
|
|
|
@ -1,13 +1,16 @@
|
|||
/// Payload validation logic
|
||||
use crate::config::Config;
|
||||
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
||||
use crate::{GenerateParameters, GenerateRequest, GrammarType};
|
||||
use crate::{
|
||||
GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor,
|
||||
};
|
||||
use base64::{engine::general_purpose::STANDARD, Engine};
|
||||
use image::{io::Reader as ImageReader, ImageFormat};
|
||||
use jsonschema::{Draft, JSONSchema};
|
||||
use rand::{thread_rng, Rng};
|
||||
use serde_json::Value;
|
||||
use std::io::Cursor;
|
||||
use std::iter;
|
||||
use text_generation_client::{Chunk, Image, InputChunk};
|
||||
use thiserror::Error;
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
|
@ -36,6 +39,7 @@ impl Validation {
|
|||
workers: usize,
|
||||
tokenizer: Option<Tokenizer>,
|
||||
config: Option<Config>,
|
||||
preprocessor_config: Option<HubPreprocessorConfig>,
|
||||
max_best_of: usize,
|
||||
max_stop_sequences: usize,
|
||||
max_top_n_tokens: u32,
|
||||
|
@ -53,12 +57,18 @@ impl Validation {
|
|||
for _ in 0..workers {
|
||||
let tokenizer_clone = tokenizer.clone();
|
||||
let config_clone = config.clone();
|
||||
let preprocessor_config_clone = preprocessor_config.clone();
|
||||
let (tokenizer_sender, tokenizer_receiver) = mpsc::unbounded_channel();
|
||||
senders.push(tokenizer_sender);
|
||||
|
||||
// Spawn worker
|
||||
tokio::task::spawn_blocking(move || {
|
||||
tokenizer_worker(tokenizer_clone, config_clone, tokenizer_receiver)
|
||||
tokenizer_worker(
|
||||
tokenizer_clone,
|
||||
config_clone,
|
||||
preprocessor_config_clone,
|
||||
tokenizer_receiver,
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -422,13 +432,20 @@ async fn round_robin_task(
|
|||
fn tokenizer_worker(
|
||||
tokenizer: Tokenizer,
|
||||
config: Option<Config>,
|
||||
preprocessor_config: Option<HubPreprocessorConfig>,
|
||||
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
|
||||
) {
|
||||
// Loop over requests
|
||||
while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() {
|
||||
parent_span.in_scope(|| {
|
||||
response_tx
|
||||
.send(prepare_input(inputs, truncate, &tokenizer, &config))
|
||||
.send(prepare_input(
|
||||
inputs,
|
||||
truncate,
|
||||
&tokenizer,
|
||||
config.as_ref(),
|
||||
preprocessor_config.as_ref(),
|
||||
))
|
||||
.unwrap_or(())
|
||||
})
|
||||
}
|
||||
|
@ -508,16 +525,67 @@ fn fetch_image(input: &str) -> Result<(Vec<u8>, String, usize, usize), Validatio
|
|||
}
|
||||
}
|
||||
|
||||
fn image_tokens(
|
||||
config: &Config,
|
||||
preprocessor_config: Option<&HubPreprocessorConfig>,
|
||||
height: usize,
|
||||
width: usize,
|
||||
) -> String {
|
||||
use Config::*;
|
||||
use HubPreprocessorConfig::*;
|
||||
match config {
|
||||
Idefics => "<image>".to_string(),
|
||||
Idefics2(config) => {
|
||||
const FAKE: &str = "<fake_token_around_image>";
|
||||
const IMAGE: &str = "<image>";
|
||||
|
||||
let slots = config.get_number_of_features(height, width);
|
||||
|
||||
let mut image_string = String::with_capacity(2 * FAKE.len() + slots * IMAGE.len());
|
||||
image_string.push_str(FAKE);
|
||||
image_string.extend(iter::repeat(IMAGE).take(slots));
|
||||
image_string.push_str(FAKE);
|
||||
|
||||
if matches!(
|
||||
preprocessor_config,
|
||||
Some(Idefics2Processor(Idefics2Preprocessor {
|
||||
do_image_splitting: true,
|
||||
..
|
||||
}))
|
||||
) {
|
||||
image_string = image_string.repeat(5);
|
||||
};
|
||||
|
||||
image_string
|
||||
}
|
||||
Paligemma(config) => "<image>".repeat(config.get_number_of_features(height, width)),
|
||||
LlavaNext(config) => "<image>".repeat(config.get_number_of_features(height, width)),
|
||||
_ => unimplemented!("Images tokens are not supported for this model configuration"),
|
||||
}
|
||||
}
|
||||
|
||||
fn image_tokens_fixup(config: &Config, text: String) -> String {
|
||||
match config {
|
||||
Config::Idefics2(_) => {
|
||||
const FAKE: &str = "<fake_token_around_image>";
|
||||
text.replace(&format!("{FAKE}{FAKE}"), FAKE)
|
||||
}
|
||||
_ => text,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get input length and optionally truncate it
|
||||
fn prepare_input(
|
||||
inputs: String,
|
||||
_truncate: Option<usize>,
|
||||
tokenizer: &Tokenizer,
|
||||
config: &Option<Config>,
|
||||
config: Option<&Config>,
|
||||
preprocessor_config: Option<&HubPreprocessorConfig>,
|
||||
) -> Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError> {
|
||||
use Config::*;
|
||||
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
||||
let (tokenizer_query, input_chunks) = match config {
|
||||
Some(Config::LlavaNext(config)) => {
|
||||
Some(config @ (Idefics | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => {
|
||||
let mut input_chunks = Vec::new();
|
||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||
let mut start = 0;
|
||||
|
@ -529,88 +597,17 @@ fn prepare_input(
|
|||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||
}
|
||||
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||
let slots = config.get_number_of_features(height, width);
|
||||
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
|
||||
tokenizer_query.push_str(&"<image>".repeat(slots));
|
||||
tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width));
|
||||
start = chunk_end;
|
||||
}
|
||||
if start != inputs.len() {
|
||||
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
|
||||
tokenizer_query.push_str(&inputs[start..]);
|
||||
}
|
||||
(tokenizer_query, input_chunks)
|
||||
}
|
||||
Some(Config::Paligemma(config)) => {
|
||||
let mut input_chunks = Vec::new();
|
||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||
let mut start = 0;
|
||||
for chunk in RE.find_iter(&inputs) {
|
||||
let chunk_start = chunk.start();
|
||||
let chunk_end = chunk.end();
|
||||
if chunk_start != start {
|
||||
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
|
||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||
}
|
||||
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||
let slots = config.get_number_of_features(height, width);
|
||||
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
|
||||
tokenizer_query.push_str(&"<image>".repeat(slots));
|
||||
start = chunk_end;
|
||||
}
|
||||
if start != inputs.len() {
|
||||
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
|
||||
tokenizer_query.push_str(&inputs[start..]);
|
||||
}
|
||||
(tokenizer_query, input_chunks)
|
||||
}
|
||||
Some(Config::Idefics2(config)) => {
|
||||
let mut input_chunks = Vec::new();
|
||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||
let mut start = 0;
|
||||
for chunk in RE.find_iter(&inputs) {
|
||||
let chunk_start = chunk.start();
|
||||
let chunk_end = chunk.end();
|
||||
if chunk_start != start {
|
||||
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
|
||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||
}
|
||||
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||
let slots = config.get_number_of_features(height, width);
|
||||
tokenizer_query.push_str("<fake_token_around_image>");
|
||||
tokenizer_query.push_str(&"<image>".repeat(slots));
|
||||
tokenizer_query.push_str("<fake_token_around_image>");
|
||||
|
||||
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
|
||||
start = chunk_end;
|
||||
}
|
||||
if start != inputs.len() {
|
||||
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
|
||||
tokenizer_query.push_str(&inputs[start..]);
|
||||
}
|
||||
(tokenizer_query, input_chunks)
|
||||
}
|
||||
Some(Config::Idefics) => {
|
||||
let mut input_chunks = Vec::new();
|
||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||
let mut start = 0;
|
||||
for chunk in RE.find_iter(&inputs) {
|
||||
let chunk_start = chunk.start();
|
||||
let chunk_end = chunk.end();
|
||||
if chunk_start != start {
|
||||
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
|
||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||
}
|
||||
let (data, mimetype, _height, _width) =
|
||||
fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||
let slots = 1;
|
||||
tokenizer_query.push_str(&"<image>".repeat(slots));
|
||||
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
|
||||
start = chunk_end;
|
||||
}
|
||||
if start != inputs.len() {
|
||||
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
|
||||
tokenizer_query.push_str(&inputs[start..]);
|
||||
}
|
||||
tokenizer_query = image_tokens_fixup(config, tokenizer_query);
|
||||
|
||||
(tokenizer_query, input_chunks)
|
||||
}
|
||||
_ => (inputs.clone(), vec![Chunk::Text(inputs).into()]),
|
||||
|
@ -750,7 +747,7 @@ pub enum ValidationError {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::{PaliTextConfig, Paligemma};
|
||||
use crate::config::{Idefics2, PaliTextConfig, Paligemma};
|
||||
use crate::default_parameters;
|
||||
use crate::tests::get_tokenizer;
|
||||
|
||||
|
@ -769,6 +766,7 @@ mod tests {
|
|||
workers,
|
||||
tokenizer,
|
||||
config,
|
||||
None,
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_top_n_tokens,
|
||||
|
@ -803,6 +801,7 @@ mod tests {
|
|||
workers,
|
||||
tokenizer,
|
||||
config,
|
||||
None,
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_top_n_tokens,
|
||||
|
@ -836,6 +835,7 @@ mod tests {
|
|||
workers,
|
||||
tokenizer,
|
||||
config,
|
||||
None,
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_top_n_tokens,
|
||||
|
@ -874,6 +874,7 @@ mod tests {
|
|||
workers,
|
||||
tokenizer,
|
||||
config,
|
||||
None,
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_top_n_tokens,
|
||||
|
@ -941,6 +942,7 @@ mod tests {
|
|||
workers,
|
||||
tokenizer,
|
||||
config,
|
||||
None,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
|
@ -1026,6 +1028,7 @@ mod tests {
|
|||
workers,
|
||||
tokenizer,
|
||||
Some(config),
|
||||
None,
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_top_n_tokens,
|
||||
|
@ -1058,4 +1061,83 @@ mod tests {
|
|||
"Failed to process images",
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_idefics2_correct_n_fake_tokens() {
|
||||
let pixel_data = STANDARD.decode(PIXEL_GIF).unwrap();
|
||||
|
||||
let tokenizer = Some(get_tokenizer().await);
|
||||
|
||||
let max_best_of = 2;
|
||||
let max_stop_sequence = 3;
|
||||
let max_top_n_tokens = 4;
|
||||
let max_input_length = 5;
|
||||
let max_total_tokens = 6;
|
||||
let disable_grammar_support = true;
|
||||
let workers = 1;
|
||||
let config = Config::Idefics2(Idefics2 {});
|
||||
let validation = Validation::new(
|
||||
workers,
|
||||
tokenizer,
|
||||
Some(config),
|
||||
Some(HubPreprocessorConfig::Idefics2Processor(
|
||||
Idefics2Preprocessor {
|
||||
do_image_splitting: true,
|
||||
},
|
||||
)),
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_top_n_tokens,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
disable_grammar_support,
|
||||
);
|
||||
|
||||
let (encoding, chunks) = match validation
|
||||
.tokenize(
|
||||
format!(
|
||||
"test![](data:image/gif;base64,{})![](data:image/gif;base64,{})",
|
||||
PIXEL_GIF, PIXEL_GIF
|
||||
),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Some((encoding, chunks))) => (encoding, chunks),
|
||||
_ => panic!("Unexpected tokenization failure"),
|
||||
};
|
||||
|
||||
assert!(
|
||||
chunks
|
||||
== vec![
|
||||
Chunk::Text("test".to_string()).into(),
|
||||
Chunk::Image(Image {
|
||||
data: pixel_data.clone(),
|
||||
mimetype: "image/gif".to_string()
|
||||
})
|
||||
.into(),
|
||||
Chunk::Image(Image {
|
||||
data: pixel_data.clone(),
|
||||
mimetype: "image/gif".to_string()
|
||||
})
|
||||
.into()
|
||||
],
|
||||
"Failed to process images",
|
||||
);
|
||||
|
||||
// Verify the number of fake tokens:
|
||||
//
|
||||
// - Two images surrounded/separated by a fake token = 3.
|
||||
// - Both are split in 5 subimages, separated by a fake token: 2 * 4
|
||||
//
|
||||
// Fake tokens get split up by the testing tokenizer, but we don't care.
|
||||
assert_eq!(
|
||||
encoding
|
||||
.get_tokens()
|
||||
.iter()
|
||||
.filter(|t| *t == "fake")
|
||||
.count(),
|
||||
11
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -39,7 +39,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
|||
|
||||
Args:
|
||||
image_size (`tuple`):
|
||||
The size of the input image in the format (width, height).
|
||||
The size of the input image in the format (height, width).
|
||||
grid_pinpoints (`List`):
|
||||
A list containing possible resolutions. Each item in the list should be a tuple or list
|
||||
of the form `(height, width)`.
|
||||
|
@ -47,7 +47,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
|||
The size of each image patch.
|
||||
|
||||
Returns:
|
||||
tuple: The shape of the image patch grid in the format (width, height).
|
||||
tuple: The shape of the image patch grid in the format (height, width).
|
||||
"""
|
||||
if not isinstance(grid_pinpoints, list):
|
||||
raise ValueError("grid_pinpoints should be a list of tuples or lists")
|
||||
|
@ -230,7 +230,10 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||
raise ValueError(
|
||||
"The number of patches is not consistent with the image size."
|
||||
)
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
|
||||
# Dimensions are intentionally swapped to be bug-compatible with
|
||||
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
|
||||
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
||||
image_sizes[image_idx],
|
||||
self.config.image_grid_pinpoints,
|
||||
self.config.vision_config.image_size,
|
||||
|
|
|
@ -39,7 +39,9 @@ class PaliGemmaBatch(VlmCausalLMBatch):
|
|||
# TODO do_convert_RGB should be on by default ?
|
||||
image = image.convert("RGB")
|
||||
image_input = processor.image_processor(image, return_tensors="pt")
|
||||
full_text += image_text_replacement(image_input, config, image_id)
|
||||
full_text += image_text_replacement(
|
||||
processor, image_input, config, image_id
|
||||
)
|
||||
image_inputs.append(image_input)
|
||||
else:
|
||||
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from itertools import repeat
|
||||
import torch
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
|
@ -15,6 +16,9 @@ from text_generation_server.models.flash_mistral import (
|
|||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
IDEFICS2_FAKE_TOKEN = "<fake_token_around_image>"
|
||||
IDEFICS2_IMAGE_TOKEN = "<image>"
|
||||
|
||||
|
||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
"""
|
||||
|
@ -22,7 +26,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
|||
|
||||
Args:
|
||||
image_size (`tuple`):
|
||||
The size of the input image in the format (width, height).
|
||||
The size of the input image in the format (height, width).
|
||||
grid_pinpoints (`List`):
|
||||
A list containing possible resolutions. Each item in the list should be a tuple or list
|
||||
of the form `(height, width)`.
|
||||
|
@ -39,15 +43,13 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
|||
return height // patch_size, width // patch_size
|
||||
|
||||
|
||||
def image_text_replacement(image_input, config, image_id) -> str:
|
||||
def image_text_replacement(processor, image_input, config, image_id: int) -> str:
|
||||
if config.model_type == "idefics2":
|
||||
# TODO technically depends on image splitting which is not implemented.
|
||||
num_features = 320
|
||||
return (
|
||||
"<fake_token_around_image>"
|
||||
+ "<image>" * num_features
|
||||
+ "<fake_token_around_image>"
|
||||
)
|
||||
image_seq_len = 64
|
||||
image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}"
|
||||
if processor.image_processor.do_image_splitting:
|
||||
image_str *= 5
|
||||
return image_str
|
||||
elif config.model_type == "llava_next":
|
||||
height, width = image_input["image_sizes"][image_id]
|
||||
num_features = get_number_of_features(height, width, config)
|
||||
|
@ -64,20 +66,35 @@ def image_text_replacement(image_input, config, image_id) -> str:
|
|||
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
||||
|
||||
|
||||
def image_text_replacement_fixup(config, text: str) -> str:
|
||||
if config.model_type == "idefics2":
|
||||
return text.replace(
|
||||
f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN
|
||||
)
|
||||
return text
|
||||
|
||||
|
||||
def get_unpadded_features(
|
||||
height: int, width: int, npatches: int, num_patch_height: int, num_patch_width: int
|
||||
original_height: int,
|
||||
original_width: int,
|
||||
npatches: int,
|
||||
num_patch_height: int,
|
||||
num_patch_width: int,
|
||||
) -> Tuple[int, int]:
|
||||
current_height = npatches * num_patch_height
|
||||
current_width = npatches * num_patch_width
|
||||
|
||||
aspect_ratio: float = width / height
|
||||
aspect_ratio: float = original_width / original_height
|
||||
current_aspect_ratio: float = current_width / current_height
|
||||
|
||||
if aspect_ratio > current_aspect_ratio:
|
||||
new_height = (height * current_width) // width
|
||||
current_height = new_height
|
||||
new_height = (original_height * current_width) // original_width
|
||||
padding = (current_height - new_height) // 2
|
||||
current_height = current_height - (2 * padding)
|
||||
else:
|
||||
new_width = (width * current_height) // height
|
||||
current_width = new_width
|
||||
new_width = (original_width * current_height) // original_height
|
||||
padding = (current_width - new_width) // 2
|
||||
current_width = current_width - (2 * padding)
|
||||
|
||||
unpadded_features = current_height * current_width
|
||||
newline_features = current_height
|
||||
|
@ -96,7 +113,9 @@ def get_number_of_features(height: int, width: int, config) -> int:
|
|||
|
||||
npatches = image_size // patch_size
|
||||
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
# Dimensions are intentionally swapped to be bug-compatible with
|
||||
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
|
||||
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
||||
[height, width],
|
||||
image_grid_pinpoints,
|
||||
image_size,
|
||||
|
@ -168,9 +187,13 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||
if chunk_type == "text":
|
||||
full_text += chunk.text
|
||||
elif chunk_type == "image":
|
||||
full_text += image_text_replacement(image_inputs, config, image_id)
|
||||
full_text += image_text_replacement(
|
||||
processor, image_inputs, config, image_id
|
||||
)
|
||||
image_id += 1
|
||||
|
||||
full_text = image_text_replacement_fixup(config, full_text)
|
||||
|
||||
batch_inputs.append(full_text)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
|
||||
|
|
Loading…
Reference in New Issue