2024-05-31 05:51:42 -06:00
|
|
|
from io import BytesIO
|
|
|
|
from PIL import Image
|
Pali gemma modeling (#1895)
This PR adds paligemma modeling code
Blog post: https://huggingface.co/blog/paligemma
Transformers PR: https://github.com/huggingface/transformers/pull/30814
install the latest changes and run with
```bash
# get the weights
# text-generation-server download-weights gv-hf/PaliGemma-base-224px-hf
# run TGI
text-generation-launcher --model-id gv-hf/PaliGemma-base-224px-hf
```
basic example sending various requests
```python
from huggingface_hub import InferenceClient
client = InferenceClient("http://127.0.0.1:3000")
images = [
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png",
]
prompts = [
"What animal is in this image?",
"Name three colors in this image.",
"What are 10 colors in this image?",
"Where is the cow standing?",
"answer en Where is the cow standing?",
"Is there a bird in the image?",
"Is ther a cow in the image?",
"Is there a rabbit in the image?",
"how many birds are in the image?",
"how many rabbits are in the image?",
]
for img in images:
print(f"\nImage: {img.split('/')[-1]}")
for prompt in prompts:
inputs = f"![]({img}){prompt}\n"
json_data = {
"inputs": inputs,
"parameters": {
"max_new_tokens": 30,
"do_sample": False,
},
}
generated_output = client.text_generation(prompt, max_new_tokens=30, stream=False)
print([f"{prompt}\n{generated_output}"])
```
---------
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-05-15 22:58:47 -06:00
|
|
|
import torch
|
|
|
|
import torch.distributed
|
|
|
|
from opentelemetry import trace
|
2024-05-31 05:51:42 -06:00
|
|
|
from typing import Iterable, Optional, Tuple
|
Pali gemma modeling (#1895)
This PR adds paligemma modeling code
Blog post: https://huggingface.co/blog/paligemma
Transformers PR: https://github.com/huggingface/transformers/pull/30814
install the latest changes and run with
```bash
# get the weights
# text-generation-server download-weights gv-hf/PaliGemma-base-224px-hf
# run TGI
text-generation-launcher --model-id gv-hf/PaliGemma-base-224px-hf
```
basic example sending various requests
```python
from huggingface_hub import InferenceClient
client = InferenceClient("http://127.0.0.1:3000")
images = [
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png",
]
prompts = [
"What animal is in this image?",
"Name three colors in this image.",
"What are 10 colors in this image?",
"Where is the cow standing?",
"answer en Where is the cow standing?",
"Is there a bird in the image?",
"Is ther a cow in the image?",
"Is there a rabbit in the image?",
"how many birds are in the image?",
"how many rabbits are in the image?",
]
for img in images:
print(f"\nImage: {img.split('/')[-1]}")
for prompt in prompts:
inputs = f"![]({img}){prompt}\n"
json_data = {
"inputs": inputs,
"parameters": {
"max_new_tokens": 30,
"do_sample": False,
},
}
generated_output = client.text_generation(prompt, max_new_tokens=30, stream=False)
print([f"{prompt}\n{generated_output}"])
```
---------
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-05-15 22:58:47 -06:00
|
|
|
from text_generation_server.models.vlm_causal_lm import (
|
|
|
|
VlmCausalLM,
|
|
|
|
VlmCausalLMBatch,
|
|
|
|
image_text_replacement,
|
|
|
|
)
|
|
|
|
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
|
|
|
|
PaliGemmaForConditionalGeneration,
|
|
|
|
)
|
2024-05-31 05:51:42 -06:00
|
|
|
from transformers import AutoProcessor, AutoConfig
|
|
|
|
|
|
|
|
from text_generation_server.pb.generate_pb2 import Request
|
Pali gemma modeling (#1895)
This PR adds paligemma modeling code
Blog post: https://huggingface.co/blog/paligemma
Transformers PR: https://github.com/huggingface/transformers/pull/30814
install the latest changes and run with
```bash
# get the weights
# text-generation-server download-weights gv-hf/PaliGemma-base-224px-hf
# run TGI
text-generation-launcher --model-id gv-hf/PaliGemma-base-224px-hf
```
basic example sending various requests
```python
from huggingface_hub import InferenceClient
client = InferenceClient("http://127.0.0.1:3000")
images = [
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png",
]
prompts = [
"What animal is in this image?",
"Name three colors in this image.",
"What are 10 colors in this image?",
"Where is the cow standing?",
"answer en Where is the cow standing?",
"Is there a bird in the image?",
"Is ther a cow in the image?",
"Is there a rabbit in the image?",
"how many birds are in the image?",
"how many rabbits are in the image?",
]
for img in images:
print(f"\nImage: {img.split('/')[-1]}")
for prompt in prompts:
inputs = f"![]({img}){prompt}\n"
json_data = {
"inputs": inputs,
"parameters": {
"max_new_tokens": 30,
"do_sample": False,
},
}
generated_output = client.text_generation(prompt, max_new_tokens=30, stream=False)
print([f"{prompt}\n{generated_output}"])
```
---------
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-05-15 22:58:47 -06:00
|
|
|
|
|
|
|
tracer = trace.get_tracer(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
class PaliGemmaBatch(VlmCausalLMBatch):
|
|
|
|
@classmethod
|
2024-05-31 05:51:42 -06:00
|
|
|
def batch_tokenized_inputs(
|
|
|
|
cls, requests: Iterable[Request], tokenizer, processor, config
|
|
|
|
):
|
Pali gemma modeling (#1895)
This PR adds paligemma modeling code
Blog post: https://huggingface.co/blog/paligemma
Transformers PR: https://github.com/huggingface/transformers/pull/30814
install the latest changes and run with
```bash
# get the weights
# text-generation-server download-weights gv-hf/PaliGemma-base-224px-hf
# run TGI
text-generation-launcher --model-id gv-hf/PaliGemma-base-224px-hf
```
basic example sending various requests
```python
from huggingface_hub import InferenceClient
client = InferenceClient("http://127.0.0.1:3000")
images = [
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png",
]
prompts = [
"What animal is in this image?",
"Name three colors in this image.",
"What are 10 colors in this image?",
"Where is the cow standing?",
"answer en Where is the cow standing?",
"Is there a bird in the image?",
"Is ther a cow in the image?",
"Is there a rabbit in the image?",
"how many birds are in the image?",
"how many rabbits are in the image?",
]
for img in images:
print(f"\nImage: {img.split('/')[-1]}")
for prompt in prompts:
inputs = f"![]({img}){prompt}\n"
json_data = {
"inputs": inputs,
"parameters": {
"max_new_tokens": 30,
"do_sample": False,
},
}
generated_output = client.text_generation(prompt, max_new_tokens=30, stream=False)
print([f"{prompt}\n{generated_output}"])
```
---------
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-05-15 22:58:47 -06:00
|
|
|
batch_inputs = []
|
|
|
|
image_inputs = []
|
|
|
|
max_truncation = 0
|
|
|
|
for r in requests:
|
|
|
|
full_text = ""
|
|
|
|
image_id = 0
|
2024-05-31 05:51:42 -06:00
|
|
|
for chunk in r.input_chunks.chunks:
|
|
|
|
chunk_type = chunk.WhichOneof("chunk")
|
|
|
|
if chunk_type == "text":
|
|
|
|
full_text += "<bos>" + chunk.text + "\n"
|
|
|
|
elif chunk_type == "image":
|
|
|
|
image = Image.open(BytesIO(chunk.image.data))
|
Pali gemma modeling (#1895)
This PR adds paligemma modeling code
Blog post: https://huggingface.co/blog/paligemma
Transformers PR: https://github.com/huggingface/transformers/pull/30814
install the latest changes and run with
```bash
# get the weights
# text-generation-server download-weights gv-hf/PaliGemma-base-224px-hf
# run TGI
text-generation-launcher --model-id gv-hf/PaliGemma-base-224px-hf
```
basic example sending various requests
```python
from huggingface_hub import InferenceClient
client = InferenceClient("http://127.0.0.1:3000")
images = [
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png",
]
prompts = [
"What animal is in this image?",
"Name three colors in this image.",
"What are 10 colors in this image?",
"Where is the cow standing?",
"answer en Where is the cow standing?",
"Is there a bird in the image?",
"Is ther a cow in the image?",
"Is there a rabbit in the image?",
"how many birds are in the image?",
"how many rabbits are in the image?",
]
for img in images:
print(f"\nImage: {img.split('/')[-1]}")
for prompt in prompts:
inputs = f"![]({img}){prompt}\n"
json_data = {
"inputs": inputs,
"parameters": {
"max_new_tokens": 30,
"do_sample": False,
},
}
generated_output = client.text_generation(prompt, max_new_tokens=30, stream=False)
print([f"{prompt}\n{generated_output}"])
```
---------
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-05-15 22:58:47 -06:00
|
|
|
# TODO do_convert_RGB should be on by default ?
|
|
|
|
image = image.convert("RGB")
|
|
|
|
image_input = processor.image_processor(image, return_tensors="pt")
|
|
|
|
full_text += image_text_replacement(image_input, config, image_id)
|
|
|
|
image_inputs.append(image_input)
|
|
|
|
else:
|
2024-05-31 05:51:42 -06:00
|
|
|
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
Pali gemma modeling (#1895)
This PR adds paligemma modeling code
Blog post: https://huggingface.co/blog/paligemma
Transformers PR: https://github.com/huggingface/transformers/pull/30814
install the latest changes and run with
```bash
# get the weights
# text-generation-server download-weights gv-hf/PaliGemma-base-224px-hf
# run TGI
text-generation-launcher --model-id gv-hf/PaliGemma-base-224px-hf
```
basic example sending various requests
```python
from huggingface_hub import InferenceClient
client = InferenceClient("http://127.0.0.1:3000")
images = [
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png",
]
prompts = [
"What animal is in this image?",
"Name three colors in this image.",
"What are 10 colors in this image?",
"Where is the cow standing?",
"answer en Where is the cow standing?",
"Is there a bird in the image?",
"Is ther a cow in the image?",
"Is there a rabbit in the image?",
"how many birds are in the image?",
"how many rabbits are in the image?",
]
for img in images:
print(f"\nImage: {img.split('/')[-1]}")
for prompt in prompts:
inputs = f"![]({img}){prompt}\n"
json_data = {
"inputs": inputs,
"parameters": {
"max_new_tokens": 30,
"do_sample": False,
},
}
generated_output = client.text_generation(prompt, max_new_tokens=30, stream=False)
print([f"{prompt}\n{generated_output}"])
```
---------
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-05-15 22:58:47 -06:00
|
|
|
|
|
|
|
batch_inputs.append(full_text)
|
|
|
|
max_truncation = max(max_truncation, r.truncate)
|
|
|
|
|
|
|
|
batch_tokenized_inputs = tokenizer(
|
|
|
|
batch_inputs,
|
|
|
|
truncation=True,
|
|
|
|
max_length=max_truncation,
|
|
|
|
add_special_tokens=False,
|
|
|
|
)["input_ids"]
|
|
|
|
if image_inputs:
|
|
|
|
image_input = image_inputs[0]
|
|
|
|
new_image_inputs = {
|
|
|
|
"pixel_values": torch.cat(
|
|
|
|
[img["pixel_values"] for img in image_inputs], dim=0
|
|
|
|
),
|
|
|
|
}
|
|
|
|
if "pixel_attention_mask" in image_input:
|
|
|
|
new_image_inputs["pixel_attention_mask"] = torch.cat(
|
|
|
|
[img["pixel_attention_mask"] for img in image_inputs], dim=0
|
|
|
|
)
|
|
|
|
if "image_sizes" in image_input:
|
|
|
|
new_image_inputs["image_sizes"] = torch.cat(
|
|
|
|
[img["image_sizes"] for img in image_inputs], dim=0
|
|
|
|
)
|
|
|
|
image_inputs = new_image_inputs
|
|
|
|
else:
|
|
|
|
image_inputs = None
|
|
|
|
return batch_tokenized_inputs, image_inputs
|
|
|
|
|
|
|
|
|
|
|
|
class PaliGemma(VlmCausalLM):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
model_id: str,
|
|
|
|
revision: Optional[str] = None,
|
|
|
|
quantize: Optional[str] = None,
|
|
|
|
speculator: Optional[str] = None,
|
|
|
|
dtype: Optional[torch.dtype] = None,
|
|
|
|
trust_remote_code: bool = False,
|
|
|
|
):
|
|
|
|
self.processor = AutoProcessor.from_pretrained(
|
|
|
|
model_id,
|
|
|
|
revision=revision,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
|
|
|
|
|
|
|
super().__init__(
|
|
|
|
config_cls=AutoConfig,
|
|
|
|
model_cls=PaliGemmaForConditionalGeneration,
|
|
|
|
model_id=model_id,
|
|
|
|
revision=revision,
|
|
|
|
quantize=quantize,
|
|
|
|
speculator=speculator,
|
|
|
|
dtype=dtype,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def batch_type(self):
|
|
|
|
return PaliGemmaBatch
|
|
|
|
|
|
|
|
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
|
|
|
return (
|
|
|
|
len(model.text_model.model.layers),
|
|
|
|
model.text_model.model.num_key_value_heads,
|
|
|
|
model.text_model.model.head_size,
|
|
|
|
)
|
|
|
|
|
|
|
|
def max_past(self) -> Optional[int]:
|
|
|
|
return getattr(self.model.text_model, "max_past", None)
|