router: send the input as chunks to the backend
Before this change, the generation input was sent to the backend as a single string, encoding images as Base64 and packing them in Markdown-style links. This change adds a new chunked input representation that separates text chunks from images chunks. Image chunks contain binary data (for smaller message sizes) and the image's MIME type. The stringly-typed inputs are still sent to support backends that do not support chunked inputs yet.
This commit is contained in:
parent
d1d724b027
commit
df71aafdcc
|
@ -3554,6 +3554,7 @@ dependencies = [
|
||||||
name = "text-generation-client"
|
name = "text-generation-client"
|
||||||
version = "2.0.5-dev0"
|
version = "2.0.5-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"base64 0.22.1",
|
||||||
"futures",
|
"futures",
|
||||||
"grpc-metadata",
|
"grpc-metadata",
|
||||||
"prost 0.12.6",
|
"prost 0.12.6",
|
||||||
|
|
|
@ -15,6 +15,7 @@ authors = ["Olivier Dehaene"]
|
||||||
homepage = "https://github.com/huggingface/text-generation-inference"
|
homepage = "https://github.com/huggingface/text-generation-inference"
|
||||||
|
|
||||||
[workspace.dependencies]
|
[workspace.dependencies]
|
||||||
|
base64 = "0.22.0"
|
||||||
tokenizers = { version = "0.19.1", features = ["http"] }
|
tokenizers = { version = "0.19.1", features = ["http"] }
|
||||||
hf-hub = { version = "0.3.1", features = ["tokio"] }
|
hf-hub = { version = "0.3.1", features = ["tokio"] }
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use text_generation_client::{
|
use text_generation_client::{
|
||||||
Batch, CachedBatch, ClientError, NextTokenChooserParameters, Request, ShardedClient,
|
Batch, CachedBatch, Chunk, ClientError, Input, NextTokenChooserParameters, Request,
|
||||||
StoppingCriteriaParameters,
|
ShardedClient, StoppingCriteriaParameters,
|
||||||
};
|
};
|
||||||
use tokenizers::{Tokenizer, TruncationDirection};
|
use tokenizers::{Tokenizer, TruncationDirection};
|
||||||
use tokio::sync::{broadcast, mpsc};
|
use tokio::sync::{broadcast, mpsc};
|
||||||
|
@ -142,6 +142,9 @@ async fn prefill(
|
||||||
.map(|id| Request {
|
.map(|id| Request {
|
||||||
id: id.into(),
|
id: id.into(),
|
||||||
prefill_logprobs: false,
|
prefill_logprobs: false,
|
||||||
|
input_chunks: Some(Input {
|
||||||
|
chunks: vec![Chunk::Text(sequence.clone()).into()],
|
||||||
|
}),
|
||||||
inputs: sequence.clone(),
|
inputs: sequence.clone(),
|
||||||
truncate: sequence_length,
|
truncate: sequence_length,
|
||||||
parameters: Some(parameters.clone()),
|
parameters: Some(parameters.clone()),
|
||||||
|
|
|
@ -51,6 +51,27 @@ message ClearCacheRequest {
|
||||||
/// Empty response
|
/// Empty response
|
||||||
message ClearCacheResponse {}
|
message ClearCacheResponse {}
|
||||||
|
|
||||||
|
message Image {
|
||||||
|
/// Binary image data.
|
||||||
|
bytes data = 1;
|
||||||
|
|
||||||
|
/// Image MIME type.
|
||||||
|
string mimetype = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message InputChunk {
|
||||||
|
oneof chunk {
|
||||||
|
/// Plain text data
|
||||||
|
string text = 1;
|
||||||
|
/// Image data
|
||||||
|
Image image = 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
message Input {
|
||||||
|
repeated InputChunk chunks = 1;
|
||||||
|
}
|
||||||
|
|
||||||
enum GrammarType {
|
enum GrammarType {
|
||||||
GRAMMAR_TYPE_NONE = 0;
|
GRAMMAR_TYPE_NONE = 0;
|
||||||
GRAMMAR_TYPE_JSON = 1;
|
GRAMMAR_TYPE_JSON = 1;
|
||||||
|
@ -95,7 +116,9 @@ message StoppingCriteriaParameters {
|
||||||
message Request {
|
message Request {
|
||||||
/// Request ID
|
/// Request ID
|
||||||
uint64 id = 1;
|
uint64 id = 1;
|
||||||
/// The generation context
|
/// The generation context as chunks
|
||||||
|
Input input_chunks = 8;
|
||||||
|
/// The generation context, stringified input_chunks
|
||||||
string inputs = 2;
|
string inputs = 2;
|
||||||
/// Context truncation
|
/// Context truncation
|
||||||
uint32 truncate = 3;
|
uint32 truncate = 3;
|
||||||
|
|
|
@ -49,7 +49,7 @@ futures-util = "0.3.30"
|
||||||
regex = "1.10.3"
|
regex = "1.10.3"
|
||||||
once_cell = "1.19.0"
|
once_cell = "1.19.0"
|
||||||
image = "0.25.1"
|
image = "0.25.1"
|
||||||
base64 = "0.22.0"
|
base64 = { workspace = true }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
|
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
|
||||||
|
|
|
@ -6,6 +6,7 @@ authors.workspace = true
|
||||||
homepage.workspace = true
|
homepage.workspace = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
base64 = { workspace = true }
|
||||||
futures = "^0.3"
|
futures = "^0.3"
|
||||||
grpc-metadata = { path = "../grpc-metadata" }
|
grpc-metadata = { path = "../grpc-metadata" }
|
||||||
prost = "^0.12"
|
prost = "^0.12"
|
||||||
|
|
|
@ -1,13 +1,17 @@
|
||||||
/// Single shard Client
|
/// Single shard Client
|
||||||
use crate::pb::generate::v2::text_generation_service_client::TextGenerationServiceClient;
|
use crate::pb::generate::v2::text_generation_service_client::TextGenerationServiceClient;
|
||||||
use crate::pb::generate::v2::*;
|
use crate::pb::generate::v2::*;
|
||||||
use crate::Result;
|
use crate::{Chunk, Result};
|
||||||
|
use base64::engine::general_purpose::STANDARD;
|
||||||
|
use base64::Engine;
|
||||||
use grpc_metadata::InjectTelemetryContext;
|
use grpc_metadata::InjectTelemetryContext;
|
||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tonic::transport::{Channel, Uri};
|
use tonic::transport::{Channel, Uri};
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
|
||||||
|
static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
|
||||||
|
|
||||||
/// Text Generation Inference gRPC client
|
/// Text Generation Inference gRPC client
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Client {
|
pub struct Client {
|
||||||
|
@ -113,18 +117,39 @@ impl Client {
|
||||||
while n_tokens < max_prefill_tokens {
|
while n_tokens < max_prefill_tokens {
|
||||||
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
||||||
|
|
||||||
|
let mut input_chunks = Vec::new();
|
||||||
|
input_chunks
|
||||||
|
.push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
|
||||||
|
if n_tokens == 0 {
|
||||||
|
input_chunks.push(
|
||||||
|
Chunk::Image(Image {
|
||||||
|
// Safe unwrap, because we control the data.
|
||||||
|
data: STANDARD.decode(WARMUP_IMAGE_BASE64).unwrap(),
|
||||||
|
mimetype: "image/jpeg;base64".to_string(),
|
||||||
|
})
|
||||||
|
.into(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send stringly-typed inputs for compatibility for backends that haven't
|
||||||
|
// been updated to support chunks.
|
||||||
let mut inputs = String::new();
|
let mut inputs = String::new();
|
||||||
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
||||||
if n_tokens == 0 {
|
if n_tokens == 0 {
|
||||||
// 1 request is enough to test vision heads.
|
// 1 request is enough to test vision heads.
|
||||||
// Sending images on other queries messes up easily with truncation.
|
// Sending images on other queries messes up easily with truncation.
|
||||||
inputs.push_str("![]()");
|
inputs.push_str(&format!(
|
||||||
|
"![](data:image/jpeg;base64,{WARMUP_IMAGE_BASE64})",
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
requests.push(Request {
|
requests.push(Request {
|
||||||
id: 0,
|
id: 0,
|
||||||
// We truncate the input on the server side to be sure that it has the correct size
|
input_chunks: Some(Input {
|
||||||
|
chunks: input_chunks,
|
||||||
|
}),
|
||||||
inputs,
|
inputs,
|
||||||
|
// We truncate the input on the server side to be sure that it has the correct size
|
||||||
truncate,
|
truncate,
|
||||||
// Set sampling parameters to also take these ops into account in the max memory
|
// Set sampling parameters to also take these ops into account in the max memory
|
||||||
parameters: Some(NextTokenChooserParameters {
|
parameters: Some(NextTokenChooserParameters {
|
||||||
|
|
|
@ -5,11 +5,14 @@ mod client;
|
||||||
mod pb;
|
mod pb;
|
||||||
mod sharded_client;
|
mod sharded_client;
|
||||||
|
|
||||||
|
use base64::{engine::general_purpose::STANDARD, Engine};
|
||||||
pub use client::Client;
|
pub use client::Client;
|
||||||
|
pub use pb::generate::v2::input_chunk::Chunk;
|
||||||
pub use pb::generate::v2::HealthResponse;
|
pub use pb::generate::v2::HealthResponse;
|
||||||
|
pub use pb::generate::v2::Image;
|
||||||
pub use pb::generate::v2::InfoResponse as ShardInfo;
|
pub use pb::generate::v2::InfoResponse as ShardInfo;
|
||||||
pub use pb::generate::v2::{
|
pub use pb::generate::v2::{
|
||||||
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, Input, InputChunk,
|
||||||
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
|
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
|
||||||
};
|
};
|
||||||
pub use sharded_client::ShardedClient;
|
pub use sharded_client::ShardedClient;
|
||||||
|
@ -44,3 +47,33 @@ impl From<transport::Error> for ClientError {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type Result<T> = std::result::Result<T, ClientError>;
|
pub type Result<T> = std::result::Result<T, ClientError>;
|
||||||
|
|
||||||
|
// Small convenience re-wrapping of `Chunk`.
|
||||||
|
impl From<Chunk> for InputChunk {
|
||||||
|
fn from(chunk: Chunk) -> Self {
|
||||||
|
InputChunk { chunk: Some(chunk) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert input chunks to a stringly-typed input for backwards
|
||||||
|
/// compat for backends that haven't implemented chunked inputs.
|
||||||
|
pub trait ChunksToString {
|
||||||
|
/// Convert chunks to string.
|
||||||
|
fn chunks_to_string(&self) -> String;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChunksToString for Vec<InputChunk> {
|
||||||
|
fn chunks_to_string(&self) -> String {
|
||||||
|
let mut output = String::new();
|
||||||
|
self.iter().for_each(|c| match &c.chunk {
|
||||||
|
Some(Chunk::Text(text)) => output.push_str(text),
|
||||||
|
Some(Chunk::Image(Image { data, mimetype })) => {
|
||||||
|
let encoded = STANDARD.encode(data);
|
||||||
|
output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded))
|
||||||
|
}
|
||||||
|
// We don't create empty chunks, so this should be unreachable.
|
||||||
|
None => unreachable!("Chunks should never be empty"),
|
||||||
|
});
|
||||||
|
output
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -4,9 +4,9 @@ use serde::{Deserialize, Serialize};
|
||||||
#[serde(tag = "model_type")]
|
#[serde(tag = "model_type")]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
pub struct LlavaNext {
|
pub struct LlavaNext {
|
||||||
text_config: TextConfig,
|
pub(crate) text_config: TextConfig,
|
||||||
vision_config: VisionConfig,
|
pub(crate) vision_config: VisionConfig,
|
||||||
image_grid_pinpoints: Vec<(usize, usize)>,
|
pub(crate) image_grid_pinpoints: Vec<(usize, usize)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_anyres_image_grid_shape(
|
fn get_anyres_image_grid_shape(
|
||||||
|
@ -119,13 +119,13 @@ impl Idefics2 {
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
pub struct PaliTextConfig {
|
pub struct PaliTextConfig {
|
||||||
num_image_tokens: usize,
|
pub(crate) num_image_tokens: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
pub struct Paligemma {
|
pub struct Paligemma {
|
||||||
text_config: PaliTextConfig,
|
pub(crate) text_config: PaliTextConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Paligemma {
|
impl Paligemma {
|
||||||
|
@ -175,8 +175,8 @@ pub struct TextConfig {}
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
pub struct VisionConfig {
|
pub struct VisionConfig {
|
||||||
image_size: usize,
|
pub(crate) image_size: usize,
|
||||||
patch_size: usize,
|
pub(crate) patch_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use text_generation_client::GrammarType as ProtoGrammarType;
|
|
||||||
use text_generation_client::{
|
use text_generation_client::{
|
||||||
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
|
Batch, Input, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
|
||||||
};
|
};
|
||||||
|
use text_generation_client::{Chunk, GrammarType as ProtoGrammarType};
|
||||||
|
|
||||||
// Note: Request ids and batch ids cannot collide.
|
// Note: Request ids and batch ids cannot collide.
|
||||||
const LIVENESS_ID: u64 = u64::MAX;
|
const LIVENESS_ID: u64 = u64::MAX;
|
||||||
|
@ -33,6 +33,9 @@ impl Health {
|
||||||
// Dummy batch of 1 token and 1 generated token
|
// Dummy batch of 1 token and 1 generated token
|
||||||
let liveness_request = Request {
|
let liveness_request = Request {
|
||||||
id: LIVENESS_ID,
|
id: LIVENESS_ID,
|
||||||
|
input_chunks: Some(Input {
|
||||||
|
chunks: vec![Chunk::Text("liveness".into()).into()],
|
||||||
|
}),
|
||||||
inputs: "liveness".to_string(),
|
inputs: "liveness".to_string(),
|
||||||
truncate: 10,
|
truncate: 10,
|
||||||
prefill_logprobs: false,
|
prefill_logprobs: false,
|
||||||
|
|
|
@ -4,6 +4,8 @@ use crate::validation::ValidGenerateRequest;
|
||||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
|
use text_generation_client::ChunksToString;
|
||||||
|
use text_generation_client::Input;
|
||||||
use text_generation_client::{Batch, Request};
|
use text_generation_client::{Batch, Request};
|
||||||
use tokio::sync::{mpsc, oneshot};
|
use tokio::sync::{mpsc, oneshot};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
|
@ -278,7 +280,10 @@ impl State {
|
||||||
batch_requests.push(Request {
|
batch_requests.push(Request {
|
||||||
id,
|
id,
|
||||||
prefill_logprobs: entry.request.decoder_input_details,
|
prefill_logprobs: entry.request.decoder_input_details,
|
||||||
inputs: entry.request.inputs.clone(),
|
input_chunks: Some(Input {
|
||||||
|
chunks: entry.request.inputs.clone(),
|
||||||
|
}),
|
||||||
|
inputs: entry.request.inputs.chunks_to_string(),
|
||||||
truncate: entry.request.truncate,
|
truncate: entry.request.truncate,
|
||||||
parameters: Some(entry.request.parameters.clone()),
|
parameters: Some(entry.request.parameters.clone()),
|
||||||
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
|
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
|
||||||
|
@ -366,7 +371,7 @@ mod tests {
|
||||||
|
|
||||||
let entry = Entry {
|
let entry = Entry {
|
||||||
request: ValidGenerateRequest {
|
request: ValidGenerateRequest {
|
||||||
inputs: String::new(),
|
inputs: vec![],
|
||||||
input_length: 0,
|
input_length: 0,
|
||||||
truncate: 0,
|
truncate: 0,
|
||||||
decoder_input_details: false,
|
decoder_input_details: false,
|
||||||
|
|
|
@ -7,7 +7,8 @@ use rand::{thread_rng, Rng};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::io::Cursor;
|
use std::io::Cursor;
|
||||||
use text_generation_client::{
|
use text_generation_client::{
|
||||||
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
|
Chunk, GrammarType as ProtoGrammarType, Image, InputChunk, NextTokenChooserParameters,
|
||||||
|
StoppingCriteriaParameters,
|
||||||
};
|
};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
|
@ -89,7 +90,7 @@ impl Validation {
|
||||||
&self,
|
&self,
|
||||||
inputs: String,
|
inputs: String,
|
||||||
truncate: Option<usize>,
|
truncate: Option<usize>,
|
||||||
) -> Result<Option<(tokenizers::Encoding, String)>, ValidationError> {
|
) -> Result<Option<(tokenizers::Encoding, Vec<InputChunk>)>, ValidationError> {
|
||||||
// If we have a fast tokenizer
|
// If we have a fast tokenizer
|
||||||
if let Some(sender) = &self.sender {
|
if let Some(sender) = &self.sender {
|
||||||
// Create response channel
|
// Create response channel
|
||||||
|
@ -115,7 +116,7 @@ impl Validation {
|
||||||
inputs: String,
|
inputs: String,
|
||||||
truncate: Option<usize>,
|
truncate: Option<usize>,
|
||||||
max_new_tokens: Option<u32>,
|
max_new_tokens: Option<u32>,
|
||||||
) -> Result<(String, usize, u32), ValidationError> {
|
) -> Result<(Vec<InputChunk>, usize, u32), ValidationError> {
|
||||||
// If we have a fast tokenizer
|
// If we have a fast tokenizer
|
||||||
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
|
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
|
||||||
// Create response channel
|
// Create response channel
|
||||||
|
@ -178,7 +179,11 @@ impl Validation {
|
||||||
// ));
|
// ));
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok((inputs, input_length, max_new_tokens))
|
Ok((
|
||||||
|
vec![Chunk::Text(inputs).into()],
|
||||||
|
input_length,
|
||||||
|
max_new_tokens,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -465,7 +470,7 @@ fn format_to_mimetype(format: ImageFormat) -> String {
|
||||||
.to_string()
|
.to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> {
|
fn fetch_image(input: &str) -> Result<(Vec<u8>, String, usize, usize), ValidationError> {
|
||||||
if input.starts_with("![](http://") || input.starts_with("![](https://") {
|
if input.starts_with("![](http://") || input.starts_with("![](https://") {
|
||||||
let url = &input["![](".len()..input.len() - 1];
|
let url = &input["![](".len()..input.len() - 1];
|
||||||
let data = reqwest::blocking::get(url)?.bytes()?;
|
let data = reqwest::blocking::get(url)?.bytes()?;
|
||||||
|
@ -476,9 +481,7 @@ fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> {
|
||||||
let height: usize = img.height().try_into()?;
|
let height: usize = img.height().try_into()?;
|
||||||
let width: usize = img.width().try_into()?;
|
let width: usize = img.width().try_into()?;
|
||||||
let mimetype = format_to_mimetype(format);
|
let mimetype = format_to_mimetype(format);
|
||||||
let encoded = STANDARD.encode(data);
|
Ok((data.to_vec(), mimetype, height, width))
|
||||||
let data_uri = format!("![](data:{mimetype};base64,{encoded})");
|
|
||||||
Ok((data_uri, height, width))
|
|
||||||
} else if input.starts_with("![](data:") {
|
} else if input.starts_with("![](data:") {
|
||||||
// Remove ![](....)
|
// Remove ![](....)
|
||||||
let content = &input["![](data:".len()..input.len() - 1];
|
let content = &input["![](data:".len()..input.len() - 1];
|
||||||
|
@ -495,9 +498,9 @@ fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> {
|
||||||
|
|
||||||
let data = STANDARD.decode(content["base64,".len()..].as_bytes())?;
|
let data = STANDARD.decode(content["base64,".len()..].as_bytes())?;
|
||||||
let img = if let Some(format) = format_from_mimetype(mimetype) {
|
let img = if let Some(format) = format_from_mimetype(mimetype) {
|
||||||
ImageReader::with_format(Cursor::new(data), format).decode()?
|
ImageReader::with_format(Cursor::new(&data), format).decode()?
|
||||||
} else {
|
} else {
|
||||||
ImageReader::new(Cursor::new(data))
|
ImageReader::new(Cursor::new(&data))
|
||||||
.with_guessed_format()
|
.with_guessed_format()
|
||||||
.map_err(|_io_error| ValidationError::InvalidImageContent(content.to_string()))?
|
.map_err(|_io_error| ValidationError::InvalidImageContent(content.to_string()))?
|
||||||
.decode()?
|
.decode()?
|
||||||
|
@ -505,7 +508,7 @@ fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> {
|
||||||
|
|
||||||
let height: usize = img.height().try_into()?;
|
let height: usize = img.height().try_into()?;
|
||||||
let width: usize = img.width().try_into()?;
|
let width: usize = img.width().try_into()?;
|
||||||
Ok((input.to_string(), height, width))
|
Ok((data, mimetype.to_string(), height, width))
|
||||||
} else {
|
} else {
|
||||||
Err(ValidationError::InvalidImageContent(input.to_string()))
|
Err(ValidationError::InvalidImageContent(input.to_string()))
|
||||||
}
|
}
|
||||||
|
@ -513,113 +516,110 @@ fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> {
|
||||||
|
|
||||||
/// Get input length and optionally truncate it
|
/// Get input length and optionally truncate it
|
||||||
fn prepare_input(
|
fn prepare_input(
|
||||||
mut inputs: String,
|
inputs: String,
|
||||||
_truncate: Option<usize>,
|
_truncate: Option<usize>,
|
||||||
tokenizer: &Tokenizer,
|
tokenizer: &Tokenizer,
|
||||||
config: &Option<Config>,
|
config: &Option<Config>,
|
||||||
) -> Result<(tokenizers::Encoding, String), ValidationError> {
|
) -> Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError> {
|
||||||
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
||||||
let tokenizer_query = match config {
|
let (tokenizer_query, input_chunks) = match config {
|
||||||
Some(Config::LlavaNext(config)) => {
|
Some(Config::LlavaNext(config)) => {
|
||||||
let mut modified_inputs = String::with_capacity(inputs.len());
|
let mut input_chunks = Vec::new();
|
||||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||||
let mut start = 0;
|
let mut start = 0;
|
||||||
for chunk in RE.find_iter(&inputs) {
|
for chunk in RE.find_iter(&inputs) {
|
||||||
let chunk_start = chunk.start();
|
let chunk_start = chunk.start();
|
||||||
let chunk_end = chunk.end();
|
let chunk_end = chunk.end();
|
||||||
if chunk_start != start {
|
if chunk_start != start {
|
||||||
modified_inputs.push_str(&inputs[start..chunk_start]);
|
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
|
||||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||||
}
|
}
|
||||||
let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||||
let slots = config.get_number_of_features(height, width);
|
let slots = config.get_number_of_features(height, width);
|
||||||
|
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
|
||||||
tokenizer_query.push_str(&"<image>".repeat(slots));
|
tokenizer_query.push_str(&"<image>".repeat(slots));
|
||||||
modified_inputs.push_str(&image_uri);
|
|
||||||
start = chunk_end;
|
start = chunk_end;
|
||||||
}
|
}
|
||||||
if start != inputs.len() - 1 {
|
if start != inputs.len() {
|
||||||
modified_inputs.push_str(&inputs[start..]);
|
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
|
||||||
tokenizer_query.push_str(&inputs[start..]);
|
tokenizer_query.push_str(&inputs[start..]);
|
||||||
}
|
}
|
||||||
inputs = modified_inputs;
|
(tokenizer_query, input_chunks)
|
||||||
tokenizer_query
|
|
||||||
}
|
}
|
||||||
Some(Config::Paligemma(config)) => {
|
Some(Config::Paligemma(config)) => {
|
||||||
let mut modified_inputs = String::with_capacity(inputs.len());
|
let mut input_chunks = Vec::new();
|
||||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||||
let mut start = 0;
|
let mut start = 0;
|
||||||
for chunk in RE.find_iter(&inputs) {
|
for chunk in RE.find_iter(&inputs) {
|
||||||
let chunk_start = chunk.start();
|
let chunk_start = chunk.start();
|
||||||
let chunk_end = chunk.end();
|
let chunk_end = chunk.end();
|
||||||
if chunk_start != start {
|
if chunk_start != start {
|
||||||
modified_inputs.push_str(&inputs[start..chunk_start]);
|
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
|
||||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||||
}
|
}
|
||||||
let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||||
let slots = config.get_number_of_features(height, width);
|
let slots = config.get_number_of_features(height, width);
|
||||||
|
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
|
||||||
tokenizer_query.push_str(&"<image>".repeat(slots));
|
tokenizer_query.push_str(&"<image>".repeat(slots));
|
||||||
modified_inputs.push_str(&image_uri);
|
|
||||||
start = chunk_end;
|
start = chunk_end;
|
||||||
}
|
}
|
||||||
if start != inputs.len() - 1 {
|
if start != inputs.len() {
|
||||||
modified_inputs.push_str(&inputs[start..]);
|
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
|
||||||
tokenizer_query.push_str(&inputs[start..]);
|
tokenizer_query.push_str(&inputs[start..]);
|
||||||
}
|
}
|
||||||
inputs = modified_inputs;
|
(tokenizer_query, input_chunks)
|
||||||
tokenizer_query
|
|
||||||
}
|
}
|
||||||
Some(Config::Idefics2(config)) => {
|
Some(Config::Idefics2(config)) => {
|
||||||
let mut modified_inputs = String::with_capacity(inputs.len());
|
let mut input_chunks = Vec::new();
|
||||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||||
let mut start = 0;
|
let mut start = 0;
|
||||||
for chunk in RE.find_iter(&inputs) {
|
for chunk in RE.find_iter(&inputs) {
|
||||||
let chunk_start = chunk.start();
|
let chunk_start = chunk.start();
|
||||||
let chunk_end = chunk.end();
|
let chunk_end = chunk.end();
|
||||||
if chunk_start != start {
|
if chunk_start != start {
|
||||||
modified_inputs.push_str(&inputs[start..chunk_start]);
|
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
|
||||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||||
}
|
}
|
||||||
let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||||
let slots = config.get_number_of_features(height, width);
|
let slots = config.get_number_of_features(height, width);
|
||||||
tokenizer_query.push_str("<fake_token_around_image>");
|
tokenizer_query.push_str("<fake_token_around_image>");
|
||||||
tokenizer_query.push_str(&"<image>".repeat(slots));
|
tokenizer_query.push_str(&"<image>".repeat(slots));
|
||||||
tokenizer_query.push_str("<fake_token_around_image>");
|
tokenizer_query.push_str("<fake_token_around_image>");
|
||||||
|
|
||||||
modified_inputs.push_str(&image_uri);
|
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
|
||||||
start = chunk_end;
|
start = chunk_end;
|
||||||
}
|
}
|
||||||
if start != inputs.len() - 1 {
|
if start != inputs.len() {
|
||||||
modified_inputs.push_str(&inputs[start..]);
|
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
|
||||||
tokenizer_query.push_str(&inputs[start..]);
|
tokenizer_query.push_str(&inputs[start..]);
|
||||||
}
|
}
|
||||||
inputs = modified_inputs;
|
(tokenizer_query, input_chunks)
|
||||||
tokenizer_query
|
|
||||||
}
|
}
|
||||||
Some(Config::Idefics) => {
|
Some(Config::Idefics) => {
|
||||||
let mut modified_inputs = String::with_capacity(inputs.len());
|
let mut input_chunks = Vec::new();
|
||||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||||
let mut start = 0;
|
let mut start = 0;
|
||||||
for chunk in RE.find_iter(&inputs) {
|
for chunk in RE.find_iter(&inputs) {
|
||||||
let chunk_start = chunk.start();
|
let chunk_start = chunk.start();
|
||||||
let chunk_end = chunk.end();
|
let chunk_end = chunk.end();
|
||||||
if chunk_start != start {
|
if chunk_start != start {
|
||||||
modified_inputs.push_str(&inputs[start..chunk_start]);
|
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
|
||||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||||
}
|
}
|
||||||
let (image_uri, _height, _width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
let (data, mimetype, _height, _width) =
|
||||||
|
fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||||
let slots = 1;
|
let slots = 1;
|
||||||
tokenizer_query.push_str(&"<image>".repeat(slots));
|
tokenizer_query.push_str(&"<image>".repeat(slots));
|
||||||
modified_inputs.push_str(&image_uri);
|
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
|
||||||
start = chunk_end;
|
start = chunk_end;
|
||||||
}
|
}
|
||||||
if start != inputs.len() - 1 {
|
if start != inputs.len() {
|
||||||
modified_inputs.push_str(&inputs[start..]);
|
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
|
||||||
tokenizer_query.push_str(&inputs[start..]);
|
tokenizer_query.push_str(&inputs[start..]);
|
||||||
}
|
}
|
||||||
inputs = modified_inputs;
|
(tokenizer_query, input_chunks)
|
||||||
tokenizer_query
|
|
||||||
}
|
}
|
||||||
_ => inputs.clone(),
|
_ => (inputs.clone(), vec![Chunk::Text(inputs).into()]),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Get the number of tokens in the input
|
// Get the number of tokens in the input
|
||||||
|
@ -627,18 +627,18 @@ fn prepare_input(
|
||||||
.encode(tokenizer_query, true)
|
.encode(tokenizer_query, true)
|
||||||
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||||
|
|
||||||
Ok((encoding, inputs))
|
Ok((encoding, input_chunks))
|
||||||
}
|
}
|
||||||
|
|
||||||
type TokenizerRequest = (
|
type TokenizerRequest = (
|
||||||
(String, Option<usize>),
|
(String, Option<usize>),
|
||||||
oneshot::Sender<Result<(tokenizers::Encoding, String), ValidationError>>,
|
oneshot::Sender<Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError>>,
|
||||||
Span,
|
Span,
|
||||||
);
|
);
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub(crate) struct ValidGenerateRequest {
|
pub(crate) struct ValidGenerateRequest {
|
||||||
pub inputs: String,
|
pub inputs: Vec<InputChunk>,
|
||||||
pub input_length: u32,
|
pub input_length: u32,
|
||||||
pub truncate: u32,
|
pub truncate: u32,
|
||||||
pub decoder_input_details: bool,
|
pub decoder_input_details: bool,
|
||||||
|
@ -714,6 +714,7 @@ pub enum ValidationError {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use crate::config::{PaliTextConfig, Paligemma};
|
||||||
use crate::default_parameters;
|
use crate::default_parameters;
|
||||||
use crate::tests::get_tokenizer;
|
use crate::tests::get_tokenizer;
|
||||||
|
|
||||||
|
@ -964,4 +965,61 @@ mod tests {
|
||||||
|
|
||||||
assert_eq!(valid_request.top_n_tokens, 0);
|
assert_eq!(valid_request.top_n_tokens, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static PIXEL_GIF: &str = "R0lGODdhAQABAIEAAP///wAAAAAAAAAAACwAAAAAAQABAAAIBAABBAQAOw==";
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_prepare_input_chunks() {
|
||||||
|
let pixel_data = STANDARD.decode(PIXEL_GIF).unwrap();
|
||||||
|
|
||||||
|
let tokenizer = Some(get_tokenizer().await);
|
||||||
|
|
||||||
|
let max_best_of = 2;
|
||||||
|
let max_stop_sequence = 3;
|
||||||
|
let max_top_n_tokens = 4;
|
||||||
|
let max_input_length = 5;
|
||||||
|
let max_total_tokens = 6;
|
||||||
|
let disable_grammar_support = true;
|
||||||
|
let workers = 1;
|
||||||
|
let config = Config::Paligemma(Paligemma {
|
||||||
|
text_config: PaliTextConfig {
|
||||||
|
num_image_tokens: 1,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
let validation = Validation::new(
|
||||||
|
workers,
|
||||||
|
tokenizer,
|
||||||
|
Some(config),
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequence,
|
||||||
|
max_top_n_tokens,
|
||||||
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
|
disable_grammar_support,
|
||||||
|
);
|
||||||
|
|
||||||
|
let chunks = match validation
|
||||||
|
.tokenize(
|
||||||
|
format!("test![](data:image/gif;base64,{})", PIXEL_GIF),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(Some((_encoding, chunks))) => chunks,
|
||||||
|
_ => panic!("Unexpected tokenization failure"),
|
||||||
|
};
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
chunks
|
||||||
|
== vec![
|
||||||
|
Chunk::Text("test".to_string()).into(),
|
||||||
|
Chunk::Image(Image {
|
||||||
|
data: pixel_data.clone(),
|
||||||
|
mimetype: "image/gif".to_string()
|
||||||
|
})
|
||||||
|
.into()
|
||||||
|
],
|
||||||
|
"Failed to process images",
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue