Merge branch 'main' into ci_amd2
This commit is contained in:
commit
c73355b99c
|
@ -152,8 +152,7 @@ jobs:
|
|||
group: ${{ github.workflow }}-${{ github.job }}-rocm-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
runs-on: [amd-gpu-tgi, multi-gpu, mi250]
|
||||
needs:
|
||||
- build-and-push-image
|
||||
needs: build-and-push-image
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
|
|
@ -48,7 +48,7 @@ RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | gpg --dea
|
|||
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
|
||||
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list
|
||||
|
||||
RUN apt-get update && apt install -y intel-basekit xpu-smi
|
||||
RUN apt-get update && apt install -y intel-basekit xpu-smi cmake python3-dev ninja-build
|
||||
|
||||
# Text Generation Inference base env
|
||||
ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||
|
@ -57,8 +57,8 @@ ENV HUGGINGFACE_HUB_CACHE=/data \
|
|||
|
||||
|
||||
WORKDIR /usr/src
|
||||
RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.1.30a0-cp310-cp310-linux_x86_64.whl
|
||||
RUN pip install intel_extension_for_pytorch-2.1.30a0-cp310-cp310-linux_x86_64.whl
|
||||
RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl
|
||||
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b group_rope origin/dev/gqa_rope
|
||||
|
||||
# Install server
|
||||
COPY proto proto
|
||||
|
@ -76,6 +76,10 @@ ENV LIBRARY_PATH=/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/ccl/latest/l
|
|||
ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64:
|
||||
ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
|
||||
ENV CCL_ZE_IPC_EXCHANGE=sockets
|
||||
ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest
|
||||
ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include
|
||||
|
||||
RUN pip uninstall -y intel-extension-for-pytorch && cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc' BUILD_SEPARATE_OPS=OFF BUILD_WITH_CPU=OFF USE_XETLA=ON python setup.py install && rm -rf /usr/src/intel-extension-for-pytorch
|
||||
|
||||
# Install benchmarker
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||
|
|
|
@ -29,6 +29,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
|||
return generate_pb2.Request(
|
||||
id=0,
|
||||
inputs="Test",
|
||||
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]),
|
||||
prefill_logprobs=True,
|
||||
truncate=100,
|
||||
parameters=default_pb_parameters,
|
||||
|
|
|
@ -25,6 +25,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
|||
return generate_pb2.Request(
|
||||
id=0,
|
||||
inputs="Test",
|
||||
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]),
|
||||
prefill_logprobs=True,
|
||||
truncate=100,
|
||||
parameters=default_pb_parameters,
|
||||
|
|
|
@ -15,6 +15,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
|||
return generate_pb2.Request(
|
||||
id=0,
|
||||
inputs="def",
|
||||
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="def")]),
|
||||
prefill_logprobs=True,
|
||||
truncate=100,
|
||||
parameters=default_pb_parameters,
|
||||
|
@ -32,6 +33,13 @@ def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
|||
return generate_pb2.Request(
|
||||
id=0,
|
||||
inputs="<fim-prefix>def<fim-suffix>world<fim-middle>",
|
||||
input_chunks=generate_pb2.Input(
|
||||
chunks=[
|
||||
generate_pb2.InputChunk(
|
||||
text="<fim-prefix>def<fim-suffix>world<fim-middle>"
|
||||
)
|
||||
]
|
||||
),
|
||||
prefill_logprobs=True,
|
||||
truncate=100,
|
||||
parameters=default_pb_parameters,
|
||||
|
|
|
@ -28,6 +28,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
|||
return generate_pb2.Request(
|
||||
id=0,
|
||||
inputs="Test",
|
||||
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]),
|
||||
prefill_logprobs=True,
|
||||
truncate=100,
|
||||
parameters=default_pb_parameters,
|
||||
|
|
|
@ -7,6 +7,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenize
|
|||
from typing import Optional, Tuple, List, Type, Dict
|
||||
|
||||
from text_generation_server.models import Model
|
||||
from text_generation_server.utils.chunks import concat_text_chunks
|
||||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
from text_generation_server.models.types import (
|
||||
Batch,
|
||||
|
@ -86,7 +87,8 @@ class CausalLMBatch(Batch):
|
|||
max_decode_tokens = 0
|
||||
for i, r in enumerate(pb.requests):
|
||||
requests_idx_mapping[r.id] = i
|
||||
inputs.append(r.inputs)
|
||||
inputs.append(concat_text_chunks(r.input_chunks.chunks))
|
||||
|
||||
next_token_choosers.append(
|
||||
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
|
||||
)
|
||||
|
|
|
@ -11,9 +11,10 @@ from loguru import logger
|
|||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from typing import Optional, Tuple, List, Type, Dict
|
||||
from typing import Iterable, Optional, Tuple, List, Type, Dict
|
||||
|
||||
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
||||
from text_generation_server.utils.chunks import concat_text_chunks
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models import Model
|
||||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
|
@ -127,11 +128,13 @@ class FlashCausalLMBatch(Batch):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def batch_tokenized_inputs(cls, requests, tokenizer):
|
||||
def batch_tokenized_inputs(
|
||||
cls, requests: Iterable[generate_pb2.Request], tokenizer
|
||||
):
|
||||
batch_inputs = []
|
||||
max_truncation = 0
|
||||
for r in requests:
|
||||
batch_inputs.append(r.inputs)
|
||||
batch_inputs.append(concat_text_chunks(r.input_chunks.chunks))
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
|
||||
batch_tokenized_inputs = tokenizer(
|
||||
|
|
|
@ -20,6 +20,7 @@ from text_generation_server.utils import (
|
|||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.chunks import concat_text_chunks
|
||||
|
||||
# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py
|
||||
|
||||
|
@ -91,7 +92,9 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||
for i, r in enumerate(pb.requests):
|
||||
requests_idx_mapping[r.id] = i
|
||||
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
||||
inputs.append(escape_custom_split_sequence(r.inputs))
|
||||
inputs.append(
|
||||
escape_custom_split_sequence(concat_text_chunks(r.input_chunks.chunks))
|
||||
)
|
||||
next_token_choosers.append(
|
||||
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
|
||||
)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import torch
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
import torch
|
||||
import time
|
||||
|
||||
|
@ -21,11 +22,6 @@ from text_generation_server.models.types import (
|
|||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
from text_generation_server.models.vlm_causal_lm import split
|
||||
|
||||
import re
|
||||
|
||||
IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
|
||||
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
@ -109,7 +105,7 @@ class IdeficsCausalLMBatch(Batch):
|
|||
max_decode_tokens = 0
|
||||
for i, r in enumerate(pb.requests):
|
||||
requests_idx_mapping[r.id] = i
|
||||
inputs.append(r.inputs)
|
||||
inputs.append(r.input_chunks.chunks)
|
||||
next_token_choosers.append(
|
||||
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
|
||||
)
|
||||
|
@ -128,8 +124,15 @@ class IdeficsCausalLMBatch(Batch):
|
|||
for inp in inputs:
|
||||
# Each input is encoded into a list, where each element of this input list is either a string or a URL
|
||||
prompt = []
|
||||
for chunk in split(inp):
|
||||
prompt.append(chunk["content"])
|
||||
for chunk in inp:
|
||||
chunk_type = chunk.WhichOneof("chunk")
|
||||
if chunk_type == "text":
|
||||
prompt.append(chunk.text)
|
||||
elif chunk_type == "image":
|
||||
image = Image.open(BytesIO(chunk.image.data))
|
||||
prompt.append(image)
|
||||
else:
|
||||
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||
prompts.append(prompt)
|
||||
|
||||
# The processor replaces the call to tokenizer, and
|
||||
|
|
|
@ -27,6 +27,7 @@ from text_generation_server.models.types import (
|
|||
Generation,
|
||||
GeneratedText,
|
||||
)
|
||||
from text_generation_server.utils.chunks import concat_text_chunks
|
||||
from text_generation_server.utils.tokens import batch_top_tokens, Sampling
|
||||
from dataclasses import dataclass
|
||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
|
@ -139,7 +140,7 @@ class MambaBatch(Batch):
|
|||
max_decode_tokens = 0
|
||||
for i, r in enumerate(pb.requests):
|
||||
requests_idx_mapping[r.id] = i
|
||||
inputs.append(r.inputs)
|
||||
inputs.append(concat_text_chunks(r.input_chunks.chunks))
|
||||
next_token_choosers.append(
|
||||
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
|
||||
)
|
||||
|
|
|
@ -1,55 +1,48 @@
|
|||
from io import BytesIO
|
||||
from PIL import Image
|
||||
import torch
|
||||
import torch.distributed
|
||||
from opentelemetry import trace
|
||||
from typing import Optional, Tuple
|
||||
from typing import Iterable, Optional, Tuple
|
||||
from text_generation_server.models.vlm_causal_lm import (
|
||||
VlmCausalLM,
|
||||
VlmCausalLMBatch,
|
||||
image_text_replacement,
|
||||
load_data_uri,
|
||||
split,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
|
||||
PaliGemmaForConditionalGeneration,
|
||||
)
|
||||
from transformers import AutoProcessor, AutoConfig, AutoImageProcessor
|
||||
from transformers import AutoProcessor, AutoConfig
|
||||
|
||||
from text_generation_server.pb.generate_pb2 import Request
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class PaliGemmaBatch(VlmCausalLMBatch):
|
||||
@classmethod
|
||||
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config):
|
||||
def batch_tokenized_inputs(
|
||||
cls, requests: Iterable[Request], tokenizer, processor, config
|
||||
):
|
||||
batch_inputs = []
|
||||
image_inputs = []
|
||||
max_truncation = 0
|
||||
for r in requests:
|
||||
chunks = split(r.inputs)
|
||||
full_text = ""
|
||||
image_id = 0
|
||||
for chunk in chunks:
|
||||
if chunk["type"] == "text":
|
||||
full_text += "<bos>" + chunk["content"] + "\n"
|
||||
elif chunk["type"] == "image":
|
||||
image = chunk["content"]
|
||||
# Should never receive URLs anymore, processing should be done
|
||||
# On the rust layer.
|
||||
# This avoid making n queries per TP
|
||||
# if image.startswith("https://") or image.startswith("http://"):
|
||||
# image = processor.image_processor.fetch_images(image)
|
||||
if image.startswith("data:"):
|
||||
image = load_data_uri(image)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Cannot process input image not starting with data:"
|
||||
)
|
||||
for chunk in r.input_chunks.chunks:
|
||||
chunk_type = chunk.WhichOneof("chunk")
|
||||
if chunk_type == "text":
|
||||
full_text += "<bos>" + chunk.text + "\n"
|
||||
elif chunk_type == "image":
|
||||
image = Image.open(BytesIO(chunk.image.data))
|
||||
# TODO do_convert_RGB should be on by default ?
|
||||
image = image.convert("RGB")
|
||||
image_input = processor.image_processor(image, return_tensors="pt")
|
||||
full_text += image_text_replacement(image_input, config, image_id)
|
||||
image_inputs.append(image_input)
|
||||
else:
|
||||
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
|
||||
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||
|
||||
batch_inputs.append(full_text)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
|
|
|
@ -6,6 +6,7 @@ from opentelemetry import trace
|
|||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
|
||||
from typing import Optional, Tuple, List, Type, Dict
|
||||
|
||||
from text_generation_server.utils.chunks import concat_text_chunks
|
||||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
from text_generation_server.models import Model
|
||||
from text_generation_server.models.types import (
|
||||
|
@ -93,7 +94,7 @@ class Seq2SeqLMBatch(Batch):
|
|||
padding_right_offset = 0
|
||||
max_decode_tokens = 0
|
||||
for i, r in enumerate(pb.requests):
|
||||
inputs.append(r.inputs)
|
||||
inputs.append(concat_text_chunks(r.input_chunks.chunks))
|
||||
requests_idx_mapping[r.id] = i
|
||||
decoder_input_lengths.append(1)
|
||||
next_token_choosers.append(
|
||||
|
|
|
@ -1,12 +1,9 @@
|
|||
import re
|
||||
import torch
|
||||
import math
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
import base64
|
||||
|
||||
from opentelemetry import trace
|
||||
from typing import Optional, Tuple, List, Type, Dict
|
||||
from typing import Iterable, Optional, Tuple, List, Type, Dict
|
||||
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers.image_processing_utils import select_best_resolution
|
||||
|
@ -18,25 +15,6 @@ from text_generation_server.models.flash_mistral import (
|
|||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
|
||||
|
||||
|
||||
def split(string) -> List[Dict[str, str]]:
|
||||
parts = []
|
||||
cursor = 0
|
||||
for pattern in IMAGES.finditer(string):
|
||||
start = pattern.start()
|
||||
if start != cursor:
|
||||
parts.append({"type": "text", "content": string[cursor:start]})
|
||||
|
||||
parts.append({"type": "image", "content": pattern.group(1)})
|
||||
cursor = pattern.end()
|
||||
|
||||
if cursor != len(string):
|
||||
parts.append({"type": "text", "content": string[cursor:]})
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
"""
|
||||
|
@ -129,13 +107,6 @@ def get_number_of_features(height: int, width: int, config) -> int:
|
|||
return unpadded_features + newline_features + base_features
|
||||
|
||||
|
||||
def load_data_uri(image_uri: str) -> Image.Image:
|
||||
image_uri = image_uri.split(",")[-1]
|
||||
content = base64.b64decode(image_uri)
|
||||
image = Image.open(BytesIO(content))
|
||||
return image
|
||||
|
||||
|
||||
class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
pixel_values: Optional[List[torch.Tensor]]
|
||||
pixel_attention_mask: Optional[List[torch.Tensor]]
|
||||
|
@ -159,35 +130,26 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||
return batch
|
||||
|
||||
@classmethod
|
||||
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config):
|
||||
def batch_tokenized_inputs(
|
||||
cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config
|
||||
):
|
||||
batch_inputs = []
|
||||
image_inputs = []
|
||||
max_truncation = 0
|
||||
for r in requests:
|
||||
chunks = split(r.inputs)
|
||||
full_text = ""
|
||||
image_id = 0
|
||||
for chunk in chunks:
|
||||
if chunk["type"] == "text":
|
||||
full_text += chunk["content"]
|
||||
elif chunk["type"] == "image":
|
||||
image = chunk["content"]
|
||||
# Should never receive URLs anymore, processing should be done
|
||||
# On the rust layer.
|
||||
# This avoid making n queries per TP
|
||||
# if image.startswith("https://") or image.startswith("http://"):
|
||||
# image = processor.image_processor.fetch_images(image)
|
||||
if image.startswith("data:"):
|
||||
image = load_data_uri(image)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Cannot process input image not starting with data:"
|
||||
)
|
||||
for chunk in r.input_chunks.chunks:
|
||||
chunk_type = chunk.WhichOneof("chunk")
|
||||
if chunk_type == "text":
|
||||
full_text += chunk.text
|
||||
elif chunk_type == "image":
|
||||
image = Image.open(BytesIO(chunk.image.data))
|
||||
image_input = processor.image_processor(image, return_tensors="pt")
|
||||
full_text += image_text_replacement(image_input, config, image_id)
|
||||
image_inputs.append(image_input)
|
||||
else:
|
||||
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
|
||||
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||
|
||||
batch_inputs.append(full_text)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
from typing import Iterable
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from text_generation_server.pb import generate_pb2
|
||||
|
||||
|
||||
def concat_text_chunks(chunks: Iterable[generate_pb2.InputChunk]) -> str:
|
||||
"""
|
||||
Concatenate text in text chunks. Non-text chunks are dropped.
|
||||
"""
|
||||
text = None
|
||||
for chunk in chunks:
|
||||
chunk_type = chunk.WhichOneof("chunk")
|
||||
if chunk_type == "text":
|
||||
if text is None:
|
||||
text = chunk.text
|
||||
else:
|
||||
raise NotImplementedError("Request contained more than one text chunk")
|
||||
else:
|
||||
# We cannot reject this, e.g. warmup sends an image chunk.
|
||||
logger.debug(f"Encountered non-text chunk type {chunk_type}")
|
||||
|
||||
if text is None:
|
||||
raise NotImplementedError("Request without a text chunk")
|
||||
|
||||
return text
|
Loading…
Reference in New Issue