From e903770897ae80f9b9ea02ba02eac4c680fd6202 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 17 Jun 2024 10:49:41 +0200 Subject: [PATCH] Support different image sizes in prefill in VLMs (#2065) When a batch contained images if different sizes during prefill, the server would fail (see e.g. #2056). Images were processed separately and then concatenated. However, this can fail for images with different sizes. Fix this by preprocessing all images in the batch together, so that the image processor can ensure that all image tensors have compatible sizes. --- .../test_flash_pali_gemma_two_images.json | 61 ++++++++ .../test_idefics/test_idefics_two_images.json | 85 +++++++++++ .../test_flash_idefics2_two_images.json | 133 ++++++++++++++++++ .../models/test_flash_pali_gemma.py | 23 +++ integration-tests/models/test_idefics.py | 21 +++ integration-tests/models/test_idefics2.py | 23 +++ .../models/vlm_causal_lm.py | 57 ++++---- 7 files changed, 376 insertions(+), 27 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_flash_pali_gemma/test_flash_pali_gemma_two_images.json create mode 100644 integration-tests/models/__snapshots__/test_idefics/test_idefics_two_images.json create mode 100644 integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_two_images.json diff --git a/integration-tests/models/__snapshots__/test_flash_pali_gemma/test_flash_pali_gemma_two_images.json b/integration-tests/models/__snapshots__/test_flash_pali_gemma/test_flash_pali_gemma_two_images.json new file mode 100644 index 00000000..ab4f3015 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_pali_gemma/test_flash_pali_gemma_two_images.json @@ -0,0 +1,61 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "eos_token", + "generated_tokens": 8, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 2502, + "logprob": -1.734375, + "special": false, + "text": "image" + }, + { + "id": 2196, + "logprob": -0.5756836, + "special": false, + "text": " result" + }, + { + "id": 604, + "logprob": -0.007843018, + "special": false, + "text": " for" + }, + { + "id": 12254, + "logprob": -1.7167969, + "special": false, + "text": " chicken" + }, + { + "id": 611, + "logprob": -0.17053223, + "special": false, + "text": " on" + }, + { + "id": 573, + "logprob": -0.7626953, + "special": false, + "text": " the" + }, + { + "id": 8318, + "logprob": -0.02709961, + "special": false, + "text": " beach" + }, + { + "id": 1, + "logprob": -0.20739746, + "special": true, + "text": "" + } + ], + "top_tokens": null + }, + "generated_text": "image result for chicken on the beach" +} diff --git a/integration-tests/models/__snapshots__/test_idefics/test_idefics_two_images.json b/integration-tests/models/__snapshots__/test_idefics/test_idefics_two_images.json new file mode 100644 index 00000000..a4727707 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_idefics/test_idefics_two_images.json @@ -0,0 +1,85 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "eos_token", + "generated_tokens": 12, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 450, + "logprob": -0.26342773, + "special": false, + "text": " The" + }, + { + "id": 21282, + "logprob": -0.01838684, + "special": false, + "text": " cow" + }, + { + "id": 322, + "logprob": -0.18041992, + "special": false, + "text": " and" + }, + { + "id": 521, + "logprob": -0.62841797, + "special": false, + "text": " ch" + }, + { + "id": 21475, + "logprob": -0.0037956238, + "special": false, + "text": "icken" + }, + { + "id": 526, + "logprob": -0.018737793, + "special": false, + "text": " are" + }, + { + "id": 373, + "logprob": -1.0820312, + "special": false, + "text": " on" + }, + { + "id": 263, + "logprob": -0.5083008, + "special": false, + "text": " a" + }, + { + "id": 25695, + "logprob": -0.07128906, + "special": false, + "text": " beach" + }, + { + "id": 29889, + "logprob": -0.12573242, + "special": false, + "text": "." + }, + { + "id": 32002, + "logprob": -0.0029792786, + "special": true, + "text": "" + }, + { + "id": 2, + "logprob": -0.00024962425, + "special": true, + "text": "" + } + ], + "top_tokens": null + }, + "generated_text": " The cow and chicken are on a beach." +} diff --git a/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_two_images.json b/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_two_images.json new file mode 100644 index 00000000..86c95b29 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_two_images.json @@ -0,0 +1,133 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 20, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 415, + "logprob": -0.04421997, + "special": false, + "text": " The" + }, + { + "id": 12072, + "logprob": -0.13500977, + "special": false, + "text": " cow" + }, + { + "id": 349, + "logprob": -0.06750488, + "special": false, + "text": " is" + }, + { + "id": 6328, + "logprob": -0.6352539, + "special": false, + "text": " standing" + }, + { + "id": 356, + "logprob": -0.16186523, + "special": false, + "text": " on" + }, + { + "id": 272, + "logprob": -0.5078125, + "special": false, + "text": " the" + }, + { + "id": 10305, + "logprob": -0.017913818, + "special": false, + "text": " beach" + }, + { + "id": 304, + "logprob": -1.5205078, + "special": false, + "text": " and" + }, + { + "id": 272, + "logprob": -0.029174805, + "special": false, + "text": " the" + }, + { + "id": 13088, + "logprob": -0.003479004, + "special": false, + "text": " chicken" + }, + { + "id": 349, + "logprob": -0.0035095215, + "special": false, + "text": " is" + }, + { + "id": 6398, + "logprob": -0.3088379, + "special": false, + "text": " sitting" + }, + { + "id": 356, + "logprob": -0.027755737, + "special": false, + "text": " on" + }, + { + "id": 264, + "logprob": -0.31884766, + "special": false, + "text": " a" + }, + { + "id": 17972, + "logprob": -0.047943115, + "special": false, + "text": " pile" + }, + { + "id": 302, + "logprob": -0.0002925396, + "special": false, + "text": " of" + }, + { + "id": 2445, + "logprob": -0.02935791, + "special": false, + "text": " money" + }, + { + "id": 28723, + "logprob": -0.031219482, + "special": false, + "text": "." + }, + { + "id": 32002, + "logprob": -0.00034475327, + "special": true, + "text": "" + }, + { + "id": 2, + "logprob": -1.1920929e-07, + "special": true, + "text": "" + } + ], + "top_tokens": null + }, + "generated_text": " The cow is standing on the beach and the chicken is sitting on a pile of money." +} diff --git a/integration-tests/models/test_flash_pali_gemma.py b/integration-tests/models/test_flash_pali_gemma.py index d4e83c9f..6be1750c 100644 --- a/integration-tests/models/test_flash_pali_gemma.py +++ b/integration-tests/models/test_flash_pali_gemma.py @@ -22,6 +22,12 @@ async def flash_pali_gemma(flash_pali_gemma_handle): return flash_pali_gemma_handle.client +def get_chicken(): + with open("integration-tests/images/chicken_on_money.png", "rb") as image_file: + encoded_string = base64.b64encode(image_file.read()) + return f"data:image/png;base64,{encoded_string.decode('utf-8')}" + + def get_cow_beach(): with open("integration-tests/images/cow_beach.png", "rb") as image_file: encoded_string = base64.b64encode(image_file.read()) @@ -37,3 +43,20 @@ async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot): assert response.generated_text == "beach" assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_pali_gemma_two_images(flash_pali_gemma, response_snapshot): + chicken = get_chicken() + cow_beach = get_cow_beach() + response = await flash_pali_gemma.generate( + f"caption![]({chicken})![]({cow_beach})\n", + max_new_tokens=20, + ) + # Is PaliGemma not able to handle two separate images? At least we + # get output showing that both images are used. + assert ( + response.generated_text == "image result for chicken on the beach" + ), f"{repr(response.generated_text)}" + assert response == response_snapshot diff --git a/integration-tests/models/test_idefics.py b/integration-tests/models/test_idefics.py index aeeaffa1..ac807b76 100644 --- a/integration-tests/models/test_idefics.py +++ b/integration-tests/models/test_idefics.py @@ -23,6 +23,12 @@ def get_chicken(): return f"data:image/png;base64,{encoded_string.decode('utf-8')}" +def get_cow_beach(): + with open("integration-tests/images/cow_beach.png", "rb") as image_file: + encoded_string = base64.b64encode(image_file.read()) + return f"data:image/png;base64,{encoded_string.decode('utf-8')}" + + @pytest.mark.asyncio async def test_idefics(idefics, response_snapshot): chicken = get_chicken() @@ -39,6 +45,21 @@ async def test_idefics(idefics, response_snapshot): assert response == response_snapshot +@pytest.mark.asyncio +@pytest.mark.private +async def test_idefics_two_images(idefics, response_snapshot): + chicken = get_chicken() + cow_beach = get_cow_beach() + response = await idefics.generate( + f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken? \nAssistant:", + max_new_tokens=20, + ) + assert ( + response.generated_text == " The cow and chicken are on a beach." + ), f"{repr(response.generated_text)}" + assert response == response_snapshot + + @pytest.mark.asyncio async def test_idefics_load(idefics, generate_load, response_snapshot): chicken = get_chicken() diff --git a/integration-tests/models/test_idefics2.py b/integration-tests/models/test_idefics2.py index d34cce34..9aaf6d8a 100644 --- a/integration-tests/models/test_idefics2.py +++ b/integration-tests/models/test_idefics2.py @@ -9,6 +9,12 @@ def get_chicken(): return f"data:image/png;base64,{encoded_string.decode('utf-8')}" +def get_cow_beach(): + with open("integration-tests/images/cow_beach.png", "rb") as image_file: + encoded_string = base64.b64encode(image_file.read()) + return f"data:image/png;base64,{encoded_string.decode('utf-8')}" + + @pytest.fixture(scope="module") def flash_idefics2_next_handle(launcher): with launcher( @@ -38,6 +44,23 @@ async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot assert response == response_snapshot +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_idefics2_two_images(flash_idefics2_next, response_snapshot): + chicken = get_chicken() + cow_beach = get_cow_beach() + response = await flash_idefics2_next.generate( + f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken? \nAssistant:", + max_new_tokens=20, + ) + assert ( + response.generated_text + == " The cow is standing on the beach and the chicken is sitting on a pile of money." + ), f"{repr(response.generated_text)}" + assert response.details.generated_tokens == 20 + assert response == response_snapshot + + @pytest.mark.asyncio @pytest.mark.private async def test_flash_idefics2_next_all_params(flash_idefics2_next, response_snapshot): diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 59a6fab1..8b5819d1 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -53,7 +53,9 @@ def image_text_replacement(image_input, config, image_id) -> str: num_features = get_number_of_features(height, width, config) from loguru import logger - logger.info(f"Found {num_features} in image of resolution {height}x{width}") + logger.info( + f"Found {num_features} features in image of resolution {height}x{width}" + ) return "" * num_features elif config.model_type == "paligemma": @@ -133,23 +135,41 @@ class VlmCausalLMBatch(FlashCausalLMBatch): def batch_tokenized_inputs( cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config ): + # Process images first. We need all of them so that the processor + # can make the image splits the same size. And we need the final + # sizes to insert correct number of image tokens. + images = [] + for r in requests: + for chunk in r.input_chunks.chunks: + chunk_type = chunk.WhichOneof("chunk") + if chunk_type == "text": + pass + elif chunk_type == "image": + image = Image.open(BytesIO(chunk.image.data)) + if config.model_type == "llava_next": + images.append(image) + else: + images.append([image]) + else: + raise RuntimeError(f"Invalid chunk type {chunk_type}") + + if images: + image_inputs = processor.image_processor(images, return_tensors="pt") + else: + image_inputs = None + batch_inputs = [] - image_inputs = [] max_truncation = 0 + image_id = 0 for r in requests: full_text = "" - image_id = 0 for chunk in r.input_chunks.chunks: chunk_type = chunk.WhichOneof("chunk") if chunk_type == "text": full_text += chunk.text elif chunk_type == "image": - image = Image.open(BytesIO(chunk.image.data)) - image_input = processor.image_processor(image, return_tensors="pt") - full_text += image_text_replacement(image_input, config, image_id) - image_inputs.append(image_input) - else: - raise RuntimeError(f"Invalid chunk type {chunk_type}") + full_text += image_text_replacement(image_inputs, config, image_id) + image_id += 1 batch_inputs.append(full_text) max_truncation = max(max_truncation, r.truncate) @@ -160,24 +180,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch): max_length=max_truncation, add_special_tokens=not config.model_type == "paligemma", )["input_ids"] - if image_inputs: - image_input = image_inputs[0] - new_image_inputs = { - "pixel_values": torch.cat( - [img["pixel_values"] for img in image_inputs], dim=0 - ), - } - if "pixel_attention_mask" in image_input: - new_image_inputs["pixel_attention_mask"] = torch.cat( - [img["pixel_attention_mask"] for img in image_inputs], dim=0 - ) - if "image_sizes" in image_input: - new_image_inputs["image_sizes"] = torch.cat( - [img["image_sizes"] for img in image_inputs], dim=0 - ) - image_inputs = new_image_inputs - else: - image_inputs = None + return batch_tokenized_inputs, image_inputs @classmethod