2024-02-21 06:15:22 -07:00
|
|
|
import torch
|
|
|
|
import torch.distributed
|
|
|
|
|
|
|
|
from opentelemetry import trace
|
|
|
|
from typing import Optional
|
Pali gemma modeling (#1895)
This PR adds paligemma modeling code
Blog post: https://huggingface.co/blog/paligemma
Transformers PR: https://github.com/huggingface/transformers/pull/30814
install the latest changes and run with
```bash
# get the weights
# text-generation-server download-weights gv-hf/PaliGemma-base-224px-hf
# run TGI
text-generation-launcher --model-id gv-hf/PaliGemma-base-224px-hf
```
basic example sending various requests
```python
from huggingface_hub import InferenceClient
client = InferenceClient("http://127.0.0.1:3000")
images = [
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png",
]
prompts = [
"What animal is in this image?",
"Name three colors in this image.",
"What are 10 colors in this image?",
"Where is the cow standing?",
"answer en Where is the cow standing?",
"Is there a bird in the image?",
"Is ther a cow in the image?",
"Is there a rabbit in the image?",
"how many birds are in the image?",
"how many rabbits are in the image?",
]
for img in images:
print(f"\nImage: {img.split('/')[-1]}")
for prompt in prompts:
inputs = f"![]({img}){prompt}\n"
json_data = {
"inputs": inputs,
"parameters": {
"max_new_tokens": 30,
"do_sample": False,
},
}
generated_output = client.text_generation(prompt, max_new_tokens=30, stream=False)
print([f"{prompt}\n{generated_output}"])
```
---------
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-05-15 22:58:47 -06:00
|
|
|
from transformers import AutoConfig, AutoTokenizer
|
2024-02-21 06:15:22 -07:00
|
|
|
|
|
|
|
from text_generation_server.models import FlashCausalLM
|
|
|
|
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
|
|
|
FlashGemmaForCausalLM,
|
|
|
|
)
|
|
|
|
from text_generation_server.utils import (
|
|
|
|
initialize_torch_distributed,
|
|
|
|
weight_files,
|
|
|
|
Weights,
|
|
|
|
)
|
|
|
|
|
|
|
|
tracer = trace.get_tracer(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
class FlashGemma(FlashCausalLM):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
model_id: str,
|
|
|
|
revision: Optional[str] = None,
|
|
|
|
quantize: Optional[str] = None,
|
2024-05-14 04:33:18 -06:00
|
|
|
speculator: Optional[str] = None,
|
2024-02-21 06:15:22 -07:00
|
|
|
dtype: Optional[torch.dtype] = None,
|
|
|
|
trust_remote_code: bool = False,
|
|
|
|
):
|
|
|
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
device = torch.device(f"cuda:{rank}")
|
|
|
|
dtype = torch.bfloat16 if dtype is None else dtype
|
|
|
|
else:
|
|
|
|
raise NotImplementedError("FlashGemma is only available on GPU")
|
|
|
|
|
Pali gemma modeling (#1895)
This PR adds paligemma modeling code
Blog post: https://huggingface.co/blog/paligemma
Transformers PR: https://github.com/huggingface/transformers/pull/30814
install the latest changes and run with
```bash
# get the weights
# text-generation-server download-weights gv-hf/PaliGemma-base-224px-hf
# run TGI
text-generation-launcher --model-id gv-hf/PaliGemma-base-224px-hf
```
basic example sending various requests
```python
from huggingface_hub import InferenceClient
client = InferenceClient("http://127.0.0.1:3000")
images = [
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png",
]
prompts = [
"What animal is in this image?",
"Name three colors in this image.",
"What are 10 colors in this image?",
"Where is the cow standing?",
"answer en Where is the cow standing?",
"Is there a bird in the image?",
"Is ther a cow in the image?",
"Is there a rabbit in the image?",
"how many birds are in the image?",
"how many rabbits are in the image?",
]
for img in images:
print(f"\nImage: {img.split('/')[-1]}")
for prompt in prompts:
inputs = f"![]({img}){prompt}\n"
json_data = {
"inputs": inputs,
"parameters": {
"max_new_tokens": 30,
"do_sample": False,
},
}
generated_output = client.text_generation(prompt, max_new_tokens=30, stream=False)
print([f"{prompt}\n{generated_output}"])
```
---------
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-05-15 22:58:47 -06:00
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
2024-02-21 06:15:22 -07:00
|
|
|
model_id,
|
|
|
|
revision=revision,
|
|
|
|
padding_side="left",
|
|
|
|
truncation_side="left",
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
|
|
|
|
Pali gemma modeling (#1895)
This PR adds paligemma modeling code
Blog post: https://huggingface.co/blog/paligemma
Transformers PR: https://github.com/huggingface/transformers/pull/30814
install the latest changes and run with
```bash
# get the weights
# text-generation-server download-weights gv-hf/PaliGemma-base-224px-hf
# run TGI
text-generation-launcher --model-id gv-hf/PaliGemma-base-224px-hf
```
basic example sending various requests
```python
from huggingface_hub import InferenceClient
client = InferenceClient("http://127.0.0.1:3000")
images = [
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png",
]
prompts = [
"What animal is in this image?",
"Name three colors in this image.",
"What are 10 colors in this image?",
"Where is the cow standing?",
"answer en Where is the cow standing?",
"Is there a bird in the image?",
"Is ther a cow in the image?",
"Is there a rabbit in the image?",
"how many birds are in the image?",
"how many rabbits are in the image?",
]
for img in images:
print(f"\nImage: {img.split('/')[-1]}")
for prompt in prompts:
inputs = f"![]({img}){prompt}\n"
json_data = {
"inputs": inputs,
"parameters": {
"max_new_tokens": 30,
"do_sample": False,
},
}
generated_output = client.text_generation(prompt, max_new_tokens=30, stream=False)
print([f"{prompt}\n{generated_output}"])
```
---------
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-05-15 22:58:47 -06:00
|
|
|
config = AutoConfig.from_pretrained(
|
2024-02-21 06:15:22 -07:00
|
|
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
|
|
)
|
|
|
|
config.quantize = quantize
|
2024-05-14 04:33:18 -06:00
|
|
|
config.speculator = speculator
|
2024-02-21 06:15:22 -07:00
|
|
|
|
|
|
|
torch.distributed.barrier(group=self.process_group)
|
|
|
|
|
|
|
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
|
|
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
|
|
|
if config.quantize in ["gptq", "awq"]:
|
|
|
|
weights._set_gptq_params(model_id, revision)
|
|
|
|
|
Pali gemma modeling (#1895)
This PR adds paligemma modeling code
Blog post: https://huggingface.co/blog/paligemma
Transformers PR: https://github.com/huggingface/transformers/pull/30814
install the latest changes and run with
```bash
# get the weights
# text-generation-server download-weights gv-hf/PaliGemma-base-224px-hf
# run TGI
text-generation-launcher --model-id gv-hf/PaliGemma-base-224px-hf
```
basic example sending various requests
```python
from huggingface_hub import InferenceClient
client = InferenceClient("http://127.0.0.1:3000")
images = [
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png",
]
prompts = [
"What animal is in this image?",
"Name three colors in this image.",
"What are 10 colors in this image?",
"Where is the cow standing?",
"answer en Where is the cow standing?",
"Is there a bird in the image?",
"Is ther a cow in the image?",
"Is there a rabbit in the image?",
"how many birds are in the image?",
"how many rabbits are in the image?",
]
for img in images:
print(f"\nImage: {img.split('/')[-1]}")
for prompt in prompts:
inputs = f"![]({img}){prompt}\n"
json_data = {
"inputs": inputs,
"parameters": {
"max_new_tokens": 30,
"do_sample": False,
},
}
generated_output = client.text_generation(prompt, max_new_tokens=30, stream=False)
print([f"{prompt}\n{generated_output}"])
```
---------
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-05-15 22:58:47 -06:00
|
|
|
# TODO hardcoded
|
|
|
|
prefix = "language_model"
|
|
|
|
model = FlashGemmaForCausalLM(prefix, config, weights, causal=True)
|
2024-02-21 06:15:22 -07:00
|
|
|
|
|
|
|
torch.distributed.barrier(group=self.process_group)
|
|
|
|
super(FlashGemma, self).__init__(
|
|
|
|
model=model,
|
|
|
|
tokenizer=tokenizer,
|
|
|
|
num_layers=len(model.model.layers),
|
|
|
|
num_kv_heads=model.model.num_key_value_heads,
|
|
|
|
head_size=model.model.head_size,
|
|
|
|
dtype=dtype,
|
|
|
|
device=device,
|
|
|
|
rank=rank,
|
|
|
|
world_size=world_size,
|
|
|
|
)
|