Inline images for multimodal models. (#1666)

This commit is contained in:
Nicolas Patry 2024-03-22 17:14:54 +01:00 committed by GitHub
parent 66914f7b19
commit f171bdc823
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 24 additions and 6 deletions

2
Cargo.lock generated
View File

@ -3045,9 +3045,11 @@ dependencies = [
"minijinja", "minijinja",
"ngrok", "ngrok",
"nohash-hasher", "nohash-hasher",
"once_cell",
"opentelemetry", "opentelemetry",
"opentelemetry-otlp", "opentelemetry-otlp",
"rand", "rand",
"regex",
"reqwest", "reqwest",
"serde", "serde",
"serde_json", "serde_json",

View File

@ -20,13 +20,14 @@ async def idefics(idefics_handle):
def get_chicken(): def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file: with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()) encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string}" return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_idefics(idefics, response_snapshot): async def test_idefics(idefics, response_snapshot):
chicken = get_chicken()
response = await idefics.generate( response = await idefics.generate(
"User:![](https://huggingface.co/spaces/HuggingFaceM4/idefics_playground/resolve/main/example_images/chicken_on_money.png?download=true)Can you tell me a very short story based on the image?", f"User:![]({chicken})Can you tell me a very short story based on the image?",
max_new_tokens=10, max_new_tokens=10,
decoder_input_details=True, decoder_input_details=True,
) )
@ -37,9 +38,10 @@ async def test_idefics(idefics, response_snapshot):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_idefics_load(idefics, generate_load, response_snapshot): async def test_idefics_load(idefics, generate_load, response_snapshot):
chicken = get_chicken()
responses = await generate_load( responses = await generate_load(
idefics, idefics,
"User:![](https://huggingface.co/spaces/HuggingFaceM4/idefics_playground/resolve/main/example_images/chicken_on_money.png?download=true)Can you tell me a very short story based on the image?", f"User:![]({chicken})Can you tell me a very short story based on the image?",
max_new_tokens=10, max_new_tokens=10,
n=4, n=4,
) )

View File

@ -46,6 +46,8 @@ ngrok = { version = "0.13.1", features = ["axum"], optional = true }
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
minijinja = { git = "https://github.com/mitsuhiko/minijinja.git", branch = "main", commit = "5cd4efb" } minijinja = { git = "https://github.com/mitsuhiko/minijinja.git", branch = "main", commit = "5cd4efb" }
futures-util = "0.3.30" futures-util = "0.3.30"
regex = "1.10.3"
once_cell = "1.19.0"
[build-dependencies] [build-dependencies]
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] } vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }

View File

@ -13,6 +13,7 @@ use tokenizers::TruncationDirection;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio::sync::oneshot; use tokio::sync::oneshot;
use tracing::{instrument, Span}; use tracing::{instrument, Span};
use {once_cell::sync::Lazy, regex::Regex};
/// Validation /// Validation
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -409,10 +410,14 @@ async fn round_robin_task(
/// Start tokenization workers /// Start tokenization workers
fn tokenizer_worker(tokenizer: Tokenizer, mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>) { fn tokenizer_worker(tokenizer: Tokenizer, mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>) {
// Loop over requests // Loop over requests
let is_multimodal = {
let vocab = tokenizer.get_vocab(true);
vocab.contains_key("<image>")
};
while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() { while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() {
parent_span.in_scope(|| { parent_span.in_scope(|| {
response_tx response_tx
.send(prepare_input(inputs, truncate, &tokenizer)) .send(prepare_input(inputs, truncate, &tokenizer, is_multimodal))
.unwrap_or(()) .unwrap_or(())
}) })
} }
@ -423,15 +428,22 @@ fn prepare_input(
mut inputs: String, mut inputs: String,
truncate: Option<usize>, truncate: Option<usize>,
tokenizer: &Tokenizer, tokenizer: &Tokenizer,
is_multimodal: bool,
) -> Result<(tokenizers::Encoding, String), ValidationError> { ) -> Result<(tokenizers::Encoding, String), ValidationError> {
let simplified_query = if is_multimodal {
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
RE.replace_all(&inputs, "<image>").into()
} else {
inputs.clone()
};
// Get the number of tokens in the input // Get the number of tokens in the input
let mut encoding = tokenizer let mut encoding = tokenizer
.encode(inputs.clone(), true) .encode(simplified_query, true)
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?; .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
// Optionally truncate // Optionally truncate
if let Some(truncate) = truncate { if let Some(truncate) = truncate {
if truncate < encoding.len() { if truncate < encoding.len() && !is_multimodal {
encoding.truncate(truncate, 0, TruncationDirection::Left); encoding.truncate(truncate, 0, TruncationDirection::Left);
inputs = tokenizer inputs = tokenizer
.decode(encoding.get_ids(), false) .decode(encoding.get_ids(), false)