Merge branch 'main' into ci_amd2

This commit is contained in:
Nicolas Patry 2024-06-07 10:04:59 +02:00 committed by GitHub
commit c73355b99c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 102 additions and 93 deletions

View File

@ -152,8 +152,7 @@ jobs:
group: ${{ github.workflow }}-${{ github.job }}-rocm-${{ github.head_ref || github.run_id }} group: ${{ github.workflow }}-${{ github.job }}-rocm-${{ github.head_ref || github.run_id }}
cancel-in-progress: true cancel-in-progress: true
runs-on: [amd-gpu-tgi, multi-gpu, mi250] runs-on: [amd-gpu-tgi, multi-gpu, mi250]
needs: needs: build-and-push-image
- build-and-push-image
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v4

View File

@ -48,7 +48,7 @@ RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | gpg --dea
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list
RUN apt-get update && apt install -y intel-basekit xpu-smi RUN apt-get update && apt install -y intel-basekit xpu-smi cmake python3-dev ninja-build
# Text Generation Inference base env # Text Generation Inference base env
ENV HUGGINGFACE_HUB_CACHE=/data \ ENV HUGGINGFACE_HUB_CACHE=/data \
@ -57,8 +57,8 @@ ENV HUGGINGFACE_HUB_CACHE=/data \
WORKDIR /usr/src WORKDIR /usr/src
RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.1.30a0-cp310-cp310-linux_x86_64.whl RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl
RUN pip install intel_extension_for_pytorch-2.1.30a0-cp310-cp310-linux_x86_64.whl RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b group_rope origin/dev/gqa_rope
# Install server # Install server
COPY proto proto COPY proto proto
@ -76,6 +76,10 @@ ENV LIBRARY_PATH=/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/ccl/latest/l
ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64: ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64:
ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
ENV CCL_ZE_IPC_EXCHANGE=sockets ENV CCL_ZE_IPC_EXCHANGE=sockets
ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest
ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include
RUN pip uninstall -y intel-extension-for-pytorch && cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc' BUILD_SEPARATE_OPS=OFF BUILD_WITH_CPU=OFF USE_XETLA=ON python setup.py install && rm -rf /usr/src/intel-extension-for-pytorch
# Install benchmarker # Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark

View File

@ -29,6 +29,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request( return generate_pb2.Request(
id=0, id=0,
inputs="Test", inputs="Test",
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]),
prefill_logprobs=True, prefill_logprobs=True,
truncate=100, truncate=100,
parameters=default_pb_parameters, parameters=default_pb_parameters,

View File

@ -25,6 +25,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request( return generate_pb2.Request(
id=0, id=0,
inputs="Test", inputs="Test",
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]),
prefill_logprobs=True, prefill_logprobs=True,
truncate=100, truncate=100,
parameters=default_pb_parameters, parameters=default_pb_parameters,

View File

@ -15,6 +15,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request( return generate_pb2.Request(
id=0, id=0,
inputs="def", inputs="def",
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="def")]),
prefill_logprobs=True, prefill_logprobs=True,
truncate=100, truncate=100,
parameters=default_pb_parameters, parameters=default_pb_parameters,
@ -32,6 +33,13 @@ def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request( return generate_pb2.Request(
id=0, id=0,
inputs="<fim-prefix>def<fim-suffix>world<fim-middle>", inputs="<fim-prefix>def<fim-suffix>world<fim-middle>",
input_chunks=generate_pb2.Input(
chunks=[
generate_pb2.InputChunk(
text="<fim-prefix>def<fim-suffix>world<fim-middle>"
)
]
),
prefill_logprobs=True, prefill_logprobs=True,
truncate=100, truncate=100,
parameters=default_pb_parameters, parameters=default_pb_parameters,

View File

@ -28,6 +28,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request( return generate_pb2.Request(
id=0, id=0,
inputs="Test", inputs="Test",
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]),
prefill_logprobs=True, prefill_logprobs=True,
truncate=100, truncate=100,
parameters=default_pb_parameters, parameters=default_pb_parameters,

View File

@ -7,6 +7,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenize
from typing import Optional, Tuple, List, Type, Dict from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models.types import ( from text_generation_server.models.types import (
Batch, Batch,
@ -86,7 +87,8 @@ class CausalLMBatch(Batch):
max_decode_tokens = 0 max_decode_tokens = 0
for i, r in enumerate(pb.requests): for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
inputs.append(r.inputs) inputs.append(concat_text_chunks(r.input_chunks.chunks))
next_token_choosers.append( next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer) NextTokenChooser.from_pb(r.parameters, device, tokenizer)
) )

View File

@ -11,9 +11,10 @@ from loguru import logger
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Dict from typing import Iterable, Optional, Tuple, List, Type, Dict
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.tokens import batch_top_tokens
@ -127,11 +128,13 @@ class FlashCausalLMBatch(Batch):
) )
@classmethod @classmethod
def batch_tokenized_inputs(cls, requests, tokenizer): def batch_tokenized_inputs(
cls, requests: Iterable[generate_pb2.Request], tokenizer
):
batch_inputs = [] batch_inputs = []
max_truncation = 0 max_truncation = 0
for r in requests: for r in requests:
batch_inputs.append(r.inputs) batch_inputs.append(concat_text_chunks(r.input_chunks.chunks))
max_truncation = max(max_truncation, r.truncate) max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer( batch_tokenized_inputs = tokenizer(

View File

@ -20,6 +20,7 @@ from text_generation_server.utils import (
weight_files, weight_files,
Weights, Weights,
) )
from text_generation_server.utils.chunks import concat_text_chunks
# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py # CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py
@ -91,7 +92,9 @@ class GalacticaCausalLMBatch(CausalLMBatch):
for i, r in enumerate(pb.requests): for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
# Add escape_custom_split_sequence to the CausalLMBatch logic # Add escape_custom_split_sequence to the CausalLMBatch logic
inputs.append(escape_custom_split_sequence(r.inputs)) inputs.append(
escape_custom_split_sequence(concat_text_chunks(r.input_chunks.chunks))
)
next_token_choosers.append( next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer) NextTokenChooser.from_pb(r.parameters, device, tokenizer)
) )

View File

@ -1,4 +1,5 @@
import torch from io import BytesIO
from PIL import Image
import torch import torch
import time import time
@ -21,11 +22,6 @@ from text_generation_server.models.types import (
) )
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
from text_generation_server.models.vlm_causal_lm import split
import re
IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -109,7 +105,7 @@ class IdeficsCausalLMBatch(Batch):
max_decode_tokens = 0 max_decode_tokens = 0
for i, r in enumerate(pb.requests): for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
inputs.append(r.inputs) inputs.append(r.input_chunks.chunks)
next_token_choosers.append( next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer) NextTokenChooser.from_pb(r.parameters, device, tokenizer)
) )
@ -128,8 +124,15 @@ class IdeficsCausalLMBatch(Batch):
for inp in inputs: for inp in inputs:
# Each input is encoded into a list, where each element of this input list is either a string or a URL # Each input is encoded into a list, where each element of this input list is either a string or a URL
prompt = [] prompt = []
for chunk in split(inp): for chunk in inp:
prompt.append(chunk["content"]) chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
prompt.append(chunk.text)
elif chunk_type == "image":
image = Image.open(BytesIO(chunk.image.data))
prompt.append(image)
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
prompts.append(prompt) prompts.append(prompt)
# The processor replaces the call to tokenizer, and # The processor replaces the call to tokenizer, and

View File

@ -27,6 +27,7 @@ from text_generation_server.models.types import (
Generation, Generation,
GeneratedText, GeneratedText,
) )
from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.tokens import batch_top_tokens, Sampling from text_generation_server.utils.tokens import batch_top_tokens, Sampling
from dataclasses import dataclass from dataclasses import dataclass
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
@ -139,7 +140,7 @@ class MambaBatch(Batch):
max_decode_tokens = 0 max_decode_tokens = 0
for i, r in enumerate(pb.requests): for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
inputs.append(r.inputs) inputs.append(concat_text_chunks(r.input_chunks.chunks))
next_token_choosers.append( next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer) NextTokenChooser.from_pb(r.parameters, device, tokenizer)
) )

View File

@ -1,55 +1,48 @@
from io import BytesIO
from PIL import Image
import torch import torch
import torch.distributed import torch.distributed
from opentelemetry import trace from opentelemetry import trace
from typing import Optional, Tuple from typing import Iterable, Optional, Tuple
from text_generation_server.models.vlm_causal_lm import ( from text_generation_server.models.vlm_causal_lm import (
VlmCausalLM, VlmCausalLM,
VlmCausalLMBatch, VlmCausalLMBatch,
image_text_replacement, image_text_replacement,
load_data_uri,
split,
) )
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import ( from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
PaliGemmaForConditionalGeneration, PaliGemmaForConditionalGeneration,
) )
from transformers import AutoProcessor, AutoConfig, AutoImageProcessor from transformers import AutoProcessor, AutoConfig
from text_generation_server.pb.generate_pb2 import Request
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
class PaliGemmaBatch(VlmCausalLMBatch): class PaliGemmaBatch(VlmCausalLMBatch):
@classmethod @classmethod
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config): def batch_tokenized_inputs(
cls, requests: Iterable[Request], tokenizer, processor, config
):
batch_inputs = [] batch_inputs = []
image_inputs = [] image_inputs = []
max_truncation = 0 max_truncation = 0
for r in requests: for r in requests:
chunks = split(r.inputs)
full_text = "" full_text = ""
image_id = 0 image_id = 0
for chunk in chunks: for chunk in r.input_chunks.chunks:
if chunk["type"] == "text": chunk_type = chunk.WhichOneof("chunk")
full_text += "<bos>" + chunk["content"] + "\n" if chunk_type == "text":
elif chunk["type"] == "image": full_text += "<bos>" + chunk.text + "\n"
image = chunk["content"] elif chunk_type == "image":
# Should never receive URLs anymore, processing should be done image = Image.open(BytesIO(chunk.image.data))
# On the rust layer.
# This avoid making n queries per TP
# if image.startswith("https://") or image.startswith("http://"):
# image = processor.image_processor.fetch_images(image)
if image.startswith("data:"):
image = load_data_uri(image)
else:
raise RuntimeError(
"Cannot process input image not starting with data:"
)
# TODO do_convert_RGB should be on by default ? # TODO do_convert_RGB should be on by default ?
image = image.convert("RGB") image = image.convert("RGB")
image_input = processor.image_processor(image, return_tensors="pt") image_input = processor.image_processor(image, return_tensors="pt")
full_text += image_text_replacement(image_input, config, image_id) full_text += image_text_replacement(image_input, config, image_id)
image_inputs.append(image_input) image_inputs.append(image_input)
else: else:
raise RuntimeError(f"Invalid chunk type {chunk['type']}") raise RuntimeError(f"Invalid chunk type {chunk_type}")
batch_inputs.append(full_text) batch_inputs.append(full_text)
max_truncation = max(max_truncation, r.truncate) max_truncation = max(max_truncation, r.truncate)

View File

@ -6,6 +6,7 @@ from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Dict from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.models.types import ( from text_generation_server.models.types import (
@ -93,7 +94,7 @@ class Seq2SeqLMBatch(Batch):
padding_right_offset = 0 padding_right_offset = 0
max_decode_tokens = 0 max_decode_tokens = 0
for i, r in enumerate(pb.requests): for i, r in enumerate(pb.requests):
inputs.append(r.inputs) inputs.append(concat_text_chunks(r.input_chunks.chunks))
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
decoder_input_lengths.append(1) decoder_input_lengths.append(1)
next_token_choosers.append( next_token_choosers.append(

View File

@ -1,12 +1,9 @@
import re
import torch import torch
import math
from PIL import Image from PIL import Image
from io import BytesIO from io import BytesIO
import base64
from opentelemetry import trace from opentelemetry import trace
from typing import Optional, Tuple, List, Type, Dict from typing import Iterable, Optional, Tuple, List, Type, Dict
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from transformers.image_processing_utils import select_best_resolution from transformers.image_processing_utils import select_best_resolution
@ -18,25 +15,6 @@ from text_generation_server.models.flash_mistral import (
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
def split(string) -> List[Dict[str, str]]:
parts = []
cursor = 0
for pattern in IMAGES.finditer(string):
start = pattern.start()
if start != cursor:
parts.append({"type": "text", "content": string[cursor:start]})
parts.append({"type": "image", "content": pattern.group(1)})
cursor = pattern.end()
if cursor != len(string):
parts.append({"type": "text", "content": string[cursor:]})
return parts
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
""" """
@ -129,13 +107,6 @@ def get_number_of_features(height: int, width: int, config) -> int:
return unpadded_features + newline_features + base_features return unpadded_features + newline_features + base_features
def load_data_uri(image_uri: str) -> Image.Image:
image_uri = image_uri.split(",")[-1]
content = base64.b64decode(image_uri)
image = Image.open(BytesIO(content))
return image
class VlmCausalLMBatch(FlashCausalLMBatch): class VlmCausalLMBatch(FlashCausalLMBatch):
pixel_values: Optional[List[torch.Tensor]] pixel_values: Optional[List[torch.Tensor]]
pixel_attention_mask: Optional[List[torch.Tensor]] pixel_attention_mask: Optional[List[torch.Tensor]]
@ -159,35 +130,26 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
return batch return batch
@classmethod @classmethod
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config): def batch_tokenized_inputs(
cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config
):
batch_inputs = [] batch_inputs = []
image_inputs = [] image_inputs = []
max_truncation = 0 max_truncation = 0
for r in requests: for r in requests:
chunks = split(r.inputs)
full_text = "" full_text = ""
image_id = 0 image_id = 0
for chunk in chunks: for chunk in r.input_chunks.chunks:
if chunk["type"] == "text": chunk_type = chunk.WhichOneof("chunk")
full_text += chunk["content"] if chunk_type == "text":
elif chunk["type"] == "image": full_text += chunk.text
image = chunk["content"] elif chunk_type == "image":
# Should never receive URLs anymore, processing should be done image = Image.open(BytesIO(chunk.image.data))
# On the rust layer.
# This avoid making n queries per TP
# if image.startswith("https://") or image.startswith("http://"):
# image = processor.image_processor.fetch_images(image)
if image.startswith("data:"):
image = load_data_uri(image)
else:
raise RuntimeError(
"Cannot process input image not starting with data:"
)
image_input = processor.image_processor(image, return_tensors="pt") image_input = processor.image_processor(image, return_tensors="pt")
full_text += image_text_replacement(image_input, config, image_id) full_text += image_text_replacement(image_input, config, image_id)
image_inputs.append(image_input) image_inputs.append(image_input)
else: else:
raise RuntimeError(f"Invalid chunk type {chunk['type']}") raise RuntimeError(f"Invalid chunk type {chunk_type}")
batch_inputs.append(full_text) batch_inputs.append(full_text)
max_truncation = max(max_truncation, r.truncate) max_truncation = max(max_truncation, r.truncate)

View File

@ -0,0 +1,27 @@
from typing import Iterable
from loguru import logger
from text_generation_server.pb import generate_pb2
def concat_text_chunks(chunks: Iterable[generate_pb2.InputChunk]) -> str:
"""
Concatenate text in text chunks. Non-text chunks are dropped.
"""
text = None
for chunk in chunks:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
if text is None:
text = chunk.text
else:
raise NotImplementedError("Request contained more than one text chunk")
else:
# We cannot reject this, e.g. warmup sends an image chunk.
logger.debug(f"Encountered non-text chunk type {chunk_type}")
if text is None:
raise NotImplementedError("Request without a text chunk")
return text