Inline images for multimodal models. (#1666)
This commit is contained in:
parent
66914f7b19
commit
f171bdc823
|
@ -3045,9 +3045,11 @@ dependencies = [
|
||||||
"minijinja",
|
"minijinja",
|
||||||
"ngrok",
|
"ngrok",
|
||||||
"nohash-hasher",
|
"nohash-hasher",
|
||||||
|
"once_cell",
|
||||||
"opentelemetry",
|
"opentelemetry",
|
||||||
"opentelemetry-otlp",
|
"opentelemetry-otlp",
|
||||||
"rand",
|
"rand",
|
||||||
|
"regex",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
|
|
@ -20,13 +20,14 @@ async def idefics(idefics_handle):
|
||||||
def get_chicken():
|
def get_chicken():
|
||||||
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
||||||
encoded_string = base64.b64encode(image_file.read())
|
encoded_string = base64.b64encode(image_file.read())
|
||||||
return f"data:image/png;base64,{encoded_string}"
|
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_idefics(idefics, response_snapshot):
|
async def test_idefics(idefics, response_snapshot):
|
||||||
|
chicken = get_chicken()
|
||||||
response = await idefics.generate(
|
response = await idefics.generate(
|
||||||
"User:![](https://huggingface.co/spaces/HuggingFaceM4/idefics_playground/resolve/main/example_images/chicken_on_money.png?download=true)Can you tell me a very short story based on the image?",
|
f"User:![]({chicken})Can you tell me a very short story based on the image?",
|
||||||
max_new_tokens=10,
|
max_new_tokens=10,
|
||||||
decoder_input_details=True,
|
decoder_input_details=True,
|
||||||
)
|
)
|
||||||
|
@ -37,9 +38,10 @@ async def test_idefics(idefics, response_snapshot):
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_idefics_load(idefics, generate_load, response_snapshot):
|
async def test_idefics_load(idefics, generate_load, response_snapshot):
|
||||||
|
chicken = get_chicken()
|
||||||
responses = await generate_load(
|
responses = await generate_load(
|
||||||
idefics,
|
idefics,
|
||||||
"User:![](https://huggingface.co/spaces/HuggingFaceM4/idefics_playground/resolve/main/example_images/chicken_on_money.png?download=true)Can you tell me a very short story based on the image?",
|
f"User:![]({chicken})Can you tell me a very short story based on the image?",
|
||||||
max_new_tokens=10,
|
max_new_tokens=10,
|
||||||
n=4,
|
n=4,
|
||||||
)
|
)
|
||||||
|
|
|
@ -46,6 +46,8 @@ ngrok = { version = "0.13.1", features = ["axum"], optional = true }
|
||||||
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
|
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
|
||||||
minijinja = { git = "https://github.com/mitsuhiko/minijinja.git", branch = "main", commit = "5cd4efb" }
|
minijinja = { git = "https://github.com/mitsuhiko/minijinja.git", branch = "main", commit = "5cd4efb" }
|
||||||
futures-util = "0.3.30"
|
futures-util = "0.3.30"
|
||||||
|
regex = "1.10.3"
|
||||||
|
once_cell = "1.19.0"
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
|
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
|
||||||
|
|
|
@ -13,6 +13,7 @@ use tokenizers::TruncationDirection;
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
use tokio::sync::oneshot;
|
use tokio::sync::oneshot;
|
||||||
use tracing::{instrument, Span};
|
use tracing::{instrument, Span};
|
||||||
|
use {once_cell::sync::Lazy, regex::Regex};
|
||||||
|
|
||||||
/// Validation
|
/// Validation
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
|
@ -409,10 +410,14 @@ async fn round_robin_task(
|
||||||
/// Start tokenization workers
|
/// Start tokenization workers
|
||||||
fn tokenizer_worker(tokenizer: Tokenizer, mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>) {
|
fn tokenizer_worker(tokenizer: Tokenizer, mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>) {
|
||||||
// Loop over requests
|
// Loop over requests
|
||||||
|
let is_multimodal = {
|
||||||
|
let vocab = tokenizer.get_vocab(true);
|
||||||
|
vocab.contains_key("<image>")
|
||||||
|
};
|
||||||
while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() {
|
while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() {
|
||||||
parent_span.in_scope(|| {
|
parent_span.in_scope(|| {
|
||||||
response_tx
|
response_tx
|
||||||
.send(prepare_input(inputs, truncate, &tokenizer))
|
.send(prepare_input(inputs, truncate, &tokenizer, is_multimodal))
|
||||||
.unwrap_or(())
|
.unwrap_or(())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -423,15 +428,22 @@ fn prepare_input(
|
||||||
mut inputs: String,
|
mut inputs: String,
|
||||||
truncate: Option<usize>,
|
truncate: Option<usize>,
|
||||||
tokenizer: &Tokenizer,
|
tokenizer: &Tokenizer,
|
||||||
|
is_multimodal: bool,
|
||||||
) -> Result<(tokenizers::Encoding, String), ValidationError> {
|
) -> Result<(tokenizers::Encoding, String), ValidationError> {
|
||||||
|
let simplified_query = if is_multimodal {
|
||||||
|
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
||||||
|
RE.replace_all(&inputs, "<image>").into()
|
||||||
|
} else {
|
||||||
|
inputs.clone()
|
||||||
|
};
|
||||||
// Get the number of tokens in the input
|
// Get the number of tokens in the input
|
||||||
let mut encoding = tokenizer
|
let mut encoding = tokenizer
|
||||||
.encode(inputs.clone(), true)
|
.encode(simplified_query, true)
|
||||||
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||||
|
|
||||||
// Optionally truncate
|
// Optionally truncate
|
||||||
if let Some(truncate) = truncate {
|
if let Some(truncate) = truncate {
|
||||||
if truncate < encoding.len() {
|
if truncate < encoding.len() && !is_multimodal {
|
||||||
encoding.truncate(truncate, 0, TruncationDirection::Left);
|
encoding.truncate(truncate, 0, TruncationDirection::Left);
|
||||||
inputs = tokenizer
|
inputs = tokenizer
|
||||||
.decode(encoding.get_ids(), false)
|
.decode(encoding.get_ids(), false)
|
||||||
|
|
Loading…
Reference in New Issue