From 4dabddb7ea1fbbd4aee47dfd68e462e73f8e6b87 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Fri, 7 Jun 2024 01:12:57 +0800 Subject: [PATCH 1/2] Xpu gqa (#2013) # What does this PR do? Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. Signed-off-by: Wang, Yi A --- Dockerfile_intel | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/Dockerfile_intel b/Dockerfile_intel index ee963928..cb0e84bb 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -48,7 +48,7 @@ RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | gpg --dea RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \ | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list -RUN apt-get update && apt install -y intel-basekit xpu-smi +RUN apt-get update && apt install -y intel-basekit xpu-smi cmake python3-dev ninja-build # Text Generation Inference base env ENV HUGGINGFACE_HUB_CACHE=/data \ @@ -57,8 +57,8 @@ ENV HUGGINGFACE_HUB_CACHE=/data \ WORKDIR /usr/src -RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.1.30a0-cp310-cp310-linux_x86_64.whl -RUN pip install intel_extension_for_pytorch-2.1.30a0-cp310-cp310-linux_x86_64.whl +RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl +RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b group_rope origin/dev/gqa_rope # Install server COPY proto proto @@ -76,6 +76,10 @@ ENV LIBRARY_PATH=/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/ccl/latest/l ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64: ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin ENV CCL_ZE_IPC_EXCHANGE=sockets +ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest +ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include + +RUN pip uninstall -y intel-extension-for-pytorch && cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc' BUILD_SEPARATE_OPS=OFF BUILD_WITH_CPU=OFF USE_XETLA=ON python setup.py install && rm -rf /usr/src/intel-extension-for-pytorch # Install benchmarker COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark From bf3c81378223fa2ee4212050c9886338feb19371 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Fri, 31 May 2024 11:51:42 +0000 Subject: [PATCH 2/2] server: use chunked inputs The router will now send the input as chunks besides as a single string. This change modifies the server to process chunked input rather than strings. This also allows us to remove the image extraction code from the server. --- .github/workflows/build.yaml | 2 +- server/tests/models/test_bloom.py | 1 + server/tests/models/test_causal_lm.py | 1 + server/tests/models/test_santacoder.py | 8 +++ server/tests/models/test_seq2seq_lm.py | 1 + .../models/causal_lm.py | 4 +- .../models/flash_causal_lm.py | 9 ++- .../models/galactica.py | 5 +- .../models/idefics_causal_lm.py | 21 ++++--- server/text_generation_server/models/mamba.py | 3 +- .../models/pali_gemma.py | 39 +++++------- .../models/seq2seq_lm.py | 3 +- .../models/vlm_causal_lm.py | 60 ++++--------------- server/text_generation_server/utils/chunks.py | 27 +++++++++ 14 files changed, 95 insertions(+), 89 deletions(-) create mode 100644 server/text_generation_server/utils/chunks.py diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 84266ce5..e80037b1 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -30,7 +30,7 @@ jobs: cancel-in-progress: true runs-on: [self-hosted, nvidia-gpu , multi-gpu, 4-a10, ci] strategy: - matrix: + matrix: include: - name: "cuda" label: "" diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 66df708a..32ee6686 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -29,6 +29,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="Test", + input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]), prefill_logprobs=True, truncate=100, parameters=default_pb_parameters, diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 250fa354..6e6463bc 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -25,6 +25,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="Test", + input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]), prefill_logprobs=True, truncate=100, parameters=default_pb_parameters, diff --git a/server/tests/models/test_santacoder.py b/server/tests/models/test_santacoder.py index 1e40e766..cb2622d9 100644 --- a/server/tests/models/test_santacoder.py +++ b/server/tests/models/test_santacoder.py @@ -15,6 +15,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="def", + input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="def")]), prefill_logprobs=True, truncate=100, parameters=default_pb_parameters, @@ -32,6 +33,13 @@ def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="defworld", + input_chunks=generate_pb2.Input( + chunks=[ + generate_pb2.InputChunk( + text="defworld" + ) + ] + ), prefill_logprobs=True, truncate=100, parameters=default_pb_parameters, diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index 735ab5eb..943c3b08 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -28,6 +28,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): return generate_pb2.Request( id=0, inputs="Test", + input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]), prefill_logprobs=True, truncate=100, parameters=default_pb_parameters, diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 81a02163..e896c831 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -7,6 +7,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenize from typing import Optional, Tuple, List, Type, Dict from text_generation_server.models import Model +from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.models.types import ( Batch, @@ -86,7 +87,8 @@ class CausalLMBatch(Batch): max_decode_tokens = 0 for i, r in enumerate(pb.requests): requests_idx_mapping[r.id] = i - inputs.append(r.inputs) + inputs.append(concat_text_chunks(r.input_chunks.chunks)) + next_token_choosers.append( NextTokenChooser.from_pb(r.parameters, device, tokenizer) ) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index d8c8838c..acf77b09 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -11,9 +11,10 @@ from loguru import logger from dataclasses import dataclass from opentelemetry import trace from transformers import PreTrainedTokenizerBase -from typing import Optional, Tuple, List, Type, Dict +from typing import Iterable, Optional, Tuple, List, Type, Dict from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE +from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models import Model from text_generation_server.utils.tokens import batch_top_tokens @@ -127,11 +128,13 @@ class FlashCausalLMBatch(Batch): ) @classmethod - def batch_tokenized_inputs(cls, requests, tokenizer): + def batch_tokenized_inputs( + cls, requests: Iterable[generate_pb2.Request], tokenizer + ): batch_inputs = [] max_truncation = 0 for r in requests: - batch_inputs.append(r.inputs) + batch_inputs.append(concat_text_chunks(r.input_chunks.chunks)) max_truncation = max(max_truncation, r.truncate) batch_tokenized_inputs = tokenizer( diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 4656fd45..d0f2b915 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -20,6 +20,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) +from text_generation_server.utils.chunks import concat_text_chunks # CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py @@ -91,7 +92,9 @@ class GalacticaCausalLMBatch(CausalLMBatch): for i, r in enumerate(pb.requests): requests_idx_mapping[r.id] = i # Add escape_custom_split_sequence to the CausalLMBatch logic - inputs.append(escape_custom_split_sequence(r.inputs)) + inputs.append( + escape_custom_split_sequence(concat_text_chunks(r.input_chunks.chunks)) + ) next_token_choosers.append( NextTokenChooser.from_pb(r.parameters, device, tokenizer) ) diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index e78a9655..f507d669 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -1,4 +1,5 @@ -import torch +from io import BytesIO +from PIL import Image import torch import time @@ -21,11 +22,6 @@ from text_generation_server.models.types import ( ) from text_generation_server.pb import generate_pb2 from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling -from text_generation_server.models.vlm_causal_lm import split - -import re - -IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)") tracer = trace.get_tracer(__name__) @@ -109,7 +105,7 @@ class IdeficsCausalLMBatch(Batch): max_decode_tokens = 0 for i, r in enumerate(pb.requests): requests_idx_mapping[r.id] = i - inputs.append(r.inputs) + inputs.append(r.input_chunks.chunks) next_token_choosers.append( NextTokenChooser.from_pb(r.parameters, device, tokenizer) ) @@ -128,8 +124,15 @@ class IdeficsCausalLMBatch(Batch): for inp in inputs: # Each input is encoded into a list, where each element of this input list is either a string or a URL prompt = [] - for chunk in split(inp): - prompt.append(chunk["content"]) + for chunk in inp: + chunk_type = chunk.WhichOneof("chunk") + if chunk_type == "text": + prompt.append(chunk.text) + elif chunk_type == "image": + image = Image.open(BytesIO(chunk.image.data)) + prompt.append(image) + else: + raise RuntimeError(f"Invalid chunk type {chunk_type}") prompts.append(prompt) # The processor replaces the call to tokenizer, and diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index d9f90590..3133a137 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -27,6 +27,7 @@ from text_generation_server.models.types import ( Generation, GeneratedText, ) +from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.tokens import batch_top_tokens, Sampling from dataclasses import dataclass from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling @@ -139,7 +140,7 @@ class MambaBatch(Batch): max_decode_tokens = 0 for i, r in enumerate(pb.requests): requests_idx_mapping[r.id] = i - inputs.append(r.inputs) + inputs.append(concat_text_chunks(r.input_chunks.chunks)) next_token_choosers.append( NextTokenChooser.from_pb(r.parameters, device, tokenizer) ) diff --git a/server/text_generation_server/models/pali_gemma.py b/server/text_generation_server/models/pali_gemma.py index d94b9526..e883ce02 100644 --- a/server/text_generation_server/models/pali_gemma.py +++ b/server/text_generation_server/models/pali_gemma.py @@ -1,55 +1,48 @@ +from io import BytesIO +from PIL import Image import torch import torch.distributed from opentelemetry import trace -from typing import Optional, Tuple +from typing import Iterable, Optional, Tuple from text_generation_server.models.vlm_causal_lm import ( VlmCausalLM, VlmCausalLMBatch, image_text_replacement, - load_data_uri, - split, ) from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import ( PaliGemmaForConditionalGeneration, ) -from transformers import AutoProcessor, AutoConfig, AutoImageProcessor +from transformers import AutoProcessor, AutoConfig + +from text_generation_server.pb.generate_pb2 import Request tracer = trace.get_tracer(__name__) class PaliGemmaBatch(VlmCausalLMBatch): @classmethod - def batch_tokenized_inputs(cls, requests, tokenizer, processor, config): + def batch_tokenized_inputs( + cls, requests: Iterable[Request], tokenizer, processor, config + ): batch_inputs = [] image_inputs = [] max_truncation = 0 for r in requests: - chunks = split(r.inputs) full_text = "" image_id = 0 - for chunk in chunks: - if chunk["type"] == "text": - full_text += "" + chunk["content"] + "\n" - elif chunk["type"] == "image": - image = chunk["content"] - # Should never receive URLs anymore, processing should be done - # On the rust layer. - # This avoid making n queries per TP - # if image.startswith("https://") or image.startswith("http://"): - # image = processor.image_processor.fetch_images(image) - if image.startswith("data:"): - image = load_data_uri(image) - else: - raise RuntimeError( - "Cannot process input image not starting with data:" - ) + for chunk in r.input_chunks.chunks: + chunk_type = chunk.WhichOneof("chunk") + if chunk_type == "text": + full_text += "" + chunk.text + "\n" + elif chunk_type == "image": + image = Image.open(BytesIO(chunk.image.data)) # TODO do_convert_RGB should be on by default ? image = image.convert("RGB") image_input = processor.image_processor(image, return_tensors="pt") full_text += image_text_replacement(image_input, config, image_id) image_inputs.append(image_input) else: - raise RuntimeError(f"Invalid chunk type {chunk['type']}") + raise RuntimeError(f"Invalid chunk type {chunk_type}") batch_inputs.append(full_text) max_truncation = max(max_truncation, r.truncate) diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 6a0c812f..3bd09556 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -6,6 +6,7 @@ from opentelemetry import trace from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase from typing import Optional, Tuple, List, Type, Dict +from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.models import Model from text_generation_server.models.types import ( @@ -93,7 +94,7 @@ class Seq2SeqLMBatch(Batch): padding_right_offset = 0 max_decode_tokens = 0 for i, r in enumerate(pb.requests): - inputs.append(r.inputs) + inputs.append(concat_text_chunks(r.input_chunks.chunks)) requests_idx_mapping[r.id] = i decoder_input_lengths.append(1) next_token_choosers.append( diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 92d79070..59a6fab1 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -1,12 +1,9 @@ -import re import torch -import math from PIL import Image from io import BytesIO -import base64 from opentelemetry import trace -from typing import Optional, Tuple, List, Type, Dict +from typing import Iterable, Optional, Tuple, List, Type, Dict from transformers import PreTrainedTokenizerBase from transformers.image_processing_utils import select_best_resolution @@ -18,25 +15,6 @@ from text_generation_server.models.flash_mistral import ( tracer = trace.get_tracer(__name__) -IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)") - - -def split(string) -> List[Dict[str, str]]: - parts = [] - cursor = 0 - for pattern in IMAGES.finditer(string): - start = pattern.start() - if start != cursor: - parts.append({"type": "text", "content": string[cursor:start]}) - - parts.append({"type": "image", "content": pattern.group(1)}) - cursor = pattern.end() - - if cursor != len(string): - parts.append({"type": "text", "content": string[cursor:]}) - - return parts - def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): """ @@ -129,13 +107,6 @@ def get_number_of_features(height: int, width: int, config) -> int: return unpadded_features + newline_features + base_features -def load_data_uri(image_uri: str) -> Image.Image: - image_uri = image_uri.split(",")[-1] - content = base64.b64decode(image_uri) - image = Image.open(BytesIO(content)) - return image - - class VlmCausalLMBatch(FlashCausalLMBatch): pixel_values: Optional[List[torch.Tensor]] pixel_attention_mask: Optional[List[torch.Tensor]] @@ -159,35 +130,26 @@ class VlmCausalLMBatch(FlashCausalLMBatch): return batch @classmethod - def batch_tokenized_inputs(cls, requests, tokenizer, processor, config): + def batch_tokenized_inputs( + cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config + ): batch_inputs = [] image_inputs = [] max_truncation = 0 for r in requests: - chunks = split(r.inputs) full_text = "" image_id = 0 - for chunk in chunks: - if chunk["type"] == "text": - full_text += chunk["content"] - elif chunk["type"] == "image": - image = chunk["content"] - # Should never receive URLs anymore, processing should be done - # On the rust layer. - # This avoid making n queries per TP - # if image.startswith("https://") or image.startswith("http://"): - # image = processor.image_processor.fetch_images(image) - if image.startswith("data:"): - image = load_data_uri(image) - else: - raise RuntimeError( - "Cannot process input image not starting with data:" - ) + for chunk in r.input_chunks.chunks: + chunk_type = chunk.WhichOneof("chunk") + if chunk_type == "text": + full_text += chunk.text + elif chunk_type == "image": + image = Image.open(BytesIO(chunk.image.data)) image_input = processor.image_processor(image, return_tensors="pt") full_text += image_text_replacement(image_input, config, image_id) image_inputs.append(image_input) else: - raise RuntimeError(f"Invalid chunk type {chunk['type']}") + raise RuntimeError(f"Invalid chunk type {chunk_type}") batch_inputs.append(full_text) max_truncation = max(max_truncation, r.truncate) diff --git a/server/text_generation_server/utils/chunks.py b/server/text_generation_server/utils/chunks.py new file mode 100644 index 00000000..73962ea3 --- /dev/null +++ b/server/text_generation_server/utils/chunks.py @@ -0,0 +1,27 @@ +from typing import Iterable + +from loguru import logger + +from text_generation_server.pb import generate_pb2 + + +def concat_text_chunks(chunks: Iterable[generate_pb2.InputChunk]) -> str: + """ + Concatenate text in text chunks. Non-text chunks are dropped. + """ + text = None + for chunk in chunks: + chunk_type = chunk.WhichOneof("chunk") + if chunk_type == "text": + if text is None: + text = chunk.text + else: + raise NotImplementedError("Request contained more than one text chunk") + else: + # We cannot reject this, e.g. warmup sends an image chunk. + logger.debug(f"Encountered non-text chunk type {chunk_type}") + + if text is None: + raise NotImplementedError("Request without a text chunk") + + return text