Merge branch 'main' into ci_amd2
This commit is contained in:
commit
c73355b99c
|
@ -152,8 +152,7 @@ jobs:
|
||||||
group: ${{ github.workflow }}-${{ github.job }}-rocm-${{ github.head_ref || github.run_id }}
|
group: ${{ github.workflow }}-${{ github.job }}-rocm-${{ github.head_ref || github.run_id }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
runs-on: [amd-gpu-tgi, multi-gpu, mi250]
|
runs-on: [amd-gpu-tgi, multi-gpu, mi250]
|
||||||
needs:
|
needs: build-and-push-image
|
||||||
- build-and-push-image
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
|
@ -48,7 +48,7 @@ RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | gpg --dea
|
||||||
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
|
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
|
||||||
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list
|
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list
|
||||||
|
|
||||||
RUN apt-get update && apt install -y intel-basekit xpu-smi
|
RUN apt-get update && apt install -y intel-basekit xpu-smi cmake python3-dev ninja-build
|
||||||
|
|
||||||
# Text Generation Inference base env
|
# Text Generation Inference base env
|
||||||
ENV HUGGINGFACE_HUB_CACHE=/data \
|
ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||||
|
@ -57,8 +57,8 @@ ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||||
|
|
||||||
|
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.1.30a0-cp310-cp310-linux_x86_64.whl
|
RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl
|
||||||
RUN pip install intel_extension_for_pytorch-2.1.30a0-cp310-cp310-linux_x86_64.whl
|
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b group_rope origin/dev/gqa_rope
|
||||||
|
|
||||||
# Install server
|
# Install server
|
||||||
COPY proto proto
|
COPY proto proto
|
||||||
|
@ -76,6 +76,10 @@ ENV LIBRARY_PATH=/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/ccl/latest/l
|
||||||
ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64:
|
ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64:
|
||||||
ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
|
ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
|
||||||
ENV CCL_ZE_IPC_EXCHANGE=sockets
|
ENV CCL_ZE_IPC_EXCHANGE=sockets
|
||||||
|
ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest
|
||||||
|
ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include
|
||||||
|
|
||||||
|
RUN pip uninstall -y intel-extension-for-pytorch && cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc' BUILD_SEPARATE_OPS=OFF BUILD_WITH_CPU=OFF USE_XETLA=ON python setup.py install && rm -rf /usr/src/intel-extension-for-pytorch
|
||||||
|
|
||||||
# Install benchmarker
|
# Install benchmarker
|
||||||
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||||
|
|
|
@ -29,6 +29,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
id=0,
|
id=0,
|
||||||
inputs="Test",
|
inputs="Test",
|
||||||
|
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]),
|
||||||
prefill_logprobs=True,
|
prefill_logprobs=True,
|
||||||
truncate=100,
|
truncate=100,
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
|
|
|
@ -25,6 +25,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
id=0,
|
id=0,
|
||||||
inputs="Test",
|
inputs="Test",
|
||||||
|
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]),
|
||||||
prefill_logprobs=True,
|
prefill_logprobs=True,
|
||||||
truncate=100,
|
truncate=100,
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
|
|
|
@ -15,6 +15,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
id=0,
|
id=0,
|
||||||
inputs="def",
|
inputs="def",
|
||||||
|
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="def")]),
|
||||||
prefill_logprobs=True,
|
prefill_logprobs=True,
|
||||||
truncate=100,
|
truncate=100,
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
|
@ -32,6 +33,13 @@ def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
id=0,
|
id=0,
|
||||||
inputs="<fim-prefix>def<fim-suffix>world<fim-middle>",
|
inputs="<fim-prefix>def<fim-suffix>world<fim-middle>",
|
||||||
|
input_chunks=generate_pb2.Input(
|
||||||
|
chunks=[
|
||||||
|
generate_pb2.InputChunk(
|
||||||
|
text="<fim-prefix>def<fim-suffix>world<fim-middle>"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
prefill_logprobs=True,
|
prefill_logprobs=True,
|
||||||
truncate=100,
|
truncate=100,
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
|
|
|
@ -28,6 +28,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
id=0,
|
id=0,
|
||||||
inputs="Test",
|
inputs="Test",
|
||||||
|
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]),
|
||||||
prefill_logprobs=True,
|
prefill_logprobs=True,
|
||||||
truncate=100,
|
truncate=100,
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
|
|
|
@ -7,6 +7,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenize
|
||||||
from typing import Optional, Tuple, List, Type, Dict
|
from typing import Optional, Tuple, List, Type, Dict
|
||||||
|
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
|
from text_generation_server.utils.chunks import concat_text_chunks
|
||||||
from text_generation_server.utils.tokens import batch_top_tokens
|
from text_generation_server.utils.tokens import batch_top_tokens
|
||||||
from text_generation_server.models.types import (
|
from text_generation_server.models.types import (
|
||||||
Batch,
|
Batch,
|
||||||
|
@ -86,7 +87,8 @@ class CausalLMBatch(Batch):
|
||||||
max_decode_tokens = 0
|
max_decode_tokens = 0
|
||||||
for i, r in enumerate(pb.requests):
|
for i, r in enumerate(pb.requests):
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
inputs.append(r.inputs)
|
inputs.append(concat_text_chunks(r.input_chunks.chunks))
|
||||||
|
|
||||||
next_token_choosers.append(
|
next_token_choosers.append(
|
||||||
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
|
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
|
||||||
)
|
)
|
||||||
|
|
|
@ -11,9 +11,10 @@ from loguru import logger
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from typing import Optional, Tuple, List, Type, Dict
|
from typing import Iterable, Optional, Tuple, List, Type, Dict
|
||||||
|
|
||||||
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
||||||
|
from text_generation_server.utils.chunks import concat_text_chunks
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
from text_generation_server.utils.tokens import batch_top_tokens
|
from text_generation_server.utils.tokens import batch_top_tokens
|
||||||
|
@ -127,11 +128,13 @@ class FlashCausalLMBatch(Batch):
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def batch_tokenized_inputs(cls, requests, tokenizer):
|
def batch_tokenized_inputs(
|
||||||
|
cls, requests: Iterable[generate_pb2.Request], tokenizer
|
||||||
|
):
|
||||||
batch_inputs = []
|
batch_inputs = []
|
||||||
max_truncation = 0
|
max_truncation = 0
|
||||||
for r in requests:
|
for r in requests:
|
||||||
batch_inputs.append(r.inputs)
|
batch_inputs.append(concat_text_chunks(r.input_chunks.chunks))
|
||||||
max_truncation = max(max_truncation, r.truncate)
|
max_truncation = max(max_truncation, r.truncate)
|
||||||
|
|
||||||
batch_tokenized_inputs = tokenizer(
|
batch_tokenized_inputs = tokenizer(
|
||||||
|
|
|
@ -20,6 +20,7 @@ from text_generation_server.utils import (
|
||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.chunks import concat_text_chunks
|
||||||
|
|
||||||
# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py
|
# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py
|
||||||
|
|
||||||
|
@ -91,7 +92,9 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||||
for i, r in enumerate(pb.requests):
|
for i, r in enumerate(pb.requests):
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
||||||
inputs.append(escape_custom_split_sequence(r.inputs))
|
inputs.append(
|
||||||
|
escape_custom_split_sequence(concat_text_chunks(r.input_chunks.chunks))
|
||||||
|
)
|
||||||
next_token_choosers.append(
|
next_token_choosers.append(
|
||||||
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
|
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import torch
|
from io import BytesIO
|
||||||
|
from PIL import Image
|
||||||
import torch
|
import torch
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
@ -21,11 +22,6 @@ from text_generation_server.models.types import (
|
||||||
)
|
)
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||||
from text_generation_server.models.vlm_causal_lm import split
|
|
||||||
|
|
||||||
import re
|
|
||||||
|
|
||||||
IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
|
|
||||||
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
@ -109,7 +105,7 @@ class IdeficsCausalLMBatch(Batch):
|
||||||
max_decode_tokens = 0
|
max_decode_tokens = 0
|
||||||
for i, r in enumerate(pb.requests):
|
for i, r in enumerate(pb.requests):
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.input_chunks.chunks)
|
||||||
next_token_choosers.append(
|
next_token_choosers.append(
|
||||||
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
|
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
|
||||||
)
|
)
|
||||||
|
@ -128,8 +124,15 @@ class IdeficsCausalLMBatch(Batch):
|
||||||
for inp in inputs:
|
for inp in inputs:
|
||||||
# Each input is encoded into a list, where each element of this input list is either a string or a URL
|
# Each input is encoded into a list, where each element of this input list is either a string or a URL
|
||||||
prompt = []
|
prompt = []
|
||||||
for chunk in split(inp):
|
for chunk in inp:
|
||||||
prompt.append(chunk["content"])
|
chunk_type = chunk.WhichOneof("chunk")
|
||||||
|
if chunk_type == "text":
|
||||||
|
prompt.append(chunk.text)
|
||||||
|
elif chunk_type == "image":
|
||||||
|
image = Image.open(BytesIO(chunk.image.data))
|
||||||
|
prompt.append(image)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||||
prompts.append(prompt)
|
prompts.append(prompt)
|
||||||
|
|
||||||
# The processor replaces the call to tokenizer, and
|
# The processor replaces the call to tokenizer, and
|
||||||
|
|
|
@ -27,6 +27,7 @@ from text_generation_server.models.types import (
|
||||||
Generation,
|
Generation,
|
||||||
GeneratedText,
|
GeneratedText,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.chunks import concat_text_chunks
|
||||||
from text_generation_server.utils.tokens import batch_top_tokens, Sampling
|
from text_generation_server.utils.tokens import batch_top_tokens, Sampling
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||||
|
@ -139,7 +140,7 @@ class MambaBatch(Batch):
|
||||||
max_decode_tokens = 0
|
max_decode_tokens = 0
|
||||||
for i, r in enumerate(pb.requests):
|
for i, r in enumerate(pb.requests):
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
inputs.append(r.inputs)
|
inputs.append(concat_text_chunks(r.input_chunks.chunks))
|
||||||
next_token_choosers.append(
|
next_token_choosers.append(
|
||||||
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
|
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,55 +1,48 @@
|
||||||
|
from io import BytesIO
|
||||||
|
from PIL import Image
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from typing import Optional, Tuple
|
from typing import Iterable, Optional, Tuple
|
||||||
from text_generation_server.models.vlm_causal_lm import (
|
from text_generation_server.models.vlm_causal_lm import (
|
||||||
VlmCausalLM,
|
VlmCausalLM,
|
||||||
VlmCausalLMBatch,
|
VlmCausalLMBatch,
|
||||||
image_text_replacement,
|
image_text_replacement,
|
||||||
load_data_uri,
|
|
||||||
split,
|
|
||||||
)
|
)
|
||||||
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
|
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
|
||||||
PaliGemmaForConditionalGeneration,
|
PaliGemmaForConditionalGeneration,
|
||||||
)
|
)
|
||||||
from transformers import AutoProcessor, AutoConfig, AutoImageProcessor
|
from transformers import AutoProcessor, AutoConfig
|
||||||
|
|
||||||
|
from text_generation_server.pb.generate_pb2 import Request
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PaliGemmaBatch(VlmCausalLMBatch):
|
class PaliGemmaBatch(VlmCausalLMBatch):
|
||||||
@classmethod
|
@classmethod
|
||||||
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config):
|
def batch_tokenized_inputs(
|
||||||
|
cls, requests: Iterable[Request], tokenizer, processor, config
|
||||||
|
):
|
||||||
batch_inputs = []
|
batch_inputs = []
|
||||||
image_inputs = []
|
image_inputs = []
|
||||||
max_truncation = 0
|
max_truncation = 0
|
||||||
for r in requests:
|
for r in requests:
|
||||||
chunks = split(r.inputs)
|
|
||||||
full_text = ""
|
full_text = ""
|
||||||
image_id = 0
|
image_id = 0
|
||||||
for chunk in chunks:
|
for chunk in r.input_chunks.chunks:
|
||||||
if chunk["type"] == "text":
|
chunk_type = chunk.WhichOneof("chunk")
|
||||||
full_text += "<bos>" + chunk["content"] + "\n"
|
if chunk_type == "text":
|
||||||
elif chunk["type"] == "image":
|
full_text += "<bos>" + chunk.text + "\n"
|
||||||
image = chunk["content"]
|
elif chunk_type == "image":
|
||||||
# Should never receive URLs anymore, processing should be done
|
image = Image.open(BytesIO(chunk.image.data))
|
||||||
# On the rust layer.
|
|
||||||
# This avoid making n queries per TP
|
|
||||||
# if image.startswith("https://") or image.startswith("http://"):
|
|
||||||
# image = processor.image_processor.fetch_images(image)
|
|
||||||
if image.startswith("data:"):
|
|
||||||
image = load_data_uri(image)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Cannot process input image not starting with data:"
|
|
||||||
)
|
|
||||||
# TODO do_convert_RGB should be on by default ?
|
# TODO do_convert_RGB should be on by default ?
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
image_input = processor.image_processor(image, return_tensors="pt")
|
image_input = processor.image_processor(image, return_tensors="pt")
|
||||||
full_text += image_text_replacement(image_input, config, image_id)
|
full_text += image_text_replacement(image_input, config, image_id)
|
||||||
image_inputs.append(image_input)
|
image_inputs.append(image_input)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
|
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||||
|
|
||||||
batch_inputs.append(full_text)
|
batch_inputs.append(full_text)
|
||||||
max_truncation = max(max_truncation, r.truncate)
|
max_truncation = max(max_truncation, r.truncate)
|
||||||
|
|
|
@ -6,6 +6,7 @@ from opentelemetry import trace
|
||||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
|
||||||
from typing import Optional, Tuple, List, Type, Dict
|
from typing import Optional, Tuple, List, Type, Dict
|
||||||
|
|
||||||
|
from text_generation_server.utils.chunks import concat_text_chunks
|
||||||
from text_generation_server.utils.tokens import batch_top_tokens
|
from text_generation_server.utils.tokens import batch_top_tokens
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
from text_generation_server.models.types import (
|
from text_generation_server.models.types import (
|
||||||
|
@ -93,7 +94,7 @@ class Seq2SeqLMBatch(Batch):
|
||||||
padding_right_offset = 0
|
padding_right_offset = 0
|
||||||
max_decode_tokens = 0
|
max_decode_tokens = 0
|
||||||
for i, r in enumerate(pb.requests):
|
for i, r in enumerate(pb.requests):
|
||||||
inputs.append(r.inputs)
|
inputs.append(concat_text_chunks(r.input_chunks.chunks))
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
decoder_input_lengths.append(1)
|
decoder_input_lengths.append(1)
|
||||||
next_token_choosers.append(
|
next_token_choosers.append(
|
||||||
|
|
|
@ -1,12 +1,9 @@
|
||||||
import re
|
|
||||||
import torch
|
import torch
|
||||||
import math
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import base64
|
|
||||||
|
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from typing import Optional, Tuple, List, Type, Dict
|
from typing import Iterable, Optional, Tuple, List, Type, Dict
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from transformers.image_processing_utils import select_best_resolution
|
from transformers.image_processing_utils import select_best_resolution
|
||||||
|
@ -18,25 +15,6 @@ from text_generation_server.models.flash_mistral import (
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
|
|
||||||
|
|
||||||
|
|
||||||
def split(string) -> List[Dict[str, str]]:
|
|
||||||
parts = []
|
|
||||||
cursor = 0
|
|
||||||
for pattern in IMAGES.finditer(string):
|
|
||||||
start = pattern.start()
|
|
||||||
if start != cursor:
|
|
||||||
parts.append({"type": "text", "content": string[cursor:start]})
|
|
||||||
|
|
||||||
parts.append({"type": "image", "content": pattern.group(1)})
|
|
||||||
cursor = pattern.end()
|
|
||||||
|
|
||||||
if cursor != len(string):
|
|
||||||
parts.append({"type": "text", "content": string[cursor:]})
|
|
||||||
|
|
||||||
return parts
|
|
||||||
|
|
||||||
|
|
||||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||||
"""
|
"""
|
||||||
|
@ -129,13 +107,6 @@ def get_number_of_features(height: int, width: int, config) -> int:
|
||||||
return unpadded_features + newline_features + base_features
|
return unpadded_features + newline_features + base_features
|
||||||
|
|
||||||
|
|
||||||
def load_data_uri(image_uri: str) -> Image.Image:
|
|
||||||
image_uri = image_uri.split(",")[-1]
|
|
||||||
content = base64.b64decode(image_uri)
|
|
||||||
image = Image.open(BytesIO(content))
|
|
||||||
return image
|
|
||||||
|
|
||||||
|
|
||||||
class VlmCausalLMBatch(FlashCausalLMBatch):
|
class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||||
pixel_values: Optional[List[torch.Tensor]]
|
pixel_values: Optional[List[torch.Tensor]]
|
||||||
pixel_attention_mask: Optional[List[torch.Tensor]]
|
pixel_attention_mask: Optional[List[torch.Tensor]]
|
||||||
|
@ -159,35 +130,26 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config):
|
def batch_tokenized_inputs(
|
||||||
|
cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config
|
||||||
|
):
|
||||||
batch_inputs = []
|
batch_inputs = []
|
||||||
image_inputs = []
|
image_inputs = []
|
||||||
max_truncation = 0
|
max_truncation = 0
|
||||||
for r in requests:
|
for r in requests:
|
||||||
chunks = split(r.inputs)
|
|
||||||
full_text = ""
|
full_text = ""
|
||||||
image_id = 0
|
image_id = 0
|
||||||
for chunk in chunks:
|
for chunk in r.input_chunks.chunks:
|
||||||
if chunk["type"] == "text":
|
chunk_type = chunk.WhichOneof("chunk")
|
||||||
full_text += chunk["content"]
|
if chunk_type == "text":
|
||||||
elif chunk["type"] == "image":
|
full_text += chunk.text
|
||||||
image = chunk["content"]
|
elif chunk_type == "image":
|
||||||
# Should never receive URLs anymore, processing should be done
|
image = Image.open(BytesIO(chunk.image.data))
|
||||||
# On the rust layer.
|
|
||||||
# This avoid making n queries per TP
|
|
||||||
# if image.startswith("https://") or image.startswith("http://"):
|
|
||||||
# image = processor.image_processor.fetch_images(image)
|
|
||||||
if image.startswith("data:"):
|
|
||||||
image = load_data_uri(image)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Cannot process input image not starting with data:"
|
|
||||||
)
|
|
||||||
image_input = processor.image_processor(image, return_tensors="pt")
|
image_input = processor.image_processor(image, return_tensors="pt")
|
||||||
full_text += image_text_replacement(image_input, config, image_id)
|
full_text += image_text_replacement(image_input, config, image_id)
|
||||||
image_inputs.append(image_input)
|
image_inputs.append(image_input)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
|
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||||
|
|
||||||
batch_inputs.append(full_text)
|
batch_inputs.append(full_text)
|
||||||
max_truncation = max(max_truncation, r.truncate)
|
max_truncation = max(max_truncation, r.truncate)
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from text_generation_server.pb import generate_pb2
|
||||||
|
|
||||||
|
|
||||||
|
def concat_text_chunks(chunks: Iterable[generate_pb2.InputChunk]) -> str:
|
||||||
|
"""
|
||||||
|
Concatenate text in text chunks. Non-text chunks are dropped.
|
||||||
|
"""
|
||||||
|
text = None
|
||||||
|
for chunk in chunks:
|
||||||
|
chunk_type = chunk.WhichOneof("chunk")
|
||||||
|
if chunk_type == "text":
|
||||||
|
if text is None:
|
||||||
|
text = chunk.text
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Request contained more than one text chunk")
|
||||||
|
else:
|
||||||
|
# We cannot reject this, e.g. warmup sends an image chunk.
|
||||||
|
logger.debug(f"Encountered non-text chunk type {chunk_type}")
|
||||||
|
|
||||||
|
if text is None:
|
||||||
|
raise NotImplementedError("Request without a text chunk")
|
||||||
|
|
||||||
|
return text
|
Loading…
Reference in New Issue