Pali gemma modeling (#1895)
This PR adds paligemma modeling code
Blog post: https://huggingface.co/blog/paligemma
Transformers PR: https://github.com/huggingface/transformers/pull/30814
install the latest changes and run with
```bash
# get the weights
# text-generation-server download-weights gv-hf/PaliGemma-base-224px-hf
# run TGI
text-generation-launcher --model-id gv-hf/PaliGemma-base-224px-hf
```
basic example sending various requests
```python
from huggingface_hub import InferenceClient
client = InferenceClient("http://127.0.0.1:3000")
images = [
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png",
]
prompts = [
"What animal is in this image?",
"Name three colors in this image.",
"What are 10 colors in this image?",
"Where is the cow standing?",
"answer en Where is the cow standing?",
"Is there a bird in the image?",
"Is ther a cow in the image?",
"Is there a rabbit in the image?",
"how many birds are in the image?",
"how many rabbits are in the image?",
]
for img in images:
print(f"\nImage: {img.split('/')[-1]}")
for prompt in prompts:
inputs = f"![]({img}){prompt}\n"
json_data = {
"inputs": inputs,
"parameters": {
"max_new_tokens": 30,
"do_sample": False,
},
}
generated_output = client.text_generation(prompt, max_new_tokens=30, stream=False)
print([f"{prompt}\n{generated_output}"])
```
---------
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-05-15 22:58:47 -06:00
|
|
|
import pytest
|
|
|
|
import base64
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
|
|
def flash_pali_gemma_handle(launcher):
|
|
|
|
with launcher(
|
|
|
|
"google/paligemma-3b-pt-224",
|
|
|
|
num_shard=1,
|
|
|
|
revision="float16",
|
|
|
|
max_input_length=4000,
|
|
|
|
max_total_tokens=4096,
|
|
|
|
) as handle:
|
|
|
|
yield handle
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
|
|
async def flash_pali_gemma(flash_pali_gemma_handle):
|
|
|
|
await flash_pali_gemma_handle.health(300)
|
|
|
|
return flash_pali_gemma_handle.client
|
|
|
|
|
|
|
|
|
2024-06-17 02:49:41 -06:00
|
|
|
def get_chicken():
|
|
|
|
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
|
|
|
encoded_string = base64.b64encode(image_file.read())
|
|
|
|
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
|
|
|
|
|
|
|
|
Pali gemma modeling (#1895)
This PR adds paligemma modeling code
Blog post: https://huggingface.co/blog/paligemma
Transformers PR: https://github.com/huggingface/transformers/pull/30814
install the latest changes and run with
```bash
# get the weights
# text-generation-server download-weights gv-hf/PaliGemma-base-224px-hf
# run TGI
text-generation-launcher --model-id gv-hf/PaliGemma-base-224px-hf
```
basic example sending various requests
```python
from huggingface_hub import InferenceClient
client = InferenceClient("http://127.0.0.1:3000")
images = [
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png",
]
prompts = [
"What animal is in this image?",
"Name three colors in this image.",
"What are 10 colors in this image?",
"Where is the cow standing?",
"answer en Where is the cow standing?",
"Is there a bird in the image?",
"Is ther a cow in the image?",
"Is there a rabbit in the image?",
"how many birds are in the image?",
"how many rabbits are in the image?",
]
for img in images:
print(f"\nImage: {img.split('/')[-1]}")
for prompt in prompts:
inputs = f"![]({img}){prompt}\n"
json_data = {
"inputs": inputs,
"parameters": {
"max_new_tokens": 30,
"do_sample": False,
},
}
generated_output = client.text_generation(prompt, max_new_tokens=30, stream=False)
print([f"{prompt}\n{generated_output}"])
```
---------
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-05-15 22:58:47 -06:00
|
|
|
def get_cow_beach():
|
|
|
|
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
|
|
|
|
encoded_string = base64.b64encode(image_file.read())
|
|
|
|
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
|
|
|
|
|
|
|
|
2024-06-25 08:53:20 -06:00
|
|
|
@pytest.mark.release
|
Pali gemma modeling (#1895)
This PR adds paligemma modeling code
Blog post: https://huggingface.co/blog/paligemma
Transformers PR: https://github.com/huggingface/transformers/pull/30814
install the latest changes and run with
```bash
# get the weights
# text-generation-server download-weights gv-hf/PaliGemma-base-224px-hf
# run TGI
text-generation-launcher --model-id gv-hf/PaliGemma-base-224px-hf
```
basic example sending various requests
```python
from huggingface_hub import InferenceClient
client = InferenceClient("http://127.0.0.1:3000")
images = [
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png",
]
prompts = [
"What animal is in this image?",
"Name three colors in this image.",
"What are 10 colors in this image?",
"Where is the cow standing?",
"answer en Where is the cow standing?",
"Is there a bird in the image?",
"Is ther a cow in the image?",
"Is there a rabbit in the image?",
"how many birds are in the image?",
"how many rabbits are in the image?",
]
for img in images:
print(f"\nImage: {img.split('/')[-1]}")
for prompt in prompts:
inputs = f"![]({img}){prompt}\n"
json_data = {
"inputs": inputs,
"parameters": {
"max_new_tokens": 30,
"do_sample": False,
},
}
generated_output = client.text_generation(prompt, max_new_tokens=30, stream=False)
print([f"{prompt}\n{generated_output}"])
```
---------
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-05-15 22:58:47 -06:00
|
|
|
@pytest.mark.asyncio
|
|
|
|
@pytest.mark.private
|
|
|
|
async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
|
|
|
|
cow = get_cow_beach()
|
|
|
|
inputs = f"![]({cow})Where is the cow standing?\n"
|
|
|
|
response = await flash_pali_gemma.generate(inputs, max_new_tokens=20)
|
|
|
|
|
|
|
|
assert response.generated_text == "beach"
|
|
|
|
assert response == response_snapshot
|
2024-06-17 02:49:41 -06:00
|
|
|
|
|
|
|
|
2024-06-25 08:53:20 -06:00
|
|
|
@pytest.mark.release
|
2024-06-17 02:49:41 -06:00
|
|
|
@pytest.mark.asyncio
|
|
|
|
@pytest.mark.private
|
|
|
|
async def test_flash_pali_gemma_two_images(flash_pali_gemma, response_snapshot):
|
|
|
|
chicken = get_chicken()
|
|
|
|
cow_beach = get_cow_beach()
|
|
|
|
response = await flash_pali_gemma.generate(
|
|
|
|
f"caption![]({chicken})![]({cow_beach})\n",
|
|
|
|
max_new_tokens=20,
|
|
|
|
)
|
|
|
|
# Is PaliGemma not able to handle two separate images? At least we
|
|
|
|
# get output showing that both images are used.
|
|
|
|
assert (
|
|
|
|
response.generated_text == "image result for chicken on the beach"
|
|
|
|
), f"{repr(response.generated_text)}"
|
|
|
|
assert response == response_snapshot
|