router: send the input as chunks to the backend

Before this change, the generation input was sent to the backend as a
single string, encoding images as Base64 and packing them in
Markdown-style links.

This change adds a new chunked input representation that separates text
chunks from images chunks. Image chunks contain binary data (for smaller
message sizes) and the image's MIME type.

The stringly-typed inputs are still sent to support backends that do not
support chunked inputs yet.
This commit is contained in:
Daniël de Kok 2024-06-03 07:27:22 +00:00 committed by Daniël de Kok
parent d1d724b027
commit df71aafdcc
12 changed files with 222 additions and 69 deletions

1
Cargo.lock generated
View File

@ -3554,6 +3554,7 @@ dependencies = [
name = "text-generation-client"
version = "2.0.5-dev0"
dependencies = [
"base64 0.22.1",
"futures",
"grpc-metadata",
"prost 0.12.6",

View File

@ -15,6 +15,7 @@ authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference"
[workspace.dependencies]
base64 = "0.22.0"
tokenizers = { version = "0.19.1", features = ["http"] }
hf-hub = { version = "0.3.1", features = ["tokio"] }

View File

@ -1,7 +1,7 @@
use std::time::{Duration, Instant};
use text_generation_client::{
Batch, CachedBatch, ClientError, NextTokenChooserParameters, Request, ShardedClient,
StoppingCriteriaParameters,
Batch, CachedBatch, Chunk, ClientError, Input, NextTokenChooserParameters, Request,
ShardedClient, StoppingCriteriaParameters,
};
use tokenizers::{Tokenizer, TruncationDirection};
use tokio::sync::{broadcast, mpsc};
@ -142,6 +142,9 @@ async fn prefill(
.map(|id| Request {
id: id.into(),
prefill_logprobs: false,
input_chunks: Some(Input {
chunks: vec![Chunk::Text(sequence.clone()).into()],
}),
inputs: sequence.clone(),
truncate: sequence_length,
parameters: Some(parameters.clone()),

View File

@ -51,6 +51,27 @@ message ClearCacheRequest {
/// Empty response
message ClearCacheResponse {}
message Image {
/// Binary image data.
bytes data = 1;
/// Image MIME type.
string mimetype = 2;
}
message InputChunk {
oneof chunk {
/// Plain text data
string text = 1;
/// Image data
Image image = 2;
}
}
message Input {
repeated InputChunk chunks = 1;
}
enum GrammarType {
GRAMMAR_TYPE_NONE = 0;
GRAMMAR_TYPE_JSON = 1;
@ -95,7 +116,9 @@ message StoppingCriteriaParameters {
message Request {
/// Request ID
uint64 id = 1;
/// The generation context
/// The generation context as chunks
Input input_chunks = 8;
/// The generation context, stringified input_chunks
string inputs = 2;
/// Context truncation
uint32 truncate = 3;

View File

@ -49,7 +49,7 @@ futures-util = "0.3.30"
regex = "1.10.3"
once_cell = "1.19.0"
image = "0.25.1"
base64 = "0.22.0"
base64 = { workspace = true }
[build-dependencies]
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }

View File

@ -6,6 +6,7 @@ authors.workspace = true
homepage.workspace = true
[dependencies]
base64 = { workspace = true }
futures = "^0.3"
grpc-metadata = { path = "../grpc-metadata" }
prost = "^0.12"

View File

@ -1,13 +1,17 @@
/// Single shard Client
use crate::pb::generate::v2::text_generation_service_client::TextGenerationServiceClient;
use crate::pb::generate::v2::*;
use crate::Result;
use crate::{Chunk, Result};
use base64::engine::general_purpose::STANDARD;
use base64::Engine;
use grpc_metadata::InjectTelemetryContext;
use std::cmp::min;
use std::time::Duration;
use tonic::transport::{Channel, Uri};
use tracing::instrument;
static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
/// Text Generation Inference gRPC client
#[derive(Debug, Clone)]
pub struct Client {
@ -113,18 +117,39 @@ impl Client {
while n_tokens < max_prefill_tokens {
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
let mut input_chunks = Vec::new();
input_chunks
.push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
if n_tokens == 0 {
input_chunks.push(
Chunk::Image(Image {
// Safe unwrap, because we control the data.
data: STANDARD.decode(WARMUP_IMAGE_BASE64).unwrap(),
mimetype: "image/jpeg;base64".to_string(),
})
.into(),
);
}
// Send stringly-typed inputs for compatibility for backends that haven't
// been updated to support chunks.
let mut inputs = String::new();
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
if n_tokens == 0 {
// 1 request is enough to test vision heads.
// Sending images on other queries messes up easily with truncation.
inputs.push_str("![]()");
inputs.push_str(&format!(
"![](data:image/jpeg;base64,{WARMUP_IMAGE_BASE64})",
));
}
requests.push(Request {
id: 0,
// We truncate the input on the server side to be sure that it has the correct size
input_chunks: Some(Input {
chunks: input_chunks,
}),
inputs,
// We truncate the input on the server side to be sure that it has the correct size
truncate,
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {

View File

@ -5,11 +5,14 @@ mod client;
mod pb;
mod sharded_client;
use base64::{engine::general_purpose::STANDARD, Engine};
pub use client::Client;
pub use pb::generate::v2::input_chunk::Chunk;
pub use pb::generate::v2::HealthResponse;
pub use pb::generate::v2::Image;
pub use pb::generate::v2::InfoResponse as ShardInfo;
pub use pb::generate::v2::{
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, Input, InputChunk,
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
};
pub use sharded_client::ShardedClient;
@ -44,3 +47,33 @@ impl From<transport::Error> for ClientError {
}
pub type Result<T> = std::result::Result<T, ClientError>;
// Small convenience re-wrapping of `Chunk`.
impl From<Chunk> for InputChunk {
fn from(chunk: Chunk) -> Self {
InputChunk { chunk: Some(chunk) }
}
}
/// Convert input chunks to a stringly-typed input for backwards
/// compat for backends that haven't implemented chunked inputs.
pub trait ChunksToString {
/// Convert chunks to string.
fn chunks_to_string(&self) -> String;
}
impl ChunksToString for Vec<InputChunk> {
fn chunks_to_string(&self) -> String {
let mut output = String::new();
self.iter().for_each(|c| match &c.chunk {
Some(Chunk::Text(text)) => output.push_str(text),
Some(Chunk::Image(Image { data, mimetype })) => {
let encoded = STANDARD.encode(data);
output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded))
}
// We don't create empty chunks, so this should be unreachable.
None => unreachable!("Chunks should never be empty"),
});
output
}
}

View File

@ -4,9 +4,9 @@ use serde::{Deserialize, Serialize};
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")]
pub struct LlavaNext {
text_config: TextConfig,
vision_config: VisionConfig,
image_grid_pinpoints: Vec<(usize, usize)>,
pub(crate) text_config: TextConfig,
pub(crate) vision_config: VisionConfig,
pub(crate) image_grid_pinpoints: Vec<(usize, usize)>,
}
fn get_anyres_image_grid_shape(
@ -119,13 +119,13 @@ impl Idefics2 {
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct PaliTextConfig {
num_image_tokens: usize,
pub(crate) num_image_tokens: usize,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Paligemma {
text_config: PaliTextConfig,
pub(crate) text_config: PaliTextConfig,
}
impl Paligemma {
@ -175,8 +175,8 @@ pub struct TextConfig {}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct VisionConfig {
image_size: usize,
patch_size: usize,
pub(crate) image_size: usize,
pub(crate) patch_size: usize,
}
#[cfg(test)]

View File

@ -1,9 +1,9 @@
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use text_generation_client::GrammarType as ProtoGrammarType;
use text_generation_client::{
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
Batch, Input, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
};
use text_generation_client::{Chunk, GrammarType as ProtoGrammarType};
// Note: Request ids and batch ids cannot collide.
const LIVENESS_ID: u64 = u64::MAX;
@ -33,6 +33,9 @@ impl Health {
// Dummy batch of 1 token and 1 generated token
let liveness_request = Request {
id: LIVENESS_ID,
input_chunks: Some(Input {
chunks: vec![Chunk::Text("liveness".into()).into()],
}),
inputs: "liveness".to_string(),
truncate: 10,
prefill_logprobs: false,

View File

@ -4,6 +4,8 @@ use crate::validation::ValidGenerateRequest;
use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::min;
use std::collections::VecDeque;
use text_generation_client::ChunksToString;
use text_generation_client::Input;
use text_generation_client::{Batch, Request};
use tokio::sync::{mpsc, oneshot};
use tokio::time::Instant;
@ -278,7 +280,10 @@ impl State {
batch_requests.push(Request {
id,
prefill_logprobs: entry.request.decoder_input_details,
inputs: entry.request.inputs.clone(),
input_chunks: Some(Input {
chunks: entry.request.inputs.clone(),
}),
inputs: entry.request.inputs.chunks_to_string(),
truncate: entry.request.truncate,
parameters: Some(entry.request.parameters.clone()),
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
@ -366,7 +371,7 @@ mod tests {
let entry = Entry {
request: ValidGenerateRequest {
inputs: String::new(),
inputs: vec![],
input_length: 0,
truncate: 0,
decoder_input_details: false,

View File

@ -7,7 +7,8 @@ use rand::{thread_rng, Rng};
use serde_json::Value;
use std::io::Cursor;
use text_generation_client::{
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
Chunk, GrammarType as ProtoGrammarType, Image, InputChunk, NextTokenChooserParameters,
StoppingCriteriaParameters,
};
use thiserror::Error;
use tokenizers::tokenizer::Tokenizer;
@ -89,7 +90,7 @@ impl Validation {
&self,
inputs: String,
truncate: Option<usize>,
) -> Result<Option<(tokenizers::Encoding, String)>, ValidationError> {
) -> Result<Option<(tokenizers::Encoding, Vec<InputChunk>)>, ValidationError> {
// If we have a fast tokenizer
if let Some(sender) = &self.sender {
// Create response channel
@ -115,7 +116,7 @@ impl Validation {
inputs: String,
truncate: Option<usize>,
max_new_tokens: Option<u32>,
) -> Result<(String, usize, u32), ValidationError> {
) -> Result<(Vec<InputChunk>, usize, u32), ValidationError> {
// If we have a fast tokenizer
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
// Create response channel
@ -178,7 +179,11 @@ impl Validation {
// ));
}
Ok((inputs, input_length, max_new_tokens))
Ok((
vec![Chunk::Text(inputs).into()],
input_length,
max_new_tokens,
))
}
}
@ -465,7 +470,7 @@ fn format_to_mimetype(format: ImageFormat) -> String {
.to_string()
}
fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> {
fn fetch_image(input: &str) -> Result<(Vec<u8>, String, usize, usize), ValidationError> {
if input.starts_with("![](http://") || input.starts_with("![](https://") {
let url = &input["![](".len()..input.len() - 1];
let data = reqwest::blocking::get(url)?.bytes()?;
@ -476,9 +481,7 @@ fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> {
let height: usize = img.height().try_into()?;
let width: usize = img.width().try_into()?;
let mimetype = format_to_mimetype(format);
let encoded = STANDARD.encode(data);
let data_uri = format!("![](data:{mimetype};base64,{encoded})");
Ok((data_uri, height, width))
Ok((data.to_vec(), mimetype, height, width))
} else if input.starts_with("![](data:") {
// Remove ![](....)
let content = &input["![](data:".len()..input.len() - 1];
@ -495,9 +498,9 @@ fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> {
let data = STANDARD.decode(content["base64,".len()..].as_bytes())?;
let img = if let Some(format) = format_from_mimetype(mimetype) {
ImageReader::with_format(Cursor::new(data), format).decode()?
ImageReader::with_format(Cursor::new(&data), format).decode()?
} else {
ImageReader::new(Cursor::new(data))
ImageReader::new(Cursor::new(&data))
.with_guessed_format()
.map_err(|_io_error| ValidationError::InvalidImageContent(content.to_string()))?
.decode()?
@ -505,7 +508,7 @@ fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> {
let height: usize = img.height().try_into()?;
let width: usize = img.width().try_into()?;
Ok((input.to_string(), height, width))
Ok((data, mimetype.to_string(), height, width))
} else {
Err(ValidationError::InvalidImageContent(input.to_string()))
}
@ -513,113 +516,110 @@ fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> {
/// Get input length and optionally truncate it
fn prepare_input(
mut inputs: String,
inputs: String,
_truncate: Option<usize>,
tokenizer: &Tokenizer,
config: &Option<Config>,
) -> Result<(tokenizers::Encoding, String), ValidationError> {
) -> Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError> {
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
let tokenizer_query = match config {
let (tokenizer_query, input_chunks) = match config {
Some(Config::LlavaNext(config)) => {
let mut modified_inputs = String::with_capacity(inputs.len());
let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start();
let chunk_end = chunk.end();
if chunk_start != start {
modified_inputs.push_str(&inputs[start..chunk_start]);
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = config.get_number_of_features(height, width);
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
tokenizer_query.push_str(&"<image>".repeat(slots));
modified_inputs.push_str(&image_uri);
start = chunk_end;
}
if start != inputs.len() - 1 {
modified_inputs.push_str(&inputs[start..]);
if start != inputs.len() {
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
tokenizer_query.push_str(&inputs[start..]);
}
inputs = modified_inputs;
tokenizer_query
(tokenizer_query, input_chunks)
}
Some(Config::Paligemma(config)) => {
let mut modified_inputs = String::with_capacity(inputs.len());
let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start();
let chunk_end = chunk.end();
if chunk_start != start {
modified_inputs.push_str(&inputs[start..chunk_start]);
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = config.get_number_of_features(height, width);
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
tokenizer_query.push_str(&"<image>".repeat(slots));
modified_inputs.push_str(&image_uri);
start = chunk_end;
}
if start != inputs.len() - 1 {
modified_inputs.push_str(&inputs[start..]);
if start != inputs.len() {
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
tokenizer_query.push_str(&inputs[start..]);
}
inputs = modified_inputs;
tokenizer_query
(tokenizer_query, input_chunks)
}
Some(Config::Idefics2(config)) => {
let mut modified_inputs = String::with_capacity(inputs.len());
let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start();
let chunk_end = chunk.end();
if chunk_start != start {
modified_inputs.push_str(&inputs[start..chunk_start]);
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = config.get_number_of_features(height, width);
tokenizer_query.push_str("<fake_token_around_image>");
tokenizer_query.push_str(&"<image>".repeat(slots));
tokenizer_query.push_str("<fake_token_around_image>");
modified_inputs.push_str(&image_uri);
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
start = chunk_end;
}
if start != inputs.len() - 1 {
modified_inputs.push_str(&inputs[start..]);
if start != inputs.len() {
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
tokenizer_query.push_str(&inputs[start..]);
}
inputs = modified_inputs;
tokenizer_query
(tokenizer_query, input_chunks)
}
Some(Config::Idefics) => {
let mut modified_inputs = String::with_capacity(inputs.len());
let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start();
let chunk_end = chunk.end();
if chunk_start != start {
modified_inputs.push_str(&inputs[start..chunk_start]);
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (image_uri, _height, _width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let (data, mimetype, _height, _width) =
fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = 1;
tokenizer_query.push_str(&"<image>".repeat(slots));
modified_inputs.push_str(&image_uri);
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
start = chunk_end;
}
if start != inputs.len() - 1 {
modified_inputs.push_str(&inputs[start..]);
if start != inputs.len() {
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
tokenizer_query.push_str(&inputs[start..]);
}
inputs = modified_inputs;
tokenizer_query
(tokenizer_query, input_chunks)
}
_ => inputs.clone(),
_ => (inputs.clone(), vec![Chunk::Text(inputs).into()]),
};
// Get the number of tokens in the input
@ -627,18 +627,18 @@ fn prepare_input(
.encode(tokenizer_query, true)
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
Ok((encoding, inputs))
Ok((encoding, input_chunks))
}
type TokenizerRequest = (
(String, Option<usize>),
oneshot::Sender<Result<(tokenizers::Encoding, String), ValidationError>>,
oneshot::Sender<Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError>>,
Span,
);
#[derive(Debug, Clone)]
pub(crate) struct ValidGenerateRequest {
pub inputs: String,
pub inputs: Vec<InputChunk>,
pub input_length: u32,
pub truncate: u32,
pub decoder_input_details: bool,
@ -714,6 +714,7 @@ pub enum ValidationError {
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{PaliTextConfig, Paligemma};
use crate::default_parameters;
use crate::tests::get_tokenizer;
@ -964,4 +965,61 @@ mod tests {
assert_eq!(valid_request.top_n_tokens, 0);
}
static PIXEL_GIF: &str = "R0lGODdhAQABAIEAAP///wAAAAAAAAAAACwAAAAAAQABAAAIBAABBAQAOw==";
#[tokio::test]
async fn test_prepare_input_chunks() {
let pixel_data = STANDARD.decode(PIXEL_GIF).unwrap();
let tokenizer = Some(get_tokenizer().await);
let max_best_of = 2;
let max_stop_sequence = 3;
let max_top_n_tokens = 4;
let max_input_length = 5;
let max_total_tokens = 6;
let disable_grammar_support = true;
let workers = 1;
let config = Config::Paligemma(Paligemma {
text_config: PaliTextConfig {
num_image_tokens: 1,
},
});
let validation = Validation::new(
workers,
tokenizer,
Some(config),
max_best_of,
max_stop_sequence,
max_top_n_tokens,
max_input_length,
max_total_tokens,
disable_grammar_support,
);
let chunks = match validation
.tokenize(
format!("test![](data:image/gif;base64,{})", PIXEL_GIF),
None,
)
.await
{
Ok(Some((_encoding, chunks))) => chunks,
_ => panic!("Unexpected tokenization failure"),
};
assert!(
chunks
== vec![
Chunk::Text("test".to_string()).into(),
Chunk::Image(Image {
data: pixel_data.clone(),
mimetype: "image/gif".to_string()
})
.into()
],
"Failed to process images",
);
}
}