From df71aafdccac59ba354d3f59f6498ce4c5075a76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Mon, 3 Jun 2024 07:27:22 +0000 Subject: [PATCH] router: send the input as chunks to the backend Before this change, the generation input was sent to the backend as a single string, encoding images as Base64 and packing them in Markdown-style links. This change adds a new chunked input representation that separates text chunks from images chunks. Image chunks contain binary data (for smaller message sizes) and the image's MIME type. The stringly-typed inputs are still sent to support backends that do not support chunked inputs yet. --- Cargo.lock | 1 + Cargo.toml | 1 + benchmark/src/generation.rs | 7 +- proto/generate.proto | 25 +++++- router/Cargo.toml | 2 +- router/client/Cargo.toml | 1 + router/client/src/client.rs | 31 ++++++- router/client/src/lib.rs | 35 +++++++- router/src/config.rs | 14 ++-- router/src/health.rs | 7 +- router/src/queue.rs | 9 +- router/src/validation.rs | 158 ++++++++++++++++++++++++------------ 12 files changed, 222 insertions(+), 69 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d58f4cb1..413ff8ab 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3554,6 +3554,7 @@ dependencies = [ name = "text-generation-client" version = "2.0.5-dev0" dependencies = [ + "base64 0.22.1", "futures", "grpc-metadata", "prost 0.12.6", diff --git a/Cargo.toml b/Cargo.toml index c5c6ca6e..16dd9423 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ authors = ["Olivier Dehaene"] homepage = "https://github.com/huggingface/text-generation-inference" [workspace.dependencies] +base64 = "0.22.0" tokenizers = { version = "0.19.1", features = ["http"] } hf-hub = { version = "0.3.1", features = ["tokio"] } diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index ea7c9778..8c07e62b 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -1,7 +1,7 @@ use std::time::{Duration, Instant}; use text_generation_client::{ - Batch, CachedBatch, ClientError, NextTokenChooserParameters, Request, ShardedClient, - StoppingCriteriaParameters, + Batch, CachedBatch, Chunk, ClientError, Input, NextTokenChooserParameters, Request, + ShardedClient, StoppingCriteriaParameters, }; use tokenizers::{Tokenizer, TruncationDirection}; use tokio::sync::{broadcast, mpsc}; @@ -142,6 +142,9 @@ async fn prefill( .map(|id| Request { id: id.into(), prefill_logprobs: false, + input_chunks: Some(Input { + chunks: vec![Chunk::Text(sequence.clone()).into()], + }), inputs: sequence.clone(), truncate: sequence_length, parameters: Some(parameters.clone()), diff --git a/proto/generate.proto b/proto/generate.proto index 6351e37f..f568d01c 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -51,6 +51,27 @@ message ClearCacheRequest { /// Empty response message ClearCacheResponse {} +message Image { + /// Binary image data. + bytes data = 1; + + /// Image MIME type. + string mimetype = 2; +} + +message InputChunk { + oneof chunk { + /// Plain text data + string text = 1; + /// Image data + Image image = 2; + } +} + +message Input { + repeated InputChunk chunks = 1; + } + enum GrammarType { GRAMMAR_TYPE_NONE = 0; GRAMMAR_TYPE_JSON = 1; @@ -95,7 +116,9 @@ message StoppingCriteriaParameters { message Request { /// Request ID uint64 id = 1; - /// The generation context + /// The generation context as chunks + Input input_chunks = 8; + /// The generation context, stringified input_chunks string inputs = 2; /// Context truncation uint32 truncate = 3; diff --git a/router/Cargo.toml b/router/Cargo.toml index fdfe1a5b..2e6264be 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -49,7 +49,7 @@ futures-util = "0.3.30" regex = "1.10.3" once_cell = "1.19.0" image = "0.25.1" -base64 = "0.22.0" +base64 = { workspace = true } [build-dependencies] vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] } diff --git a/router/client/Cargo.toml b/router/client/Cargo.toml index d0131784..abbde82d 100644 --- a/router/client/Cargo.toml +++ b/router/client/Cargo.toml @@ -6,6 +6,7 @@ authors.workspace = true homepage.workspace = true [dependencies] +base64 = { workspace = true } futures = "^0.3" grpc-metadata = { path = "../grpc-metadata" } prost = "^0.12" diff --git a/router/client/src/client.rs b/router/client/src/client.rs index e8035106..8b509d6b 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -1,13 +1,17 @@ /// Single shard Client use crate::pb::generate::v2::text_generation_service_client::TextGenerationServiceClient; use crate::pb::generate::v2::*; -use crate::Result; +use crate::{Chunk, Result}; +use base64::engine::general_purpose::STANDARD; +use base64::Engine; use grpc_metadata::InjectTelemetryContext; use std::cmp::min; use std::time::Duration; use tonic::transport::{Channel, Uri}; use tracing::instrument; +static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII="; + /// Text Generation Inference gRPC client #[derive(Debug, Clone)] pub struct Client { @@ -113,18 +117,39 @@ impl Client { while n_tokens < max_prefill_tokens { let truncate = min(max_input_length, max_prefill_tokens - n_tokens); + let mut input_chunks = Vec::new(); + input_chunks + .push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into()); + if n_tokens == 0 { + input_chunks.push( + Chunk::Image(Image { + // Safe unwrap, because we control the data. + data: STANDARD.decode(WARMUP_IMAGE_BASE64).unwrap(), + mimetype: "image/jpeg;base64".to_string(), + }) + .into(), + ); + } + + // Send stringly-typed inputs for compatibility for backends that haven't + // been updated to support chunks. let mut inputs = String::new(); inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); if n_tokens == 0 { // 1 request is enough to test vision heads. // Sending images on other queries messes up easily with truncation. - inputs.push_str("![](data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=)"); + inputs.push_str(&format!( + "![](data:image/jpeg;base64,{WARMUP_IMAGE_BASE64})", + )); } requests.push(Request { id: 0, - // We truncate the input on the server side to be sure that it has the correct size + input_chunks: Some(Input { + chunks: input_chunks, + }), inputs, + // We truncate the input on the server side to be sure that it has the correct size truncate, // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index 6782d9ff..9e9ef13b 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -5,11 +5,14 @@ mod client; mod pb; mod sharded_client; +use base64::{engine::general_purpose::STANDARD, Engine}; pub use client::Client; +pub use pb::generate::v2::input_chunk::Chunk; pub use pb::generate::v2::HealthResponse; +pub use pb::generate::v2::Image; pub use pb::generate::v2::InfoResponse as ShardInfo; pub use pb::generate::v2::{ - Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, + Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, Input, InputChunk, NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens, }; pub use sharded_client::ShardedClient; @@ -44,3 +47,33 @@ impl From for ClientError { } pub type Result = std::result::Result; + +// Small convenience re-wrapping of `Chunk`. +impl From for InputChunk { + fn from(chunk: Chunk) -> Self { + InputChunk { chunk: Some(chunk) } + } +} + +/// Convert input chunks to a stringly-typed input for backwards +/// compat for backends that haven't implemented chunked inputs. +pub trait ChunksToString { + /// Convert chunks to string. + fn chunks_to_string(&self) -> String; +} + +impl ChunksToString for Vec { + fn chunks_to_string(&self) -> String { + let mut output = String::new(); + self.iter().for_each(|c| match &c.chunk { + Some(Chunk::Text(text)) => output.push_str(text), + Some(Chunk::Image(Image { data, mimetype })) => { + let encoded = STANDARD.encode(data); + output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded)) + } + // We don't create empty chunks, so this should be unreachable. + None => unreachable!("Chunks should never be empty"), + }); + output + } +} diff --git a/router/src/config.rs b/router/src/config.rs index d27b1136..29fefd5b 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -4,9 +4,9 @@ use serde::{Deserialize, Serialize}; #[serde(tag = "model_type")] #[serde(rename_all = "snake_case")] pub struct LlavaNext { - text_config: TextConfig, - vision_config: VisionConfig, - image_grid_pinpoints: Vec<(usize, usize)>, + pub(crate) text_config: TextConfig, + pub(crate) vision_config: VisionConfig, + pub(crate) image_grid_pinpoints: Vec<(usize, usize)>, } fn get_anyres_image_grid_shape( @@ -119,13 +119,13 @@ impl Idefics2 { #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub struct PaliTextConfig { - num_image_tokens: usize, + pub(crate) num_image_tokens: usize, } #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub struct Paligemma { - text_config: PaliTextConfig, + pub(crate) text_config: PaliTextConfig, } impl Paligemma { @@ -175,8 +175,8 @@ pub struct TextConfig {} #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub struct VisionConfig { - image_size: usize, - patch_size: usize, + pub(crate) image_size: usize, + pub(crate) patch_size: usize, } #[cfg(test)] diff --git a/router/src/health.rs b/router/src/health.rs index b05b3094..121255b9 100644 --- a/router/src/health.rs +++ b/router/src/health.rs @@ -1,9 +1,9 @@ use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; -use text_generation_client::GrammarType as ProtoGrammarType; use text_generation_client::{ - Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters, + Batch, Input, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters, }; +use text_generation_client::{Chunk, GrammarType as ProtoGrammarType}; // Note: Request ids and batch ids cannot collide. const LIVENESS_ID: u64 = u64::MAX; @@ -33,6 +33,9 @@ impl Health { // Dummy batch of 1 token and 1 generated token let liveness_request = Request { id: LIVENESS_ID, + input_chunks: Some(Input { + chunks: vec![Chunk::Text("liveness".into()).into()], + }), inputs: "liveness".to_string(), truncate: 10, prefill_logprobs: false, diff --git a/router/src/queue.rs b/router/src/queue.rs index a32673dd..40692ffc 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -4,6 +4,8 @@ use crate::validation::ValidGenerateRequest; use nohash_hasher::{BuildNoHashHasher, IntMap}; use std::cmp::min; use std::collections::VecDeque; +use text_generation_client::ChunksToString; +use text_generation_client::Input; use text_generation_client::{Batch, Request}; use tokio::sync::{mpsc, oneshot}; use tokio::time::Instant; @@ -278,7 +280,10 @@ impl State { batch_requests.push(Request { id, prefill_logprobs: entry.request.decoder_input_details, - inputs: entry.request.inputs.clone(), + input_chunks: Some(Input { + chunks: entry.request.inputs.clone(), + }), + inputs: entry.request.inputs.chunks_to_string(), truncate: entry.request.truncate, parameters: Some(entry.request.parameters.clone()), stopping_parameters: Some(entry.request.stopping_parameters.clone()), @@ -366,7 +371,7 @@ mod tests { let entry = Entry { request: ValidGenerateRequest { - inputs: String::new(), + inputs: vec![], input_length: 0, truncate: 0, decoder_input_details: false, diff --git a/router/src/validation.rs b/router/src/validation.rs index 96b6cb27..863bb99b 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -7,7 +7,8 @@ use rand::{thread_rng, Rng}; use serde_json::Value; use std::io::Cursor; use text_generation_client::{ - GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters, + Chunk, GrammarType as ProtoGrammarType, Image, InputChunk, NextTokenChooserParameters, + StoppingCriteriaParameters, }; use thiserror::Error; use tokenizers::tokenizer::Tokenizer; @@ -89,7 +90,7 @@ impl Validation { &self, inputs: String, truncate: Option, - ) -> Result, ValidationError> { + ) -> Result)>, ValidationError> { // If we have a fast tokenizer if let Some(sender) = &self.sender { // Create response channel @@ -115,7 +116,7 @@ impl Validation { inputs: String, truncate: Option, max_new_tokens: Option, - ) -> Result<(String, usize, u32), ValidationError> { + ) -> Result<(Vec, usize, u32), ValidationError> { // If we have a fast tokenizer if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? { // Create response channel @@ -178,7 +179,11 @@ impl Validation { // )); } - Ok((inputs, input_length, max_new_tokens)) + Ok(( + vec![Chunk::Text(inputs).into()], + input_length, + max_new_tokens, + )) } } @@ -465,7 +470,7 @@ fn format_to_mimetype(format: ImageFormat) -> String { .to_string() } -fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> { +fn fetch_image(input: &str) -> Result<(Vec, String, usize, usize), ValidationError> { if input.starts_with("![](http://") || input.starts_with("![](https://") { let url = &input["![](".len()..input.len() - 1]; let data = reqwest::blocking::get(url)?.bytes()?; @@ -476,9 +481,7 @@ fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> { let height: usize = img.height().try_into()?; let width: usize = img.width().try_into()?; let mimetype = format_to_mimetype(format); - let encoded = STANDARD.encode(data); - let data_uri = format!("![](data:{mimetype};base64,{encoded})"); - Ok((data_uri, height, width)) + Ok((data.to_vec(), mimetype, height, width)) } else if input.starts_with("![](data:") { // Remove ![](....) let content = &input["![](data:".len()..input.len() - 1]; @@ -495,9 +498,9 @@ fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> { let data = STANDARD.decode(content["base64,".len()..].as_bytes())?; let img = if let Some(format) = format_from_mimetype(mimetype) { - ImageReader::with_format(Cursor::new(data), format).decode()? + ImageReader::with_format(Cursor::new(&data), format).decode()? } else { - ImageReader::new(Cursor::new(data)) + ImageReader::new(Cursor::new(&data)) .with_guessed_format() .map_err(|_io_error| ValidationError::InvalidImageContent(content.to_string()))? .decode()? @@ -505,7 +508,7 @@ fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> { let height: usize = img.height().try_into()?; let width: usize = img.width().try_into()?; - Ok((input.to_string(), height, width)) + Ok((data, mimetype.to_string(), height, width)) } else { Err(ValidationError::InvalidImageContent(input.to_string())) } @@ -513,113 +516,110 @@ fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> { /// Get input length and optionally truncate it fn prepare_input( - mut inputs: String, + inputs: String, _truncate: Option, tokenizer: &Tokenizer, config: &Option, -) -> Result<(tokenizers::Encoding, String), ValidationError> { +) -> Result<(tokenizers::Encoding, Vec), ValidationError> { static RE: Lazy = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap()); - let tokenizer_query = match config { + let (tokenizer_query, input_chunks) = match config { Some(Config::LlavaNext(config)) => { - let mut modified_inputs = String::with_capacity(inputs.len()); + let mut input_chunks = Vec::new(); let mut tokenizer_query = String::with_capacity(inputs.len()); let mut start = 0; for chunk in RE.find_iter(&inputs) { let chunk_start = chunk.start(); let chunk_end = chunk.end(); if chunk_start != start { - modified_inputs.push_str(&inputs[start..chunk_start]); + input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into()); tokenizer_query.push_str(&inputs[start..chunk_start]); } - let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?; + let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?; let slots = config.get_number_of_features(height, width); + input_chunks.push(Chunk::Image(Image { data, mimetype }).into()); tokenizer_query.push_str(&"".repeat(slots)); - modified_inputs.push_str(&image_uri); start = chunk_end; } - if start != inputs.len() - 1 { - modified_inputs.push_str(&inputs[start..]); + if start != inputs.len() { + input_chunks.push(Chunk::Text(inputs[start..].to_string()).into()); tokenizer_query.push_str(&inputs[start..]); } - inputs = modified_inputs; - tokenizer_query + (tokenizer_query, input_chunks) } Some(Config::Paligemma(config)) => { - let mut modified_inputs = String::with_capacity(inputs.len()); + let mut input_chunks = Vec::new(); let mut tokenizer_query = String::with_capacity(inputs.len()); let mut start = 0; for chunk in RE.find_iter(&inputs) { let chunk_start = chunk.start(); let chunk_end = chunk.end(); if chunk_start != start { - modified_inputs.push_str(&inputs[start..chunk_start]); + input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into()); tokenizer_query.push_str(&inputs[start..chunk_start]); } - let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?; + let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?; let slots = config.get_number_of_features(height, width); + input_chunks.push(Chunk::Image(Image { data, mimetype }).into()); tokenizer_query.push_str(&"".repeat(slots)); - modified_inputs.push_str(&image_uri); start = chunk_end; } - if start != inputs.len() - 1 { - modified_inputs.push_str(&inputs[start..]); + if start != inputs.len() { + input_chunks.push(Chunk::Text(inputs[start..].to_string()).into()); tokenizer_query.push_str(&inputs[start..]); } - inputs = modified_inputs; - tokenizer_query + (tokenizer_query, input_chunks) } Some(Config::Idefics2(config)) => { - let mut modified_inputs = String::with_capacity(inputs.len()); + let mut input_chunks = Vec::new(); let mut tokenizer_query = String::with_capacity(inputs.len()); let mut start = 0; for chunk in RE.find_iter(&inputs) { let chunk_start = chunk.start(); let chunk_end = chunk.end(); if chunk_start != start { - modified_inputs.push_str(&inputs[start..chunk_start]); + input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into()); tokenizer_query.push_str(&inputs[start..chunk_start]); } - let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?; + let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?; let slots = config.get_number_of_features(height, width); tokenizer_query.push_str(""); tokenizer_query.push_str(&"".repeat(slots)); tokenizer_query.push_str(""); - modified_inputs.push_str(&image_uri); + input_chunks.push(Chunk::Image(Image { data, mimetype }).into()); start = chunk_end; } - if start != inputs.len() - 1 { - modified_inputs.push_str(&inputs[start..]); + if start != inputs.len() { + input_chunks.push(Chunk::Text(inputs[start..].to_string()).into()); tokenizer_query.push_str(&inputs[start..]); } - inputs = modified_inputs; - tokenizer_query + (tokenizer_query, input_chunks) } Some(Config::Idefics) => { - let mut modified_inputs = String::with_capacity(inputs.len()); + let mut input_chunks = Vec::new(); let mut tokenizer_query = String::with_capacity(inputs.len()); let mut start = 0; for chunk in RE.find_iter(&inputs) { let chunk_start = chunk.start(); let chunk_end = chunk.end(); if chunk_start != start { - modified_inputs.push_str(&inputs[start..chunk_start]); + input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into()); tokenizer_query.push_str(&inputs[start..chunk_start]); } - let (image_uri, _height, _width) = fetch_image(&inputs[chunk_start..chunk_end])?; + let (data, mimetype, _height, _width) = + fetch_image(&inputs[chunk_start..chunk_end])?; let slots = 1; tokenizer_query.push_str(&"".repeat(slots)); - modified_inputs.push_str(&image_uri); + input_chunks.push(Chunk::Image(Image { data, mimetype }).into()); start = chunk_end; } - if start != inputs.len() - 1 { - modified_inputs.push_str(&inputs[start..]); + if start != inputs.len() { + input_chunks.push(Chunk::Text(inputs[start..].to_string()).into()); tokenizer_query.push_str(&inputs[start..]); } - inputs = modified_inputs; - tokenizer_query + (tokenizer_query, input_chunks) } - _ => inputs.clone(), + _ => (inputs.clone(), vec![Chunk::Text(inputs).into()]), }; // Get the number of tokens in the input @@ -627,18 +627,18 @@ fn prepare_input( .encode(tokenizer_query, true) .map_err(|err| ValidationError::Tokenizer(err.to_string()))?; - Ok((encoding, inputs)) + Ok((encoding, input_chunks)) } type TokenizerRequest = ( (String, Option), - oneshot::Sender>, + oneshot::Sender), ValidationError>>, Span, ); #[derive(Debug, Clone)] pub(crate) struct ValidGenerateRequest { - pub inputs: String, + pub inputs: Vec, pub input_length: u32, pub truncate: u32, pub decoder_input_details: bool, @@ -714,6 +714,7 @@ pub enum ValidationError { #[cfg(test)] mod tests { use super::*; + use crate::config::{PaliTextConfig, Paligemma}; use crate::default_parameters; use crate::tests::get_tokenizer; @@ -964,4 +965,61 @@ mod tests { assert_eq!(valid_request.top_n_tokens, 0); } + + static PIXEL_GIF: &str = "R0lGODdhAQABAIEAAP///wAAAAAAAAAAACwAAAAAAQABAAAIBAABBAQAOw=="; + + #[tokio::test] + async fn test_prepare_input_chunks() { + let pixel_data = STANDARD.decode(PIXEL_GIF).unwrap(); + + let tokenizer = Some(get_tokenizer().await); + + let max_best_of = 2; + let max_stop_sequence = 3; + let max_top_n_tokens = 4; + let max_input_length = 5; + let max_total_tokens = 6; + let disable_grammar_support = true; + let workers = 1; + let config = Config::Paligemma(Paligemma { + text_config: PaliTextConfig { + num_image_tokens: 1, + }, + }); + let validation = Validation::new( + workers, + tokenizer, + Some(config), + max_best_of, + max_stop_sequence, + max_top_n_tokens, + max_input_length, + max_total_tokens, + disable_grammar_support, + ); + + let chunks = match validation + .tokenize( + format!("test![](data:image/gif;base64,{})", PIXEL_GIF), + None, + ) + .await + { + Ok(Some((_encoding, chunks))) => chunks, + _ => panic!("Unexpected tokenization failure"), + }; + + assert!( + chunks + == vec![ + Chunk::Text("test".to_string()).into(), + Chunk::Image(Image { + data: pixel_data.clone(), + mimetype: "image/gif".to_string() + }) + .into() + ], + "Failed to process images", + ); + } }