server: use chunked inputs
The router will now send the input as chunks besides as a single string. This change modifies the server to process chunked input rather than strings. This also allows us to remove the image extraction code from the server.
This commit is contained in:
parent
4dabddb7ea
commit
bf3c813782
|
@ -29,6 +29,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
id=0,
|
id=0,
|
||||||
inputs="Test",
|
inputs="Test",
|
||||||
|
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]),
|
||||||
prefill_logprobs=True,
|
prefill_logprobs=True,
|
||||||
truncate=100,
|
truncate=100,
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
|
|
|
@ -25,6 +25,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
id=0,
|
id=0,
|
||||||
inputs="Test",
|
inputs="Test",
|
||||||
|
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]),
|
||||||
prefill_logprobs=True,
|
prefill_logprobs=True,
|
||||||
truncate=100,
|
truncate=100,
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
|
|
|
@ -15,6 +15,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
id=0,
|
id=0,
|
||||||
inputs="def",
|
inputs="def",
|
||||||
|
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="def")]),
|
||||||
prefill_logprobs=True,
|
prefill_logprobs=True,
|
||||||
truncate=100,
|
truncate=100,
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
|
@ -32,6 +33,13 @@ def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
id=0,
|
id=0,
|
||||||
inputs="<fim-prefix>def<fim-suffix>world<fim-middle>",
|
inputs="<fim-prefix>def<fim-suffix>world<fim-middle>",
|
||||||
|
input_chunks=generate_pb2.Input(
|
||||||
|
chunks=[
|
||||||
|
generate_pb2.InputChunk(
|
||||||
|
text="<fim-prefix>def<fim-suffix>world<fim-middle>"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
prefill_logprobs=True,
|
prefill_logprobs=True,
|
||||||
truncate=100,
|
truncate=100,
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
|
|
|
@ -28,6 +28,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
id=0,
|
id=0,
|
||||||
inputs="Test",
|
inputs="Test",
|
||||||
|
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]),
|
||||||
prefill_logprobs=True,
|
prefill_logprobs=True,
|
||||||
truncate=100,
|
truncate=100,
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
|
|
|
@ -7,6 +7,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenize
|
||||||
from typing import Optional, Tuple, List, Type, Dict
|
from typing import Optional, Tuple, List, Type, Dict
|
||||||
|
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
|
from text_generation_server.utils.chunks import concat_text_chunks
|
||||||
from text_generation_server.utils.tokens import batch_top_tokens
|
from text_generation_server.utils.tokens import batch_top_tokens
|
||||||
from text_generation_server.models.types import (
|
from text_generation_server.models.types import (
|
||||||
Batch,
|
Batch,
|
||||||
|
@ -86,7 +87,8 @@ class CausalLMBatch(Batch):
|
||||||
max_decode_tokens = 0
|
max_decode_tokens = 0
|
||||||
for i, r in enumerate(pb.requests):
|
for i, r in enumerate(pb.requests):
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
inputs.append(r.inputs)
|
inputs.append(concat_text_chunks(r.input_chunks.chunks))
|
||||||
|
|
||||||
next_token_choosers.append(
|
next_token_choosers.append(
|
||||||
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
|
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
|
||||||
)
|
)
|
||||||
|
|
|
@ -11,9 +11,10 @@ from loguru import logger
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from typing import Optional, Tuple, List, Type, Dict
|
from typing import Iterable, Optional, Tuple, List, Type, Dict
|
||||||
|
|
||||||
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
||||||
|
from text_generation_server.utils.chunks import concat_text_chunks
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
from text_generation_server.utils.tokens import batch_top_tokens
|
from text_generation_server.utils.tokens import batch_top_tokens
|
||||||
|
@ -127,11 +128,13 @@ class FlashCausalLMBatch(Batch):
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def batch_tokenized_inputs(cls, requests, tokenizer):
|
def batch_tokenized_inputs(
|
||||||
|
cls, requests: Iterable[generate_pb2.Request], tokenizer
|
||||||
|
):
|
||||||
batch_inputs = []
|
batch_inputs = []
|
||||||
max_truncation = 0
|
max_truncation = 0
|
||||||
for r in requests:
|
for r in requests:
|
||||||
batch_inputs.append(r.inputs)
|
batch_inputs.append(concat_text_chunks(r.input_chunks.chunks))
|
||||||
max_truncation = max(max_truncation, r.truncate)
|
max_truncation = max(max_truncation, r.truncate)
|
||||||
|
|
||||||
batch_tokenized_inputs = tokenizer(
|
batch_tokenized_inputs = tokenizer(
|
||||||
|
|
|
@ -20,6 +20,7 @@ from text_generation_server.utils import (
|
||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.chunks import concat_text_chunks
|
||||||
|
|
||||||
# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py
|
# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py
|
||||||
|
|
||||||
|
@ -91,7 +92,9 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||||
for i, r in enumerate(pb.requests):
|
for i, r in enumerate(pb.requests):
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
||||||
inputs.append(escape_custom_split_sequence(r.inputs))
|
inputs.append(
|
||||||
|
escape_custom_split_sequence(concat_text_chunks(r.input_chunks.chunks))
|
||||||
|
)
|
||||||
next_token_choosers.append(
|
next_token_choosers.append(
|
||||||
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
|
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import torch
|
from io import BytesIO
|
||||||
|
from PIL import Image
|
||||||
import torch
|
import torch
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
@ -21,11 +22,6 @@ from text_generation_server.models.types import (
|
||||||
)
|
)
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||||
from text_generation_server.models.vlm_causal_lm import split
|
|
||||||
|
|
||||||
import re
|
|
||||||
|
|
||||||
IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
|
|
||||||
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
@ -109,7 +105,7 @@ class IdeficsCausalLMBatch(Batch):
|
||||||
max_decode_tokens = 0
|
max_decode_tokens = 0
|
||||||
for i, r in enumerate(pb.requests):
|
for i, r in enumerate(pb.requests):
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.input_chunks.chunks)
|
||||||
next_token_choosers.append(
|
next_token_choosers.append(
|
||||||
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
|
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
|
||||||
)
|
)
|
||||||
|
@ -128,8 +124,15 @@ class IdeficsCausalLMBatch(Batch):
|
||||||
for inp in inputs:
|
for inp in inputs:
|
||||||
# Each input is encoded into a list, where each element of this input list is either a string or a URL
|
# Each input is encoded into a list, where each element of this input list is either a string or a URL
|
||||||
prompt = []
|
prompt = []
|
||||||
for chunk in split(inp):
|
for chunk in inp:
|
||||||
prompt.append(chunk["content"])
|
chunk_type = chunk.WhichOneof("chunk")
|
||||||
|
if chunk_type == "text":
|
||||||
|
prompt.append(chunk.text)
|
||||||
|
elif chunk_type == "image":
|
||||||
|
image = Image.open(BytesIO(chunk.image.data))
|
||||||
|
prompt.append(image)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||||
prompts.append(prompt)
|
prompts.append(prompt)
|
||||||
|
|
||||||
# The processor replaces the call to tokenizer, and
|
# The processor replaces the call to tokenizer, and
|
||||||
|
|
|
@ -27,6 +27,7 @@ from text_generation_server.models.types import (
|
||||||
Generation,
|
Generation,
|
||||||
GeneratedText,
|
GeneratedText,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.chunks import concat_text_chunks
|
||||||
from text_generation_server.utils.tokens import batch_top_tokens, Sampling
|
from text_generation_server.utils.tokens import batch_top_tokens, Sampling
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||||
|
@ -139,7 +140,7 @@ class MambaBatch(Batch):
|
||||||
max_decode_tokens = 0
|
max_decode_tokens = 0
|
||||||
for i, r in enumerate(pb.requests):
|
for i, r in enumerate(pb.requests):
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
inputs.append(r.inputs)
|
inputs.append(concat_text_chunks(r.input_chunks.chunks))
|
||||||
next_token_choosers.append(
|
next_token_choosers.append(
|
||||||
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
|
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,55 +1,48 @@
|
||||||
|
from io import BytesIO
|
||||||
|
from PIL import Image
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from typing import Optional, Tuple
|
from typing import Iterable, Optional, Tuple
|
||||||
from text_generation_server.models.vlm_causal_lm import (
|
from text_generation_server.models.vlm_causal_lm import (
|
||||||
VlmCausalLM,
|
VlmCausalLM,
|
||||||
VlmCausalLMBatch,
|
VlmCausalLMBatch,
|
||||||
image_text_replacement,
|
image_text_replacement,
|
||||||
load_data_uri,
|
|
||||||
split,
|
|
||||||
)
|
)
|
||||||
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
|
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
|
||||||
PaliGemmaForConditionalGeneration,
|
PaliGemmaForConditionalGeneration,
|
||||||
)
|
)
|
||||||
from transformers import AutoProcessor, AutoConfig, AutoImageProcessor
|
from transformers import AutoProcessor, AutoConfig
|
||||||
|
|
||||||
|
from text_generation_server.pb.generate_pb2 import Request
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PaliGemmaBatch(VlmCausalLMBatch):
|
class PaliGemmaBatch(VlmCausalLMBatch):
|
||||||
@classmethod
|
@classmethod
|
||||||
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config):
|
def batch_tokenized_inputs(
|
||||||
|
cls, requests: Iterable[Request], tokenizer, processor, config
|
||||||
|
):
|
||||||
batch_inputs = []
|
batch_inputs = []
|
||||||
image_inputs = []
|
image_inputs = []
|
||||||
max_truncation = 0
|
max_truncation = 0
|
||||||
for r in requests:
|
for r in requests:
|
||||||
chunks = split(r.inputs)
|
|
||||||
full_text = ""
|
full_text = ""
|
||||||
image_id = 0
|
image_id = 0
|
||||||
for chunk in chunks:
|
for chunk in r.input_chunks.chunks:
|
||||||
if chunk["type"] == "text":
|
chunk_type = chunk.WhichOneof("chunk")
|
||||||
full_text += "<bos>" + chunk["content"] + "\n"
|
if chunk_type == "text":
|
||||||
elif chunk["type"] == "image":
|
full_text += "<bos>" + chunk.text + "\n"
|
||||||
image = chunk["content"]
|
elif chunk_type == "image":
|
||||||
# Should never receive URLs anymore, processing should be done
|
image = Image.open(BytesIO(chunk.image.data))
|
||||||
# On the rust layer.
|
|
||||||
# This avoid making n queries per TP
|
|
||||||
# if image.startswith("https://") or image.startswith("http://"):
|
|
||||||
# image = processor.image_processor.fetch_images(image)
|
|
||||||
if image.startswith("data:"):
|
|
||||||
image = load_data_uri(image)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Cannot process input image not starting with data:"
|
|
||||||
)
|
|
||||||
# TODO do_convert_RGB should be on by default ?
|
# TODO do_convert_RGB should be on by default ?
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
image_input = processor.image_processor(image, return_tensors="pt")
|
image_input = processor.image_processor(image, return_tensors="pt")
|
||||||
full_text += image_text_replacement(image_input, config, image_id)
|
full_text += image_text_replacement(image_input, config, image_id)
|
||||||
image_inputs.append(image_input)
|
image_inputs.append(image_input)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
|
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||||
|
|
||||||
batch_inputs.append(full_text)
|
batch_inputs.append(full_text)
|
||||||
max_truncation = max(max_truncation, r.truncate)
|
max_truncation = max(max_truncation, r.truncate)
|
||||||
|
|
|
@ -6,6 +6,7 @@ from opentelemetry import trace
|
||||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
|
||||||
from typing import Optional, Tuple, List, Type, Dict
|
from typing import Optional, Tuple, List, Type, Dict
|
||||||
|
|
||||||
|
from text_generation_server.utils.chunks import concat_text_chunks
|
||||||
from text_generation_server.utils.tokens import batch_top_tokens
|
from text_generation_server.utils.tokens import batch_top_tokens
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
from text_generation_server.models.types import (
|
from text_generation_server.models.types import (
|
||||||
|
@ -93,7 +94,7 @@ class Seq2SeqLMBatch(Batch):
|
||||||
padding_right_offset = 0
|
padding_right_offset = 0
|
||||||
max_decode_tokens = 0
|
max_decode_tokens = 0
|
||||||
for i, r in enumerate(pb.requests):
|
for i, r in enumerate(pb.requests):
|
||||||
inputs.append(r.inputs)
|
inputs.append(concat_text_chunks(r.input_chunks.chunks))
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
decoder_input_lengths.append(1)
|
decoder_input_lengths.append(1)
|
||||||
next_token_choosers.append(
|
next_token_choosers.append(
|
||||||
|
|
|
@ -1,12 +1,9 @@
|
||||||
import re
|
|
||||||
import torch
|
import torch
|
||||||
import math
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import base64
|
|
||||||
|
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from typing import Optional, Tuple, List, Type, Dict
|
from typing import Iterable, Optional, Tuple, List, Type, Dict
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from transformers.image_processing_utils import select_best_resolution
|
from transformers.image_processing_utils import select_best_resolution
|
||||||
|
@ -18,25 +15,6 @@ from text_generation_server.models.flash_mistral import (
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
|
|
||||||
|
|
||||||
|
|
||||||
def split(string) -> List[Dict[str, str]]:
|
|
||||||
parts = []
|
|
||||||
cursor = 0
|
|
||||||
for pattern in IMAGES.finditer(string):
|
|
||||||
start = pattern.start()
|
|
||||||
if start != cursor:
|
|
||||||
parts.append({"type": "text", "content": string[cursor:start]})
|
|
||||||
|
|
||||||
parts.append({"type": "image", "content": pattern.group(1)})
|
|
||||||
cursor = pattern.end()
|
|
||||||
|
|
||||||
if cursor != len(string):
|
|
||||||
parts.append({"type": "text", "content": string[cursor:]})
|
|
||||||
|
|
||||||
return parts
|
|
||||||
|
|
||||||
|
|
||||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||||
"""
|
"""
|
||||||
|
@ -129,13 +107,6 @@ def get_number_of_features(height: int, width: int, config) -> int:
|
||||||
return unpadded_features + newline_features + base_features
|
return unpadded_features + newline_features + base_features
|
||||||
|
|
||||||
|
|
||||||
def load_data_uri(image_uri: str) -> Image.Image:
|
|
||||||
image_uri = image_uri.split(",")[-1]
|
|
||||||
content = base64.b64decode(image_uri)
|
|
||||||
image = Image.open(BytesIO(content))
|
|
||||||
return image
|
|
||||||
|
|
||||||
|
|
||||||
class VlmCausalLMBatch(FlashCausalLMBatch):
|
class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||||
pixel_values: Optional[List[torch.Tensor]]
|
pixel_values: Optional[List[torch.Tensor]]
|
||||||
pixel_attention_mask: Optional[List[torch.Tensor]]
|
pixel_attention_mask: Optional[List[torch.Tensor]]
|
||||||
|
@ -159,35 +130,26 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config):
|
def batch_tokenized_inputs(
|
||||||
|
cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config
|
||||||
|
):
|
||||||
batch_inputs = []
|
batch_inputs = []
|
||||||
image_inputs = []
|
image_inputs = []
|
||||||
max_truncation = 0
|
max_truncation = 0
|
||||||
for r in requests:
|
for r in requests:
|
||||||
chunks = split(r.inputs)
|
|
||||||
full_text = ""
|
full_text = ""
|
||||||
image_id = 0
|
image_id = 0
|
||||||
for chunk in chunks:
|
for chunk in r.input_chunks.chunks:
|
||||||
if chunk["type"] == "text":
|
chunk_type = chunk.WhichOneof("chunk")
|
||||||
full_text += chunk["content"]
|
if chunk_type == "text":
|
||||||
elif chunk["type"] == "image":
|
full_text += chunk.text
|
||||||
image = chunk["content"]
|
elif chunk_type == "image":
|
||||||
# Should never receive URLs anymore, processing should be done
|
image = Image.open(BytesIO(chunk.image.data))
|
||||||
# On the rust layer.
|
|
||||||
# This avoid making n queries per TP
|
|
||||||
# if image.startswith("https://") or image.startswith("http://"):
|
|
||||||
# image = processor.image_processor.fetch_images(image)
|
|
||||||
if image.startswith("data:"):
|
|
||||||
image = load_data_uri(image)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Cannot process input image not starting with data:"
|
|
||||||
)
|
|
||||||
image_input = processor.image_processor(image, return_tensors="pt")
|
image_input = processor.image_processor(image, return_tensors="pt")
|
||||||
full_text += image_text_replacement(image_input, config, image_id)
|
full_text += image_text_replacement(image_input, config, image_id)
|
||||||
image_inputs.append(image_input)
|
image_inputs.append(image_input)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
|
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||||
|
|
||||||
batch_inputs.append(full_text)
|
batch_inputs.append(full_text)
|
||||||
max_truncation = max(max_truncation, r.truncate)
|
max_truncation = max(max_truncation, r.truncate)
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from text_generation_server.pb import generate_pb2
|
||||||
|
|
||||||
|
|
||||||
|
def concat_text_chunks(chunks: Iterable[generate_pb2.InputChunk]) -> str:
|
||||||
|
"""
|
||||||
|
Concatenate text in text chunks. Non-text chunks are dropped.
|
||||||
|
"""
|
||||||
|
text = None
|
||||||
|
for chunk in chunks:
|
||||||
|
chunk_type = chunk.WhichOneof("chunk")
|
||||||
|
if chunk_type == "text":
|
||||||
|
if text is None:
|
||||||
|
text = chunk.text
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Request contained more than one text chunk")
|
||||||
|
else:
|
||||||
|
# We cannot reject this, e.g. warmup sends an image chunk.
|
||||||
|
logger.debug(f"Encountered non-text chunk type {chunk_type}")
|
||||||
|
|
||||||
|
if text is None:
|
||||||
|
raise NotImplementedError("Request without a text chunk")
|
||||||
|
|
||||||
|
return text
|
Loading…
Reference in New Issue