router: send the input as chunks to the backend
Before this change, the generation input was sent to the backend as a single string, encoding images as Base64 and packing them in Markdown-style links. This change adds a new chunked input representation that separates text chunks from images chunks. Image chunks contain binary data (for smaller message sizes) and the image's MIME type. The stringly-typed inputs are still sent to support backends that do not support chunked inputs yet.
This commit is contained in:
parent
d1d724b027
commit
df71aafdcc
|
@ -3554,6 +3554,7 @@ dependencies = [
|
|||
name = "text-generation-client"
|
||||
version = "2.0.5-dev0"
|
||||
dependencies = [
|
||||
"base64 0.22.1",
|
||||
"futures",
|
||||
"grpc-metadata",
|
||||
"prost 0.12.6",
|
||||
|
|
|
@ -15,6 +15,7 @@ authors = ["Olivier Dehaene"]
|
|||
homepage = "https://github.com/huggingface/text-generation-inference"
|
||||
|
||||
[workspace.dependencies]
|
||||
base64 = "0.22.0"
|
||||
tokenizers = { version = "0.19.1", features = ["http"] }
|
||||
hf-hub = { version = "0.3.1", features = ["tokio"] }
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use std::time::{Duration, Instant};
|
||||
use text_generation_client::{
|
||||
Batch, CachedBatch, ClientError, NextTokenChooserParameters, Request, ShardedClient,
|
||||
StoppingCriteriaParameters,
|
||||
Batch, CachedBatch, Chunk, ClientError, Input, NextTokenChooserParameters, Request,
|
||||
ShardedClient, StoppingCriteriaParameters,
|
||||
};
|
||||
use tokenizers::{Tokenizer, TruncationDirection};
|
||||
use tokio::sync::{broadcast, mpsc};
|
||||
|
@ -142,6 +142,9 @@ async fn prefill(
|
|||
.map(|id| Request {
|
||||
id: id.into(),
|
||||
prefill_logprobs: false,
|
||||
input_chunks: Some(Input {
|
||||
chunks: vec![Chunk::Text(sequence.clone()).into()],
|
||||
}),
|
||||
inputs: sequence.clone(),
|
||||
truncate: sequence_length,
|
||||
parameters: Some(parameters.clone()),
|
||||
|
|
|
@ -51,6 +51,27 @@ message ClearCacheRequest {
|
|||
/// Empty response
|
||||
message ClearCacheResponse {}
|
||||
|
||||
message Image {
|
||||
/// Binary image data.
|
||||
bytes data = 1;
|
||||
|
||||
/// Image MIME type.
|
||||
string mimetype = 2;
|
||||
}
|
||||
|
||||
message InputChunk {
|
||||
oneof chunk {
|
||||
/// Plain text data
|
||||
string text = 1;
|
||||
/// Image data
|
||||
Image image = 2;
|
||||
}
|
||||
}
|
||||
|
||||
message Input {
|
||||
repeated InputChunk chunks = 1;
|
||||
}
|
||||
|
||||
enum GrammarType {
|
||||
GRAMMAR_TYPE_NONE = 0;
|
||||
GRAMMAR_TYPE_JSON = 1;
|
||||
|
@ -95,7 +116,9 @@ message StoppingCriteriaParameters {
|
|||
message Request {
|
||||
/// Request ID
|
||||
uint64 id = 1;
|
||||
/// The generation context
|
||||
/// The generation context as chunks
|
||||
Input input_chunks = 8;
|
||||
/// The generation context, stringified input_chunks
|
||||
string inputs = 2;
|
||||
/// Context truncation
|
||||
uint32 truncate = 3;
|
||||
|
|
|
@ -49,7 +49,7 @@ futures-util = "0.3.30"
|
|||
regex = "1.10.3"
|
||||
once_cell = "1.19.0"
|
||||
image = "0.25.1"
|
||||
base64 = "0.22.0"
|
||||
base64 = { workspace = true }
|
||||
|
||||
[build-dependencies]
|
||||
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
|
||||
|
|
|
@ -6,6 +6,7 @@ authors.workspace = true
|
|||
homepage.workspace = true
|
||||
|
||||
[dependencies]
|
||||
base64 = { workspace = true }
|
||||
futures = "^0.3"
|
||||
grpc-metadata = { path = "../grpc-metadata" }
|
||||
prost = "^0.12"
|
||||
|
|
|
@ -1,13 +1,17 @@
|
|||
/// Single shard Client
|
||||
use crate::pb::generate::v2::text_generation_service_client::TextGenerationServiceClient;
|
||||
use crate::pb::generate::v2::*;
|
||||
use crate::Result;
|
||||
use crate::{Chunk, Result};
|
||||
use base64::engine::general_purpose::STANDARD;
|
||||
use base64::Engine;
|
||||
use grpc_metadata::InjectTelemetryContext;
|
||||
use std::cmp::min;
|
||||
use std::time::Duration;
|
||||
use tonic::transport::{Channel, Uri};
|
||||
use tracing::instrument;
|
||||
|
||||
static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
|
||||
|
||||
/// Text Generation Inference gRPC client
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Client {
|
||||
|
@ -113,18 +117,39 @@ impl Client {
|
|||
while n_tokens < max_prefill_tokens {
|
||||
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
||||
|
||||
let mut input_chunks = Vec::new();
|
||||
input_chunks
|
||||
.push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
|
||||
if n_tokens == 0 {
|
||||
input_chunks.push(
|
||||
Chunk::Image(Image {
|
||||
// Safe unwrap, because we control the data.
|
||||
data: STANDARD.decode(WARMUP_IMAGE_BASE64).unwrap(),
|
||||
mimetype: "image/jpeg;base64".to_string(),
|
||||
})
|
||||
.into(),
|
||||
);
|
||||
}
|
||||
|
||||
// Send stringly-typed inputs for compatibility for backends that haven't
|
||||
// been updated to support chunks.
|
||||
let mut inputs = String::new();
|
||||
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
||||
if n_tokens == 0 {
|
||||
// 1 request is enough to test vision heads.
|
||||
// Sending images on other queries messes up easily with truncation.
|
||||
inputs.push_str("![](data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=)");
|
||||
inputs.push_str(&format!(
|
||||
"![](data:image/jpeg;base64,{WARMUP_IMAGE_BASE64})",
|
||||
));
|
||||
}
|
||||
|
||||
requests.push(Request {
|
||||
id: 0,
|
||||
// We truncate the input on the server side to be sure that it has the correct size
|
||||
input_chunks: Some(Input {
|
||||
chunks: input_chunks,
|
||||
}),
|
||||
inputs,
|
||||
// We truncate the input on the server side to be sure that it has the correct size
|
||||
truncate,
|
||||
// Set sampling parameters to also take these ops into account in the max memory
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
|
|
|
@ -5,11 +5,14 @@ mod client;
|
|||
mod pb;
|
||||
mod sharded_client;
|
||||
|
||||
use base64::{engine::general_purpose::STANDARD, Engine};
|
||||
pub use client::Client;
|
||||
pub use pb::generate::v2::input_chunk::Chunk;
|
||||
pub use pb::generate::v2::HealthResponse;
|
||||
pub use pb::generate::v2::Image;
|
||||
pub use pb::generate::v2::InfoResponse as ShardInfo;
|
||||
pub use pb::generate::v2::{
|
||||
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
||||
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, Input, InputChunk,
|
||||
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
|
||||
};
|
||||
pub use sharded_client::ShardedClient;
|
||||
|
@ -44,3 +47,33 @@ impl From<transport::Error> for ClientError {
|
|||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, ClientError>;
|
||||
|
||||
// Small convenience re-wrapping of `Chunk`.
|
||||
impl From<Chunk> for InputChunk {
|
||||
fn from(chunk: Chunk) -> Self {
|
||||
InputChunk { chunk: Some(chunk) }
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert input chunks to a stringly-typed input for backwards
|
||||
/// compat for backends that haven't implemented chunked inputs.
|
||||
pub trait ChunksToString {
|
||||
/// Convert chunks to string.
|
||||
fn chunks_to_string(&self) -> String;
|
||||
}
|
||||
|
||||
impl ChunksToString for Vec<InputChunk> {
|
||||
fn chunks_to_string(&self) -> String {
|
||||
let mut output = String::new();
|
||||
self.iter().for_each(|c| match &c.chunk {
|
||||
Some(Chunk::Text(text)) => output.push_str(text),
|
||||
Some(Chunk::Image(Image { data, mimetype })) => {
|
||||
let encoded = STANDARD.encode(data);
|
||||
output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded))
|
||||
}
|
||||
// We don't create empty chunks, so this should be unreachable.
|
||||
None => unreachable!("Chunks should never be empty"),
|
||||
});
|
||||
output
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,9 +4,9 @@ use serde::{Deserialize, Serialize};
|
|||
#[serde(tag = "model_type")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct LlavaNext {
|
||||
text_config: TextConfig,
|
||||
vision_config: VisionConfig,
|
||||
image_grid_pinpoints: Vec<(usize, usize)>,
|
||||
pub(crate) text_config: TextConfig,
|
||||
pub(crate) vision_config: VisionConfig,
|
||||
pub(crate) image_grid_pinpoints: Vec<(usize, usize)>,
|
||||
}
|
||||
|
||||
fn get_anyres_image_grid_shape(
|
||||
|
@ -119,13 +119,13 @@ impl Idefics2 {
|
|||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct PaliTextConfig {
|
||||
num_image_tokens: usize,
|
||||
pub(crate) num_image_tokens: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct Paligemma {
|
||||
text_config: PaliTextConfig,
|
||||
pub(crate) text_config: PaliTextConfig,
|
||||
}
|
||||
|
||||
impl Paligemma {
|
||||
|
@ -175,8 +175,8 @@ pub struct TextConfig {}
|
|||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct VisionConfig {
|
||||
image_size: usize,
|
||||
patch_size: usize,
|
||||
pub(crate) image_size: usize,
|
||||
pub(crate) patch_size: usize,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::GrammarType as ProtoGrammarType;
|
||||
use text_generation_client::{
|
||||
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
|
||||
Batch, Input, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
|
||||
};
|
||||
use text_generation_client::{Chunk, GrammarType as ProtoGrammarType};
|
||||
|
||||
// Note: Request ids and batch ids cannot collide.
|
||||
const LIVENESS_ID: u64 = u64::MAX;
|
||||
|
@ -33,6 +33,9 @@ impl Health {
|
|||
// Dummy batch of 1 token and 1 generated token
|
||||
let liveness_request = Request {
|
||||
id: LIVENESS_ID,
|
||||
input_chunks: Some(Input {
|
||||
chunks: vec![Chunk::Text("liveness".into()).into()],
|
||||
}),
|
||||
inputs: "liveness".to_string(),
|
||||
truncate: 10,
|
||||
prefill_logprobs: false,
|
||||
|
|
|
@ -4,6 +4,8 @@ use crate::validation::ValidGenerateRequest;
|
|||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||
use std::cmp::min;
|
||||
use std::collections::VecDeque;
|
||||
use text_generation_client::ChunksToString;
|
||||
use text_generation_client::Input;
|
||||
use text_generation_client::{Batch, Request};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tokio::time::Instant;
|
||||
|
@ -278,7 +280,10 @@ impl State {
|
|||
batch_requests.push(Request {
|
||||
id,
|
||||
prefill_logprobs: entry.request.decoder_input_details,
|
||||
inputs: entry.request.inputs.clone(),
|
||||
input_chunks: Some(Input {
|
||||
chunks: entry.request.inputs.clone(),
|
||||
}),
|
||||
inputs: entry.request.inputs.chunks_to_string(),
|
||||
truncate: entry.request.truncate,
|
||||
parameters: Some(entry.request.parameters.clone()),
|
||||
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
|
||||
|
@ -366,7 +371,7 @@ mod tests {
|
|||
|
||||
let entry = Entry {
|
||||
request: ValidGenerateRequest {
|
||||
inputs: String::new(),
|
||||
inputs: vec![],
|
||||
input_length: 0,
|
||||
truncate: 0,
|
||||
decoder_input_details: false,
|
||||
|
|
|
@ -7,7 +7,8 @@ use rand::{thread_rng, Rng};
|
|||
use serde_json::Value;
|
||||
use std::io::Cursor;
|
||||
use text_generation_client::{
|
||||
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
|
||||
Chunk, GrammarType as ProtoGrammarType, Image, InputChunk, NextTokenChooserParameters,
|
||||
StoppingCriteriaParameters,
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
|
@ -89,7 +90,7 @@ impl Validation {
|
|||
&self,
|
||||
inputs: String,
|
||||
truncate: Option<usize>,
|
||||
) -> Result<Option<(tokenizers::Encoding, String)>, ValidationError> {
|
||||
) -> Result<Option<(tokenizers::Encoding, Vec<InputChunk>)>, ValidationError> {
|
||||
// If we have a fast tokenizer
|
||||
if let Some(sender) = &self.sender {
|
||||
// Create response channel
|
||||
|
@ -115,7 +116,7 @@ impl Validation {
|
|||
inputs: String,
|
||||
truncate: Option<usize>,
|
||||
max_new_tokens: Option<u32>,
|
||||
) -> Result<(String, usize, u32), ValidationError> {
|
||||
) -> Result<(Vec<InputChunk>, usize, u32), ValidationError> {
|
||||
// If we have a fast tokenizer
|
||||
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
|
||||
// Create response channel
|
||||
|
@ -178,7 +179,11 @@ impl Validation {
|
|||
// ));
|
||||
}
|
||||
|
||||
Ok((inputs, input_length, max_new_tokens))
|
||||
Ok((
|
||||
vec![Chunk::Text(inputs).into()],
|
||||
input_length,
|
||||
max_new_tokens,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -465,7 +470,7 @@ fn format_to_mimetype(format: ImageFormat) -> String {
|
|||
.to_string()
|
||||
}
|
||||
|
||||
fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> {
|
||||
fn fetch_image(input: &str) -> Result<(Vec<u8>, String, usize, usize), ValidationError> {
|
||||
if input.starts_with("![](http://") || input.starts_with("![](https://") {
|
||||
let url = &input["![](".len()..input.len() - 1];
|
||||
let data = reqwest::blocking::get(url)?.bytes()?;
|
||||
|
@ -476,9 +481,7 @@ fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> {
|
|||
let height: usize = img.height().try_into()?;
|
||||
let width: usize = img.width().try_into()?;
|
||||
let mimetype = format_to_mimetype(format);
|
||||
let encoded = STANDARD.encode(data);
|
||||
let data_uri = format!("![](data:{mimetype};base64,{encoded})");
|
||||
Ok((data_uri, height, width))
|
||||
Ok((data.to_vec(), mimetype, height, width))
|
||||
} else if input.starts_with("![](data:") {
|
||||
// Remove ![](....)
|
||||
let content = &input["![](data:".len()..input.len() - 1];
|
||||
|
@ -495,9 +498,9 @@ fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> {
|
|||
|
||||
let data = STANDARD.decode(content["base64,".len()..].as_bytes())?;
|
||||
let img = if let Some(format) = format_from_mimetype(mimetype) {
|
||||
ImageReader::with_format(Cursor::new(data), format).decode()?
|
||||
ImageReader::with_format(Cursor::new(&data), format).decode()?
|
||||
} else {
|
||||
ImageReader::new(Cursor::new(data))
|
||||
ImageReader::new(Cursor::new(&data))
|
||||
.with_guessed_format()
|
||||
.map_err(|_io_error| ValidationError::InvalidImageContent(content.to_string()))?
|
||||
.decode()?
|
||||
|
@ -505,7 +508,7 @@ fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> {
|
|||
|
||||
let height: usize = img.height().try_into()?;
|
||||
let width: usize = img.width().try_into()?;
|
||||
Ok((input.to_string(), height, width))
|
||||
Ok((data, mimetype.to_string(), height, width))
|
||||
} else {
|
||||
Err(ValidationError::InvalidImageContent(input.to_string()))
|
||||
}
|
||||
|
@ -513,113 +516,110 @@ fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> {
|
|||
|
||||
/// Get input length and optionally truncate it
|
||||
fn prepare_input(
|
||||
mut inputs: String,
|
||||
inputs: String,
|
||||
_truncate: Option<usize>,
|
||||
tokenizer: &Tokenizer,
|
||||
config: &Option<Config>,
|
||||
) -> Result<(tokenizers::Encoding, String), ValidationError> {
|
||||
) -> Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError> {
|
||||
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
||||
let tokenizer_query = match config {
|
||||
let (tokenizer_query, input_chunks) = match config {
|
||||
Some(Config::LlavaNext(config)) => {
|
||||
let mut modified_inputs = String::with_capacity(inputs.len());
|
||||
let mut input_chunks = Vec::new();
|
||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||
let mut start = 0;
|
||||
for chunk in RE.find_iter(&inputs) {
|
||||
let chunk_start = chunk.start();
|
||||
let chunk_end = chunk.end();
|
||||
if chunk_start != start {
|
||||
modified_inputs.push_str(&inputs[start..chunk_start]);
|
||||
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
|
||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||
}
|
||||
let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||
let slots = config.get_number_of_features(height, width);
|
||||
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
|
||||
tokenizer_query.push_str(&"<image>".repeat(slots));
|
||||
modified_inputs.push_str(&image_uri);
|
||||
start = chunk_end;
|
||||
}
|
||||
if start != inputs.len() - 1 {
|
||||
modified_inputs.push_str(&inputs[start..]);
|
||||
if start != inputs.len() {
|
||||
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
|
||||
tokenizer_query.push_str(&inputs[start..]);
|
||||
}
|
||||
inputs = modified_inputs;
|
||||
tokenizer_query
|
||||
(tokenizer_query, input_chunks)
|
||||
}
|
||||
Some(Config::Paligemma(config)) => {
|
||||
let mut modified_inputs = String::with_capacity(inputs.len());
|
||||
let mut input_chunks = Vec::new();
|
||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||
let mut start = 0;
|
||||
for chunk in RE.find_iter(&inputs) {
|
||||
let chunk_start = chunk.start();
|
||||
let chunk_end = chunk.end();
|
||||
if chunk_start != start {
|
||||
modified_inputs.push_str(&inputs[start..chunk_start]);
|
||||
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
|
||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||
}
|
||||
let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||
let slots = config.get_number_of_features(height, width);
|
||||
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
|
||||
tokenizer_query.push_str(&"<image>".repeat(slots));
|
||||
modified_inputs.push_str(&image_uri);
|
||||
start = chunk_end;
|
||||
}
|
||||
if start != inputs.len() - 1 {
|
||||
modified_inputs.push_str(&inputs[start..]);
|
||||
if start != inputs.len() {
|
||||
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
|
||||
tokenizer_query.push_str(&inputs[start..]);
|
||||
}
|
||||
inputs = modified_inputs;
|
||||
tokenizer_query
|
||||
(tokenizer_query, input_chunks)
|
||||
}
|
||||
Some(Config::Idefics2(config)) => {
|
||||
let mut modified_inputs = String::with_capacity(inputs.len());
|
||||
let mut input_chunks = Vec::new();
|
||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||
let mut start = 0;
|
||||
for chunk in RE.find_iter(&inputs) {
|
||||
let chunk_start = chunk.start();
|
||||
let chunk_end = chunk.end();
|
||||
if chunk_start != start {
|
||||
modified_inputs.push_str(&inputs[start..chunk_start]);
|
||||
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
|
||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||
}
|
||||
let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||
let slots = config.get_number_of_features(height, width);
|
||||
tokenizer_query.push_str("<fake_token_around_image>");
|
||||
tokenizer_query.push_str(&"<image>".repeat(slots));
|
||||
tokenizer_query.push_str("<fake_token_around_image>");
|
||||
|
||||
modified_inputs.push_str(&image_uri);
|
||||
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
|
||||
start = chunk_end;
|
||||
}
|
||||
if start != inputs.len() - 1 {
|
||||
modified_inputs.push_str(&inputs[start..]);
|
||||
if start != inputs.len() {
|
||||
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
|
||||
tokenizer_query.push_str(&inputs[start..]);
|
||||
}
|
||||
inputs = modified_inputs;
|
||||
tokenizer_query
|
||||
(tokenizer_query, input_chunks)
|
||||
}
|
||||
Some(Config::Idefics) => {
|
||||
let mut modified_inputs = String::with_capacity(inputs.len());
|
||||
let mut input_chunks = Vec::new();
|
||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||
let mut start = 0;
|
||||
for chunk in RE.find_iter(&inputs) {
|
||||
let chunk_start = chunk.start();
|
||||
let chunk_end = chunk.end();
|
||||
if chunk_start != start {
|
||||
modified_inputs.push_str(&inputs[start..chunk_start]);
|
||||
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
|
||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||
}
|
||||
let (image_uri, _height, _width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||
let (data, mimetype, _height, _width) =
|
||||
fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||
let slots = 1;
|
||||
tokenizer_query.push_str(&"<image>".repeat(slots));
|
||||
modified_inputs.push_str(&image_uri);
|
||||
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
|
||||
start = chunk_end;
|
||||
}
|
||||
if start != inputs.len() - 1 {
|
||||
modified_inputs.push_str(&inputs[start..]);
|
||||
if start != inputs.len() {
|
||||
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
|
||||
tokenizer_query.push_str(&inputs[start..]);
|
||||
}
|
||||
inputs = modified_inputs;
|
||||
tokenizer_query
|
||||
(tokenizer_query, input_chunks)
|
||||
}
|
||||
_ => inputs.clone(),
|
||||
_ => (inputs.clone(), vec![Chunk::Text(inputs).into()]),
|
||||
};
|
||||
|
||||
// Get the number of tokens in the input
|
||||
|
@ -627,18 +627,18 @@ fn prepare_input(
|
|||
.encode(tokenizer_query, true)
|
||||
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||
|
||||
Ok((encoding, inputs))
|
||||
Ok((encoding, input_chunks))
|
||||
}
|
||||
|
||||
type TokenizerRequest = (
|
||||
(String, Option<usize>),
|
||||
oneshot::Sender<Result<(tokenizers::Encoding, String), ValidationError>>,
|
||||
oneshot::Sender<Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError>>,
|
||||
Span,
|
||||
);
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct ValidGenerateRequest {
|
||||
pub inputs: String,
|
||||
pub inputs: Vec<InputChunk>,
|
||||
pub input_length: u32,
|
||||
pub truncate: u32,
|
||||
pub decoder_input_details: bool,
|
||||
|
@ -714,6 +714,7 @@ pub enum ValidationError {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::{PaliTextConfig, Paligemma};
|
||||
use crate::default_parameters;
|
||||
use crate::tests::get_tokenizer;
|
||||
|
||||
|
@ -964,4 +965,61 @@ mod tests {
|
|||
|
||||
assert_eq!(valid_request.top_n_tokens, 0);
|
||||
}
|
||||
|
||||
static PIXEL_GIF: &str = "R0lGODdhAQABAIEAAP///wAAAAAAAAAAACwAAAAAAQABAAAIBAABBAQAOw==";
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_prepare_input_chunks() {
|
||||
let pixel_data = STANDARD.decode(PIXEL_GIF).unwrap();
|
||||
|
||||
let tokenizer = Some(get_tokenizer().await);
|
||||
|
||||
let max_best_of = 2;
|
||||
let max_stop_sequence = 3;
|
||||
let max_top_n_tokens = 4;
|
||||
let max_input_length = 5;
|
||||
let max_total_tokens = 6;
|
||||
let disable_grammar_support = true;
|
||||
let workers = 1;
|
||||
let config = Config::Paligemma(Paligemma {
|
||||
text_config: PaliTextConfig {
|
||||
num_image_tokens: 1,
|
||||
},
|
||||
});
|
||||
let validation = Validation::new(
|
||||
workers,
|
||||
tokenizer,
|
||||
Some(config),
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_top_n_tokens,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
disable_grammar_support,
|
||||
);
|
||||
|
||||
let chunks = match validation
|
||||
.tokenize(
|
||||
format!("test![](data:image/gif;base64,{})", PIXEL_GIF),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Some((_encoding, chunks))) => chunks,
|
||||
_ => panic!("Unexpected tokenization failure"),
|
||||
};
|
||||
|
||||
assert!(
|
||||
chunks
|
||||
== vec![
|
||||
Chunk::Text("test".to_string()).into(),
|
||||
Chunk::Image(Image {
|
||||
data: pixel_data.clone(),
|
||||
mimetype: "image/gif".to_string()
|
||||
})
|
||||
.into()
|
||||
],
|
||||
"Failed to process images",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue