feat(server): Improved doc

This commit is contained in:
OlivierDehaene 2022-11-07 12:53:56 +01:00
parent cea6051eff
commit 4236e41b0d
9 changed files with 195 additions and 101 deletions

View File

@ -28,6 +28,7 @@ ENV LANG=C.UTF-8 \
MODEL_NAME=bigscience/bloom \
QUANTIZE=false \
NUM_GPUS=8 \
SAFETENSORS_FAST_GPU=1 \
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
NCCL_ASYNC_ERROR_HANDLING=1 \
CUDA_HOME=/usr/local/cuda \
@ -55,12 +56,6 @@ RUN cd server && make install-torch
# Install specific version of transformers
RUN cd server && make install-transformers
# Install specific version of safetensors
# FIXME: This is a temporary fix while we wait for a new release
RUN curl https://sh.rustup.rs -sSf | bash -s -- -y
ENV PATH="/root/.cargo/bin:${PATH}"
RUN cd server && make install-safetensors
# Install server
COPY proto proto
COPY server server

View File

@ -6,7 +6,8 @@
</div>
A Rust and gRPC server for text generation inference.
A Rust and gRPC server for text generation inference. Used in production at [HuggingFace](https://huggingface.co)
to power Bloom, BloomZ and MT0-XXL api-inference widgets.
## Features
@ -15,11 +16,11 @@ A Rust and gRPC server for text generation inference.
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
- 45ms per token generation for BLOOM with 8xA100 80GB
## Officialy supported models
## Officially supported models
- BLOOM
- BLOOMZ
- BLOOM-560m
- [BLOOM](https://huggingface.co/bigscience/bloom)
- [BLOOMZ](https://huggingface.co/bigscience/bloomz)
- [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl)
Other models are supported on a best effort basis using:
@ -91,4 +92,3 @@ make router-dev
- [ ] Add tests for the `server/model` logic
- [ ] Backport custom CUDA kernels to Transformers
- [ ] Install safetensors with pip

View File

@ -295,6 +295,10 @@ fn shard_manager(
"MASTER_PORT".parse().unwrap(),
master_port.to_string().parse().unwrap(),
),
(
"SAFETENSORS_FAST_GPU".parse().unwrap(),
"1".to_string().parse().unwrap(),
),
];
// If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard

View File

@ -16,24 +16,13 @@ install-transformers:
mv transformers-7302a24535e8dc5637ea5b4e4572fc971d404098 transformers
cd transformers && python setup.py install
install-safetensors:
# Install specific version of safetensors
pip install setuptools_rust
rm safetensors || true
rm safetensors-634deccbcbad5eaf417935281f8b3be7ebca69c5 || true
curl -L -O https://github.com/huggingface/safetensors/archive/634deccbcbad5eaf417935281f8b3be7ebca69c5.zip
unzip 634deccbcbad5eaf417935281f8b3be7ebca69c5.zip
rm 634deccbcbad5eaf417935281f8b3be7ebca69c5.zip
mv safetensors-634deccbcbad5eaf417935281f8b3be7ebca69c5 safetensors
cd safetensors/bindings/python && python setup.py develop
install-torch:
# Install specific version of torch
pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 --no-cache-dir
install: gen-server install-torch install-transformers install-safetensors
install: gen-server install-torch install-transformers
pip install pip --upgrade
pip install -e . --no-cache-dir
run-dev:
python -m torch.distributed.run --nproc_per_node=2 text_generation/cli.py serve bigscience/bloom-560m --sharded
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation/cli.py serve bigscience/bloom-560m --sharded

47
server/poetry.lock generated
View File

@ -145,6 +145,18 @@ category = "main"
optional = false
python-versions = ">=3.6"
[[package]]
name = "safetensors"
version = "0.2.4"
description = "Fast and Safe Tensor serialization"
category = "main"
optional = false
python-versions = "*"
[package.extras]
dev = ["black (==22.3)", "flake8 (>=3.8.3)", "huggingface-hub", "isort (>=5.5.4)", "numpy", "pytest", "setuptools-rust"]
testing = ["black (==22.3)", "flake8 (>=3.8.3)", "huggingface-hub", "isort (>=5.5.4)", "numpy", "pytest", "setuptools-rust"]
[[package]]
name = "setuptools"
version = "65.5.0"
@ -208,7 +220,7 @@ bnb = ["bitsandbytes"]
[metadata]
lock-version = "1.1"
python-versions = "^3.9"
content-hash = "224b1e379d6105fe911bff4563946a90dfa6ff5918cf2e7be59f8d4f7c5cd7cf"
content-hash = "3266187ef14fe8f9e29b3b6530d07781ea952aa670c0fe0de34be43efa231a67"
[metadata.files]
accelerate = [
@ -459,6 +471,39 @@ PyYAML = [
{file = "PyYAML-6.0-cp39-cp39-win_amd64.whl", hash = "sha256:b3d267842bf12586ba6c734f89d1f5b871df0273157918b0ccefa29deb05c21c"},
{file = "PyYAML-6.0.tar.gz", hash = "sha256:68fb519c14306fec9720a2a5b45bc9f0c8d1b9c72adf45c37baedfcd949c35a2"},
]
safetensors = [
{file = "safetensors-0.2.4-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:79c4a7610d7699c64d8531c43f758ded4990ebaa7b0887c2078640e6de44e726"},
{file = "safetensors-0.2.4-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:ef425a4ddd29612fe733a6eeca6ad8f3ee3939f530a032114974aac4c4667b89"},
{file = "safetensors-0.2.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:77758f8ba4de6e20bf394dd964854a926dee2efee82eaa95e6c0893e2a7d960c"},
{file = "safetensors-0.2.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fb956e9090cce515649f00b491b5ddc0f9c3d989139016a8d69f9dcf57e8d3d9"},
{file = "safetensors-0.2.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3e31b02d27249bd519f05ec9d189097c59fc6851c59daa1a86ef347659e33ac3"},
{file = "safetensors-0.2.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c2fead03a1497042efea4358574f3d7acf501b0c82e54d605f393f2b4e2aafe"},
{file = "safetensors-0.2.4-cp310-cp310-win32.whl", hash = "sha256:dce6ed3c7d13aafa574737eb3309c928adcb6781e879b41f0861be83b439cf3e"},
{file = "safetensors-0.2.4-cp310-cp310-win_amd64.whl", hash = "sha256:1dfe727325a1342767c6725dc2cc1f00463eb40a1f5df37c338d8e03957e27ce"},
{file = "safetensors-0.2.4-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:c066bc7b90a582a01ec468fef61a7581b5c726bf12c50491cb6ea5db215ea5e0"},
{file = "safetensors-0.2.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca6ed53dad5d7d0e67eb676528ff2ad345cac3a34010e4dc1e3736972de294a5"},
{file = "safetensors-0.2.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ada03b44acbb036cfabe7066a8df4ad9b1ac05bb585a6b6c0f285f08e016381d"},
{file = "safetensors-0.2.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:58a0902708daa7ec2b2293b46e85df61f4fa359ddfe648e7ac025a79e6f59627"},
{file = "safetensors-0.2.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e0a4e38f7cbb4bfc513588e52f349b906c941e74fbbe192f2b19fc34221d448"},
{file = "safetensors-0.2.4-cp37-cp37m-win32.whl", hash = "sha256:4f8695b77dd847203258f035f8468f8b701c90621cb6b457e109f8d89c27f16c"},
{file = "safetensors-0.2.4-cp37-cp37m-win_amd64.whl", hash = "sha256:16b08f33c753c7da64b3999beea7c30d58204a0820961e33881d05a331e3f5c0"},
{file = "safetensors-0.2.4-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:a381606804f23db9eede51135f5fbd1f75dda02100415ee150fd39eb1cd6be4c"},
{file = "safetensors-0.2.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7aceae84d0c7233d83923029aaf8d184848561e0211ec98c5317327b3db025d6"},
{file = "safetensors-0.2.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:da48fc929485cbd9ee22621e388764a7cef27b0205e73aee2ad75aadd7d67662"},
{file = "safetensors-0.2.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2619b88f934c4de6b59de90c9dc00eae2d0e30f254a1daebd6eb232ac1f9a7a7"},
{file = "safetensors-0.2.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d1f78b987ae1f6b71da8ea110164e4cab2ee31b53835d2a66279df89c5d73f0e"},
{file = "safetensors-0.2.4-cp38-cp38-win32.whl", hash = "sha256:34b3e60b5130fb0fe07114705e51d30aa2c7eae4c1d1e77d6f260fa4ade70ede"},
{file = "safetensors-0.2.4-cp38-cp38-win_amd64.whl", hash = "sha256:debaa4fa98a7af44ba6dcb6945efee77b8480284c2cb05918ab97cf511c40826"},
{file = "safetensors-0.2.4-cp39-cp39-macosx_10_11_x86_64.whl", hash = "sha256:90baaafc0c872a736124b341db54b0bdd61765cbf3a61418371066a37905b18d"},
{file = "safetensors-0.2.4-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:b4bf7e23191d6a3ff00de141512869fc776e8ff159c872cb44af018cb04d45eb"},
{file = "safetensors-0.2.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf11a3aba8796e548ceb0a65f34dcd334dcf0c4c891dccabe18a8b53918ae8ab"},
{file = "safetensors-0.2.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:95c31935ea71d63a38c546654136d7f0dbf1e7aeb6564dbc2201bc1fe9b34e4c"},
{file = "safetensors-0.2.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ef31776e2e081d6f075408eed34a0fbd524cbd19e50268bef02c238b209213b7"},
{file = "safetensors-0.2.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d06bb1d68148f6d6934352124d8cbfcf0db092f969db7187e348bd5cbf183db5"},
{file = "safetensors-0.2.4-cp39-cp39-win32.whl", hash = "sha256:5d546152b9a5bd58eae97c2ddefba394404d37ddedec305f7639c9b6054513e5"},
{file = "safetensors-0.2.4-cp39-cp39-win_amd64.whl", hash = "sha256:553ecfd895d379c1e03a7c9241f7343b3af66573436969ed7eb95df81dfbe9af"},
{file = "safetensors-0.2.4.tar.gz", hash = "sha256:35c0719a898f1f1292464f4cd9370bb6c2698032f1db4d677489f078b66b5a75"},
]
setuptools = [
{file = "setuptools-65.5.0-py3-none-any.whl", hash = "sha256:f62ea9da9ed6289bfe868cd6845968a2c854d1427f8548d52cae02a42b4f0356"},
{file = "setuptools-65.5.0.tar.gz", hash = "sha256:512e5536220e38146176efb833d4a62aa726b7bbff82cfbc8ba9eaa3996e0b17"},

View File

@ -15,6 +15,7 @@ typer = "^0.6.1"
grpcio-reflection = "^1.49.1"
accelerate = "^0.12.0"
bitsandbytes = "^0.35.1"
safetensors = "^0.2.4"
[tool.poetry.extras]
bnb = ["bitsandbytes"]

View File

@ -9,17 +9,13 @@ __all__ = ["Model", "BLOOMSharded", "CausalLM", "Seq2SeqLM"]
def get_model(model_name: str, sharded: bool, quantize: bool) -> Model:
if model_name.startswith("bigscience/bloom"):
if sharded:
return BLOOMSharded(model_name, quantize)
return BLOOMSharded(model_name, quantize=quantize)
else:
if quantize:
raise ValueError("quantization is not supported for non-sharded BLOOM")
return CausalLM(model_name)
return CausalLM(model_name, quantize=quantize)
else:
if sharded:
raise ValueError("sharded is not supported for AutoModel")
if quantize:
raise ValueError("quantize is not supported for AutoModel")
try:
return CausalLM(model_name)
return CausalLM(model_name, quantize=quantize)
except Exception as e:
return Seq2SeqLM(model_name)
return Seq2SeqLM(model_name, quantize=quantize)

View File

@ -2,7 +2,7 @@ import torch
from dataclasses import dataclass
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import Optional, Tuple, List, Dict, Type
from typing import Optional, Tuple, List, Type
from text_generation.models import Model
from text_generation.models.types import GeneratedText
@ -14,11 +14,23 @@ from text_generation.utils import NextTokenChooser, StoppingCriteria
class CausalLMBatch:
batch_id: int
requests: List[generate_pb2.Request]
all_input_lengths: List[int]
input_ids: Dict[str, torch.Tensor]
# Decoder values
input_ids: torch.Tensor
attention_mask: torch.Tensor
past_key_values: Optional[List[Tuple]]
# All tokens
all_input_ids: List[torch.Tensor]
# Lengths of all generations present in the batch
input_lengths: List[int]
# Generation helpers
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
# Metadata used for padding
size: int
max_sequence_length: int
@ -36,12 +48,12 @@ class CausalLMBatch:
inputs = []
next_token_choosers = []
stopping_criterias = []
all_input_lengths = []
input_lengths = []
# Parse batch
for r in pb.requests:
inputs.append(r.inputs)
all_input_lengths.append(r.input_length)
input_lengths.append(r.input_length)
next_token_choosers.append(
NextTokenChooser(
temperature=r.parameters.temperature,
@ -56,21 +68,23 @@ class CausalLMBatch:
)
)
input_ids = tokenizer(
tokenized_inputs = tokenizer(
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
).to(device)
all_input_ids = input_ids["input_ids"].unsqueeze(-1)
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
return cls(
batch_id=pb.id,
requests=pb.requests,
all_input_lengths=all_input_lengths,
input_ids=input_ids,
input_ids=tokenized_inputs["input_ids"],
attention_mask=tokenized_inputs["attention_mask"],
past_key_values=None,
all_input_ids=all_input_ids,
input_lengths=input_lengths,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=pb.size,
max_sequence_length=max(all_input_lengths),
max_sequence_length=max(input_lengths),
)
@classmethod
@ -80,19 +94,23 @@ class CausalLMBatch:
max_sequence_length = max(batch.max_sequence_length for batch in batches)
# Batch attributes
input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []}
requests = []
all_input_lengths = []
input_lengths = []
all_input_ids = []
next_token_choosers = []
stopping_criterias = []
# Batch tensors
input_ids = None
attention_mask = None
past_key_values = []
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
start_index = 0
for i, batch in enumerate(batches):
requests.extend(batch.requests)
all_input_lengths.extend(batch.all_input_lengths)
input_lengths.extend(batch.input_lengths)
all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
@ -101,32 +119,35 @@ class CausalLMBatch:
end_index = start_index + batch.size
# We only concatenate batches that did at least one step
if batch.input_ids["input_ids"].shape[1] > 1:
if batch.input_ids.shape[1] > 1:
raise ValueError("Batch input_ids should be of shape (batch_size, 1)")
# Initialize tensors
if i == 0:
input_ids["input_ids"] = torch.empty(
(total_batch_size, 1),
dtype=batch.input_ids["input_ids"].dtype,
device=batch.input_ids["input_ids"].device,
)
input_ids["attention_mask"] = torch.zeros(
(total_batch_size, max_sequence_length),
dtype=batch.input_ids["attention_mask"].dtype,
device=batch.input_ids["attention_mask"].device,
)
# input_ids["input_ids"] is always of shape [batch_size, 1]
# Create empty tensor
# input_ids is always of shape [batch_size, 1]
# We do not need to pad it
input_ids["input_ids"][start_index:end_index] = batch.input_ids["input_ids"]
if input_ids is None:
input_ids = torch.empty(
(total_batch_size, 1),
dtype=batch.input_ids.dtype,
device=batch.input_ids.device,
)
# Copy to correct indices
input_ids[start_index:end_index] = batch.input_ids
# Create padded tensor
if attention_mask is None:
attention_mask = torch.zeros(
(total_batch_size, max_sequence_length),
dtype=batch.attention_mask.dtype,
device=batch.attention_mask.device,
)
# We need to slice the attention mask to remove padding from previous steps
input_ids["attention_mask"][
attention_mask[
start_index:end_index, -batch.max_sequence_length :
] = batch.input_ids["attention_mask"][:, -batch.max_sequence_length :]
] = batch.attention_mask[:, -batch.max_sequence_length :]
for j, past in enumerate(batch.input_ids["past_key_values"]):
for j, past in enumerate(batch.past_key_values):
# Shenanigans to get dimensions because BLOOM outputs a past with a different shape
# BLOOM: [batch_size * num_heads, ...] vs [batch_size, num_heads, ...]
head_dim, padded_sequence_length = past[0].shape[-2:]
@ -137,8 +158,8 @@ class CausalLMBatch:
)
# This will run only once per layer
if j == len(input_ids["past_key_values"]):
input_ids["past_key_values"].append([])
if j == len(past_key_values):
past_key_values.append([])
# Decoder past
for k, t in enumerate(past):
@ -172,21 +193,21 @@ class CausalLMBatch:
# Initialize tensors
# This will run only once per layer and per past tensor
if k == len(input_ids["past_key_values"][j]):
input_ids["past_key_values"][j].append(
if k == len(past_key_values[j]):
past_key_values[j].append(
torch.zeros(padded_t_shape, dtype=t.dtype, device=t.device)
)
# We slice the past keys and values to remove the padding from previous batches
if not head_dim_last:
input_ids["past_key_values"][j][k][
past_key_values[j][k][
start_index:end_index,
:,
:,
-(batch.max_sequence_length - 1) :,
] = t[:, :, :, -(batch.max_sequence_length - 1) :]
else:
input_ids["past_key_values"][j][k][
past_key_values[j][k][
start_index:end_index,
:,
-(batch.max_sequence_length - 1) :,
@ -198,9 +219,11 @@ class CausalLMBatch:
return cls(
batch_id=batches[0].batch_id,
requests=requests,
all_input_lengths=all_input_lengths,
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
all_input_ids=all_input_ids,
input_lengths=input_lengths,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=total_batch_size,
@ -209,7 +232,7 @@ class CausalLMBatch:
class CausalLM(Model):
def __init__(self, model_name: str):
def __init__(self, model_name: str, quantize=False):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
@ -223,6 +246,7 @@ class CausalLM(Model):
model_name,
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None,
load_in_8bit=quantize,
).eval()
super(CausalLM, self).__init__(
@ -255,16 +279,19 @@ class CausalLM(Model):
torch.no_grad if self.device.type == "cpu" else torch.inference_mode
)
with context_manager():
logits, past = self.forward(**batch.input_ids)
logits, past = self.forward(
batch.input_ids, batch.attention_mask, batch.past_key_values
)
# List of indices to cache
next_batch_keep_indices = []
# New input_ids for next forward
# New values for next forward
next_batch_input_lengths = []
next_batch_input_ids = []
next_batch_all_input_ids = []
next_all_input_lengths = []
# Metadata
next_batch_size = 0
next_batch_max_sequence_length = 0
@ -274,7 +301,7 @@ class CausalLM(Model):
# Zipped iterator
iterator = zip(
batch.requests,
batch.all_input_lengths,
batch.input_lengths,
logits,
batch.next_token_choosers,
batch.stopping_criterias,
@ -313,7 +340,7 @@ class CausalLM(Model):
next_batch_all_input_ids.append(all_tokens)
next_batch_size += 1
new_input_length = input_length + 1
next_all_input_lengths.append(new_input_length)
next_batch_input_lengths.append(new_input_length)
next_batch_max_sequence_length = max(
next_batch_max_sequence_length, new_input_length
)
@ -322,15 +349,14 @@ class CausalLM(Model):
if not next_batch_keep_indices:
return generated_texts, None
# If we finished at least one generation
next_batch_input_ids = {"input_ids": torch.cat(next_batch_input_ids, dim=0)}
next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0)
# If we finished at least one generation, we need to evict the indices of the generations that finished
# from the values of the next batch
if generated_texts:
# Apply indices to attention mask, past key values and other items that need to be cached
next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"][
next_batch_keep_indices
]
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
# Force past to be of dim [batch_size, num_heads, ...] for easy indexing
next_batch_input_ids["past_key_values"] = [
next_batch_past_key_values = [
[
t.view(-1, self.num_heads, *t.shape[-2:])[next_batch_keep_indices]
for t in layer
@ -345,16 +371,16 @@ class CausalLM(Model):
batch.stopping_criterias[i] for i in next_batch_keep_indices
]
else:
next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"]
next_batch_input_ids["past_key_values"] = past
next_batch_attention_mask = batch.attention_mask
next_batch_past_key_values = past
next_batch_requests = batch.requests
next_batch_next_token_choosers = batch.next_token_choosers
next_batch_stopping_criterias = batch.stopping_criterias
# Update attention_mask with padding as we added a new token to input_ids
next_batch_input_ids["attention_mask"] = torch.cat(
next_batch_attention_mask = torch.cat(
[
next_batch_input_ids["attention_mask"],
next_batch_attention_mask,
torch.ones((next_batch_size, 1)).to(self.device),
],
dim=1,
@ -363,9 +389,11 @@ class CausalLM(Model):
next_batch = CausalLMBatch(
batch_id=batch.batch_id,
requests=next_batch_requests,
all_input_lengths=next_all_input_lengths,
input_ids=next_batch_input_ids,
attention_mask=next_batch_attention_mask,
past_key_values=next_batch_past_key_values,
all_input_ids=next_batch_all_input_ids,
input_lengths=next_batch_input_lengths,
next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias,
size=next_batch_size,

View File

@ -15,26 +15,33 @@ class Seq2SeqLMBatch:
batch_id: int
requests: List[generate_pb2.Request]
# Encoder values
input_ids: torch.Tensor
attention_mask: torch.Tensor
# Decoder values
decoder_input_ids: torch.Tensor
decoder_attention_mask: Optional[torch.Tensor]
encoder_last_hidden_state: Optional[torch.Tensor]
# Seq2SeqLM keeps track of both encoder and decoder attention keys and values
past_key_values: Optional[List[Tuple]]
# Lengths of all generations present in the batch
input_lengths: List[int]
decoder_input_lengths: List[int]
# Generation helpers
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
# Metadata used for padding
size: int
max_input_length: int
max_decoder_input_length: int
def to_pb(self):
"""Convert a Seq2SeqLMBatch to a text_generation.v1.Batch protobuf"""
return generate_pb2.Batch(
id=self.batch_id,
requests=self.requests,
@ -45,6 +52,7 @@ class Seq2SeqLMBatch:
def from_pb(
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
) -> "Seq2SeqLMBatch":
"""Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch"""
inputs = []
next_token_choosers = []
stopping_criterias = []
@ -57,6 +65,7 @@ class Seq2SeqLMBatch:
for r in pb.requests:
inputs.append(r.inputs)
input_lengths.append(r.input_length)
# Decoder sequence only contains the bos_token
decoder_input_ids.append(tokenizer.bos_token_id)
decoder_input_lengths.append(1)
next_token_choosers.append(
@ -73,9 +82,11 @@ class Seq2SeqLMBatch:
)
)
# Tokenize batch
tokenized_inputs = tokenizer(
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
).to(device)
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
decoder_input_ids = torch.tensor(decoder_input_ids).to(device).unsqueeze(-1)
return cls(
@ -98,6 +109,8 @@ class Seq2SeqLMBatch:
@classmethod
def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
"""Concatenate multiple batches together by padding internal torch tensors"""
# Used for padding
total_batch_size = sum(batch.size for batch in batches)
max_input_length = max(batch.max_input_length for batch in batches)
@ -112,6 +125,7 @@ class Seq2SeqLMBatch:
next_token_choosers = []
stopping_criterias = []
# Batch tensors
input_ids = None
attention_mask = None
decoder_input_ids = None
@ -122,7 +136,9 @@ class Seq2SeqLMBatch:
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
start_index = 0
for i, batch in enumerate(batches):
# Extend all list attributes
requests.extend(batch.requests)
input_lengths.extend(batch.input_lengths)
decoder_input_lengths.extend(batch.decoder_input_lengths)
@ -136,51 +152,62 @@ class Seq2SeqLMBatch:
if batch.encoder_last_hidden_state is None:
raise ValueError("Batch encoder_last_hidden_state cannot be None")
# Create padded tensor
if input_ids is None:
input_ids = torch.zeros(
(total_batch_size, max_input_length),
dtype=batch.input_ids.dtype,
device=batch.input_ids.device,
)
# Copy to correct indices
input_ids[
start_index:end_index, -batch.max_input_length :
] = batch.input_ids[:, -batch.max_input_length :]
# Create padded tensor
if attention_mask is None:
attention_mask = torch.zeros(
(total_batch_size, max_input_length),
dtype=batch.attention_mask.dtype,
device=batch.attention_mask.device,
)
# Copy to correct indices
attention_mask[
start_index:end_index, -batch.max_input_length :
] = batch.attention_mask[:, -batch.max_input_length :]
# Create padded tensor
if decoder_input_ids is None:
decoder_input_ids = torch.zeros(
(total_batch_size, max_decoder_input_length),
dtype=batch.decoder_input_ids.dtype,
device=batch.decoder_input_ids.device,
)
# Copy to correct indices
decoder_input_ids[
start_index:end_index, -batch.max_decoder_input_length :
] = batch.decoder_input_ids[:, -batch.max_decoder_input_length :]
# Create padded tensor
if decoder_attention_mask is None:
decoder_attention_mask = torch.zeros(
(total_batch_size, max_decoder_input_length),
dtype=batch.attention_mask.dtype,
device=batch.attention_mask.device,
dtype=batch.attention_mask.dtype, # As decoder_attention_mask might not exist,
device=batch.attention_mask.device, # we use `batch.attention_maks` for device here
)
# If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
# this batch. All generations are of length `batch.max_decoder_input_length`.
if batch.decoder_attention_mask is None:
decoder_attention_mask[
start_index:end_index, -batch.max_decoder_input_length :
] = 1
# If it exists, we need to index
else:
decoder_attention_mask[
start_index:end_index, -batch.max_decoder_input_length :
] = batch.decoder_attention_mask[:, -batch.max_decoder_input_length :]
# Create padded tensor
if encoder_last_hidden_state is None:
encoder_last_hidden_state = torch.zeros(
(
@ -192,10 +219,12 @@ class Seq2SeqLMBatch:
device=batch.encoder_last_hidden_state.device,
)
# Copy to correct indices
encoder_last_hidden_state[
start_index:end_index, -batch.max_decoder_input_length :, :
] = batch.encoder_last_hidden_state[:, -batch.max_decoder_input_length :, :]
# Iterate over attention layers
for j, past in enumerate(batch.past_key_values):
_, num_heads, _, head_dim = past[0].shape
@ -271,7 +300,7 @@ class Seq2SeqLMBatch:
class Seq2SeqLM(Model):
def __init__(self, model_name: str):
def __init__(self, model_name: str, quantize=False):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
@ -283,6 +312,7 @@ class Seq2SeqLM(Model):
model_name,
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None,
load_in_8bit=quantize,
).eval()
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
tokenizer.bos_token_id = self.model.config.decoder_start_token_id
@ -314,14 +344,17 @@ class Seq2SeqLM(Model):
if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1].unsqueeze(-1)
# Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
# internally...
if encoder_last_hidden_state is not None:
encoder_last_hidden_state = [encoder_last_hidden_state]
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=[encoder_last_hidden_state]
if encoder_last_hidden_state is not None
else None,
encoder_outputs=encoder_last_hidden_state,
past_key_values=past_key_values,
use_cache=True,
)
@ -351,11 +384,12 @@ class Seq2SeqLM(Model):
# List of indices to cache
next_batch_keep_indices = []
# New input_ids for next forward
# New values for next forward
next_batch_input_lengths = []
next_batch_decoder_input_ids = []
next_batch_decoder_input_lengths = []
# Metadata
next_batch_size = 0
next_batch_max_input_length = 0
next_batch_max_decoder_input_length = 0
@ -395,7 +429,7 @@ class Seq2SeqLM(Model):
# Evaluate stopping criteria
if stopping_criteria(decoder_tokens):
# Decode all tokens
# Decode tokens
output = self.tokenizer.decode(decoder_tokens, skip_special_tokens=True)
# Add to the list of finished generations with the original request
generated_texts.append(
@ -420,9 +454,11 @@ class Seq2SeqLM(Model):
if not next_batch_keep_indices:
return generated_texts, None
# If we finished at least one generation
next_batch_decoder_input_ids = torch.cat(next_batch_decoder_input_ids)
# If we finished at least one generation, we need to evict the indices of the generations that finished
# from the values of the next batch
if generated_texts:
# Apply indices to attention mask, past key values and other items that need to be cached
next_batch_input_ids = batch.input_ids[next_batch_keep_indices]
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
@ -458,7 +494,7 @@ class Seq2SeqLM(Model):
next_batch_next_token_choosers = batch.next_token_choosers
next_batch_stopping_criterias = batch.stopping_criterias
# Update attention_mask with padding as we added a new token to input_ids
# Update decoder_attention_mask with padding as we added a new token to input_ids
if next_batch_decoder_attention_mask is not None:
next_batch_decoder_attention_mask = torch.cat(
[