Support different image sizes in prefill in VLMs (#2065)
When a batch contained images if different sizes during prefill, the server would fail (see e.g. #2056). Images were processed separately and then concatenated. However, this can fail for images with different sizes. Fix this by preprocessing all images in the batch together, so that the image processor can ensure that all image tensors have compatible sizes.
This commit is contained in:
parent
445f313504
commit
e903770897
|
@ -0,0 +1,61 @@
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "eos_token",
|
||||||
|
"generated_tokens": 8,
|
||||||
|
"prefill": [],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 2502,
|
||||||
|
"logprob": -1.734375,
|
||||||
|
"special": false,
|
||||||
|
"text": "image"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2196,
|
||||||
|
"logprob": -0.5756836,
|
||||||
|
"special": false,
|
||||||
|
"text": " result"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 604,
|
||||||
|
"logprob": -0.007843018,
|
||||||
|
"special": false,
|
||||||
|
"text": " for"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 12254,
|
||||||
|
"logprob": -1.7167969,
|
||||||
|
"special": false,
|
||||||
|
"text": " chicken"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 611,
|
||||||
|
"logprob": -0.17053223,
|
||||||
|
"special": false,
|
||||||
|
"text": " on"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 573,
|
||||||
|
"logprob": -0.7626953,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8318,
|
||||||
|
"logprob": -0.02709961,
|
||||||
|
"special": false,
|
||||||
|
"text": " beach"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": -0.20739746,
|
||||||
|
"special": true,
|
||||||
|
"text": "<eos>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "image result for chicken on the beach"
|
||||||
|
}
|
|
@ -0,0 +1,85 @@
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "eos_token",
|
||||||
|
"generated_tokens": 12,
|
||||||
|
"prefill": [],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 450,
|
||||||
|
"logprob": -0.26342773,
|
||||||
|
"special": false,
|
||||||
|
"text": " The"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 21282,
|
||||||
|
"logprob": -0.01838684,
|
||||||
|
"special": false,
|
||||||
|
"text": " cow"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 322,
|
||||||
|
"logprob": -0.18041992,
|
||||||
|
"special": false,
|
||||||
|
"text": " and"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 521,
|
||||||
|
"logprob": -0.62841797,
|
||||||
|
"special": false,
|
||||||
|
"text": " ch"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 21475,
|
||||||
|
"logprob": -0.0037956238,
|
||||||
|
"special": false,
|
||||||
|
"text": "icken"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 526,
|
||||||
|
"logprob": -0.018737793,
|
||||||
|
"special": false,
|
||||||
|
"text": " are"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 373,
|
||||||
|
"logprob": -1.0820312,
|
||||||
|
"special": false,
|
||||||
|
"text": " on"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 263,
|
||||||
|
"logprob": -0.5083008,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 25695,
|
||||||
|
"logprob": -0.07128906,
|
||||||
|
"special": false,
|
||||||
|
"text": " beach"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29889,
|
||||||
|
"logprob": -0.12573242,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 32002,
|
||||||
|
"logprob": -0.0029792786,
|
||||||
|
"special": true,
|
||||||
|
"text": "<end_of_utterance>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"logprob": -0.00024962425,
|
||||||
|
"special": true,
|
||||||
|
"text": "</s>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " The cow and chicken are on a beach."
|
||||||
|
}
|
|
@ -0,0 +1,133 @@
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 20,
|
||||||
|
"prefill": [],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 415,
|
||||||
|
"logprob": -0.04421997,
|
||||||
|
"special": false,
|
||||||
|
"text": " The"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 12072,
|
||||||
|
"logprob": -0.13500977,
|
||||||
|
"special": false,
|
||||||
|
"text": " cow"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 349,
|
||||||
|
"logprob": -0.06750488,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6328,
|
||||||
|
"logprob": -0.6352539,
|
||||||
|
"special": false,
|
||||||
|
"text": " standing"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 356,
|
||||||
|
"logprob": -0.16186523,
|
||||||
|
"special": false,
|
||||||
|
"text": " on"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 272,
|
||||||
|
"logprob": -0.5078125,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 10305,
|
||||||
|
"logprob": -0.017913818,
|
||||||
|
"special": false,
|
||||||
|
"text": " beach"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 304,
|
||||||
|
"logprob": -1.5205078,
|
||||||
|
"special": false,
|
||||||
|
"text": " and"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 272,
|
||||||
|
"logprob": -0.029174805,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13088,
|
||||||
|
"logprob": -0.003479004,
|
||||||
|
"special": false,
|
||||||
|
"text": " chicken"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 349,
|
||||||
|
"logprob": -0.0035095215,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6398,
|
||||||
|
"logprob": -0.3088379,
|
||||||
|
"special": false,
|
||||||
|
"text": " sitting"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 356,
|
||||||
|
"logprob": -0.027755737,
|
||||||
|
"special": false,
|
||||||
|
"text": " on"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 264,
|
||||||
|
"logprob": -0.31884766,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 17972,
|
||||||
|
"logprob": -0.047943115,
|
||||||
|
"special": false,
|
||||||
|
"text": " pile"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 302,
|
||||||
|
"logprob": -0.0002925396,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2445,
|
||||||
|
"logprob": -0.02935791,
|
||||||
|
"special": false,
|
||||||
|
"text": " money"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28723,
|
||||||
|
"logprob": -0.031219482,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 32002,
|
||||||
|
"logprob": -0.00034475327,
|
||||||
|
"special": true,
|
||||||
|
"text": "<end_of_utterance>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"logprob": -1.1920929e-07,
|
||||||
|
"special": true,
|
||||||
|
"text": "</s>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " The cow is standing on the beach and the chicken is sitting on a pile of money."
|
||||||
|
}
|
|
@ -22,6 +22,12 @@ async def flash_pali_gemma(flash_pali_gemma_handle):
|
||||||
return flash_pali_gemma_handle.client
|
return flash_pali_gemma_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
def get_chicken():
|
||||||
|
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
||||||
|
encoded_string = base64.b64encode(image_file.read())
|
||||||
|
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||||
|
|
||||||
|
|
||||||
def get_cow_beach():
|
def get_cow_beach():
|
||||||
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
|
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
|
||||||
encoded_string = base64.b64encode(image_file.read())
|
encoded_string = base64.b64encode(image_file.read())
|
||||||
|
@ -37,3 +43,20 @@ async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
|
||||||
|
|
||||||
assert response.generated_text == "beach"
|
assert response.generated_text == "beach"
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_pali_gemma_two_images(flash_pali_gemma, response_snapshot):
|
||||||
|
chicken = get_chicken()
|
||||||
|
cow_beach = get_cow_beach()
|
||||||
|
response = await flash_pali_gemma.generate(
|
||||||
|
f"caption![]({chicken})![]({cow_beach})\n",
|
||||||
|
max_new_tokens=20,
|
||||||
|
)
|
||||||
|
# Is PaliGemma not able to handle two separate images? At least we
|
||||||
|
# get output showing that both images are used.
|
||||||
|
assert (
|
||||||
|
response.generated_text == "image result for chicken on the beach"
|
||||||
|
), f"{repr(response.generated_text)}"
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
|
@ -23,6 +23,12 @@ def get_chicken():
|
||||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_cow_beach():
|
||||||
|
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
|
||||||
|
encoded_string = base64.b64encode(image_file.read())
|
||||||
|
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_idefics(idefics, response_snapshot):
|
async def test_idefics(idefics, response_snapshot):
|
||||||
chicken = get_chicken()
|
chicken = get_chicken()
|
||||||
|
@ -39,6 +45,21 @@ async def test_idefics(idefics, response_snapshot):
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_idefics_two_images(idefics, response_snapshot):
|
||||||
|
chicken = get_chicken()
|
||||||
|
cow_beach = get_cow_beach()
|
||||||
|
response = await idefics.generate(
|
||||||
|
f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken?<end_of_utterance> \nAssistant:",
|
||||||
|
max_new_tokens=20,
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
response.generated_text == " The cow and chicken are on a beach."
|
||||||
|
), f"{repr(response.generated_text)}"
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_idefics_load(idefics, generate_load, response_snapshot):
|
async def test_idefics_load(idefics, generate_load, response_snapshot):
|
||||||
chicken = get_chicken()
|
chicken = get_chicken()
|
||||||
|
|
|
@ -9,6 +9,12 @@ def get_chicken():
|
||||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_cow_beach():
|
||||||
|
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
|
||||||
|
encoded_string = base64.b64encode(image_file.read())
|
||||||
|
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def flash_idefics2_next_handle(launcher):
|
def flash_idefics2_next_handle(launcher):
|
||||||
with launcher(
|
with launcher(
|
||||||
|
@ -38,6 +44,23 @@ async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_idefics2_two_images(flash_idefics2_next, response_snapshot):
|
||||||
|
chicken = get_chicken()
|
||||||
|
cow_beach = get_cow_beach()
|
||||||
|
response = await flash_idefics2_next.generate(
|
||||||
|
f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken?<end_of_utterance> \nAssistant:",
|
||||||
|
max_new_tokens=20,
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
response.generated_text
|
||||||
|
== " The cow is standing on the beach and the chicken is sitting on a pile of money."
|
||||||
|
), f"{repr(response.generated_text)}"
|
||||||
|
assert response.details.generated_tokens == 20
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_idefics2_next_all_params(flash_idefics2_next, response_snapshot):
|
async def test_flash_idefics2_next_all_params(flash_idefics2_next, response_snapshot):
|
||||||
|
|
|
@ -53,7 +53,9 @@ def image_text_replacement(image_input, config, image_id) -> str:
|
||||||
num_features = get_number_of_features(height, width, config)
|
num_features = get_number_of_features(height, width, config)
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
logger.info(f"Found {num_features} in image of resolution {height}x{width}")
|
logger.info(
|
||||||
|
f"Found {num_features} features in image of resolution {height}x{width}"
|
||||||
|
)
|
||||||
return "<image>" * num_features
|
return "<image>" * num_features
|
||||||
|
|
||||||
elif config.model_type == "paligemma":
|
elif config.model_type == "paligemma":
|
||||||
|
@ -133,23 +135,41 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||||
def batch_tokenized_inputs(
|
def batch_tokenized_inputs(
|
||||||
cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config
|
cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config
|
||||||
):
|
):
|
||||||
|
# Process images first. We need all of them so that the processor
|
||||||
|
# can make the image splits the same size. And we need the final
|
||||||
|
# sizes to insert correct number of image tokens.
|
||||||
|
images = []
|
||||||
|
for r in requests:
|
||||||
|
for chunk in r.input_chunks.chunks:
|
||||||
|
chunk_type = chunk.WhichOneof("chunk")
|
||||||
|
if chunk_type == "text":
|
||||||
|
pass
|
||||||
|
elif chunk_type == "image":
|
||||||
|
image = Image.open(BytesIO(chunk.image.data))
|
||||||
|
if config.model_type == "llava_next":
|
||||||
|
images.append(image)
|
||||||
|
else:
|
||||||
|
images.append([image])
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||||
|
|
||||||
|
if images:
|
||||||
|
image_inputs = processor.image_processor(images, return_tensors="pt")
|
||||||
|
else:
|
||||||
|
image_inputs = None
|
||||||
|
|
||||||
batch_inputs = []
|
batch_inputs = []
|
||||||
image_inputs = []
|
|
||||||
max_truncation = 0
|
max_truncation = 0
|
||||||
|
image_id = 0
|
||||||
for r in requests:
|
for r in requests:
|
||||||
full_text = ""
|
full_text = ""
|
||||||
image_id = 0
|
|
||||||
for chunk in r.input_chunks.chunks:
|
for chunk in r.input_chunks.chunks:
|
||||||
chunk_type = chunk.WhichOneof("chunk")
|
chunk_type = chunk.WhichOneof("chunk")
|
||||||
if chunk_type == "text":
|
if chunk_type == "text":
|
||||||
full_text += chunk.text
|
full_text += chunk.text
|
||||||
elif chunk_type == "image":
|
elif chunk_type == "image":
|
||||||
image = Image.open(BytesIO(chunk.image.data))
|
full_text += image_text_replacement(image_inputs, config, image_id)
|
||||||
image_input = processor.image_processor(image, return_tensors="pt")
|
image_id += 1
|
||||||
full_text += image_text_replacement(image_input, config, image_id)
|
|
||||||
image_inputs.append(image_input)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
|
||||||
|
|
||||||
batch_inputs.append(full_text)
|
batch_inputs.append(full_text)
|
||||||
max_truncation = max(max_truncation, r.truncate)
|
max_truncation = max(max_truncation, r.truncate)
|
||||||
|
@ -160,24 +180,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||||
max_length=max_truncation,
|
max_length=max_truncation,
|
||||||
add_special_tokens=not config.model_type == "paligemma",
|
add_special_tokens=not config.model_type == "paligemma",
|
||||||
)["input_ids"]
|
)["input_ids"]
|
||||||
if image_inputs:
|
|
||||||
image_input = image_inputs[0]
|
|
||||||
new_image_inputs = {
|
|
||||||
"pixel_values": torch.cat(
|
|
||||||
[img["pixel_values"] for img in image_inputs], dim=0
|
|
||||||
),
|
|
||||||
}
|
|
||||||
if "pixel_attention_mask" in image_input:
|
|
||||||
new_image_inputs["pixel_attention_mask"] = torch.cat(
|
|
||||||
[img["pixel_attention_mask"] for img in image_inputs], dim=0
|
|
||||||
)
|
|
||||||
if "image_sizes" in image_input:
|
|
||||||
new_image_inputs["image_sizes"] = torch.cat(
|
|
||||||
[img["image_sizes"] for img in image_inputs], dim=0
|
|
||||||
)
|
|
||||||
image_inputs = new_image_inputs
|
|
||||||
else:
|
|
||||||
image_inputs = None
|
|
||||||
return batch_tokenized_inputs, image_inputs
|
return batch_tokenized_inputs, image_inputs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
Loading…
Reference in New Issue