77 lines
2.8 KiB
Python
77 lines
2.8 KiB
Python
from io import BytesIO
|
|
from PIL import Image
|
|
import torch
|
|
import torch.distributed
|
|
from opentelemetry import trace
|
|
from typing import Iterable, Optional, Tuple
|
|
from text_generation_server.models.vlm_causal_lm import (
|
|
VlmCausalLM,
|
|
VlmCausalLMBatch,
|
|
image_text_replacement,
|
|
)
|
|
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
|
|
PaliGemmaForConditionalGeneration,
|
|
)
|
|
from transformers import AutoProcessor, AutoConfig
|
|
|
|
from text_generation_server.pb.generate_pb2 import Request
|
|
|
|
tracer = trace.get_tracer(__name__)
|
|
|
|
|
|
class PaliGemmaBatch(VlmCausalLMBatch):
|
|
@classmethod
|
|
def batch_tokenized_inputs(
|
|
cls, requests: Iterable[Request], tokenizer, processor, config
|
|
):
|
|
batch_inputs = []
|
|
image_inputs = []
|
|
max_truncation = 0
|
|
for r in requests:
|
|
full_text = ""
|
|
image_id = 0
|
|
for chunk in r.input_chunks.chunks:
|
|
chunk_type = chunk.WhichOneof("chunk")
|
|
if chunk_type == "text":
|
|
full_text += "<bos>" + chunk.text + "\n"
|
|
elif chunk_type == "image":
|
|
image = Image.open(BytesIO(chunk.image.data))
|
|
# TODO do_convert_RGB should be on by default ?
|
|
image = image.convert("RGB")
|
|
image_input = processor.image_processor(image, return_tensors="pt")
|
|
full_text += image_text_replacement(
|
|
processor, image_input, config, image_id
|
|
)
|
|
image_inputs.append(image_input)
|
|
else:
|
|
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
|
|
|
batch_inputs.append(full_text)
|
|
max_truncation = max(max_truncation, r.truncate)
|
|
|
|
batch_tokenized_inputs = tokenizer(
|
|
batch_inputs,
|
|
truncation=True,
|
|
max_length=max_truncation,
|
|
add_special_tokens=False,
|
|
)["input_ids"]
|
|
if image_inputs:
|
|
image_input = image_inputs[0]
|
|
new_image_inputs = {
|
|
"pixel_values": torch.cat(
|
|
[img["pixel_values"] for img in image_inputs], dim=0
|
|
),
|
|
}
|
|
if "pixel_attention_mask" in image_input:
|
|
new_image_inputs["pixel_attention_mask"] = torch.cat(
|
|
[img["pixel_attention_mask"] for img in image_inputs], dim=0
|
|
)
|
|
if "image_sizes" in image_input:
|
|
new_image_inputs["image_sizes"] = torch.cat(
|
|
[img["image_sizes"] for img in image_inputs], dim=0
|
|
)
|
|
image_inputs = new_image_inputs
|
|
else:
|
|
image_inputs = None
|
|
return batch_tokenized_inputs, image_inputs
|