feat(server): Improved doc

This commit is contained in:
OlivierDehaene 2022-11-07 12:53:56 +01:00
parent cea6051eff
commit 4236e41b0d
9 changed files with 195 additions and 101 deletions

View File

@ -28,6 +28,7 @@ ENV LANG=C.UTF-8 \
MODEL_NAME=bigscience/bloom \ MODEL_NAME=bigscience/bloom \
QUANTIZE=false \ QUANTIZE=false \
NUM_GPUS=8 \ NUM_GPUS=8 \
SAFETENSORS_FAST_GPU=1 \
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
NCCL_ASYNC_ERROR_HANDLING=1 \ NCCL_ASYNC_ERROR_HANDLING=1 \
CUDA_HOME=/usr/local/cuda \ CUDA_HOME=/usr/local/cuda \
@ -55,12 +56,6 @@ RUN cd server && make install-torch
# Install specific version of transformers # Install specific version of transformers
RUN cd server && make install-transformers RUN cd server && make install-transformers
# Install specific version of safetensors
# FIXME: This is a temporary fix while we wait for a new release
RUN curl https://sh.rustup.rs -sSf | bash -s -- -y
ENV PATH="/root/.cargo/bin:${PATH}"
RUN cd server && make install-safetensors
# Install server # Install server
COPY proto proto COPY proto proto
COPY server server COPY server server

View File

@ -6,7 +6,8 @@
</div> </div>
A Rust and gRPC server for text generation inference. A Rust and gRPC server for text generation inference. Used in production at [HuggingFace](https://huggingface.co)
to power Bloom, BloomZ and MT0-XXL api-inference widgets.
## Features ## Features
@ -15,11 +16,11 @@ A Rust and gRPC server for text generation inference.
- [Safetensors](https://github.com/huggingface/safetensors) weight loading - [Safetensors](https://github.com/huggingface/safetensors) weight loading
- 45ms per token generation for BLOOM with 8xA100 80GB - 45ms per token generation for BLOOM with 8xA100 80GB
## Officialy supported models ## Officially supported models
- BLOOM - [BLOOM](https://huggingface.co/bigscience/bloom)
- BLOOMZ - [BLOOMZ](https://huggingface.co/bigscience/bloomz)
- BLOOM-560m - [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl)
Other models are supported on a best effort basis using: Other models are supported on a best effort basis using:
@ -90,5 +91,4 @@ make router-dev
## TODO: ## TODO:
- [ ] Add tests for the `server/model` logic - [ ] Add tests for the `server/model` logic
- [ ] Backport custom CUDA kernels to Transformers - [ ] Backport custom CUDA kernels to Transformers
- [ ] Install safetensors with pip

View File

@ -295,6 +295,10 @@ fn shard_manager(
"MASTER_PORT".parse().unwrap(), "MASTER_PORT".parse().unwrap(),
master_port.to_string().parse().unwrap(), master_port.to_string().parse().unwrap(),
), ),
(
"SAFETENSORS_FAST_GPU".parse().unwrap(),
"1".to_string().parse().unwrap(),
),
]; ];
// If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard // If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard

View File

@ -16,24 +16,13 @@ install-transformers:
mv transformers-7302a24535e8dc5637ea5b4e4572fc971d404098 transformers mv transformers-7302a24535e8dc5637ea5b4e4572fc971d404098 transformers
cd transformers && python setup.py install cd transformers && python setup.py install
install-safetensors:
# Install specific version of safetensors
pip install setuptools_rust
rm safetensors || true
rm safetensors-634deccbcbad5eaf417935281f8b3be7ebca69c5 || true
curl -L -O https://github.com/huggingface/safetensors/archive/634deccbcbad5eaf417935281f8b3be7ebca69c5.zip
unzip 634deccbcbad5eaf417935281f8b3be7ebca69c5.zip
rm 634deccbcbad5eaf417935281f8b3be7ebca69c5.zip
mv safetensors-634deccbcbad5eaf417935281f8b3be7ebca69c5 safetensors
cd safetensors/bindings/python && python setup.py develop
install-torch: install-torch:
# Install specific version of torch # Install specific version of torch
pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 --no-cache-dir pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 --no-cache-dir
install: gen-server install-torch install-transformers install-safetensors install: gen-server install-torch install-transformers
pip install pip --upgrade pip install pip --upgrade
pip install -e . --no-cache-dir pip install -e . --no-cache-dir
run-dev: run-dev:
python -m torch.distributed.run --nproc_per_node=2 text_generation/cli.py serve bigscience/bloom-560m --sharded SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation/cli.py serve bigscience/bloom-560m --sharded

47
server/poetry.lock generated
View File

@ -145,6 +145,18 @@ category = "main"
optional = false optional = false
python-versions = ">=3.6" python-versions = ">=3.6"
[[package]]
name = "safetensors"
version = "0.2.4"
description = "Fast and Safe Tensor serialization"
category = "main"
optional = false
python-versions = "*"
[package.extras]
dev = ["black (==22.3)", "flake8 (>=3.8.3)", "huggingface-hub", "isort (>=5.5.4)", "numpy", "pytest", "setuptools-rust"]
testing = ["black (==22.3)", "flake8 (>=3.8.3)", "huggingface-hub", "isort (>=5.5.4)", "numpy", "pytest", "setuptools-rust"]
[[package]] [[package]]
name = "setuptools" name = "setuptools"
version = "65.5.0" version = "65.5.0"
@ -208,7 +220,7 @@ bnb = ["bitsandbytes"]
[metadata] [metadata]
lock-version = "1.1" lock-version = "1.1"
python-versions = "^3.9" python-versions = "^3.9"
content-hash = "224b1e379d6105fe911bff4563946a90dfa6ff5918cf2e7be59f8d4f7c5cd7cf" content-hash = "3266187ef14fe8f9e29b3b6530d07781ea952aa670c0fe0de34be43efa231a67"
[metadata.files] [metadata.files]
accelerate = [ accelerate = [
@ -459,6 +471,39 @@ PyYAML = [
{file = "PyYAML-6.0-cp39-cp39-win_amd64.whl", hash = "sha256:b3d267842bf12586ba6c734f89d1f5b871df0273157918b0ccefa29deb05c21c"}, {file = "PyYAML-6.0-cp39-cp39-win_amd64.whl", hash = "sha256:b3d267842bf12586ba6c734f89d1f5b871df0273157918b0ccefa29deb05c21c"},
{file = "PyYAML-6.0.tar.gz", hash = "sha256:68fb519c14306fec9720a2a5b45bc9f0c8d1b9c72adf45c37baedfcd949c35a2"}, {file = "PyYAML-6.0.tar.gz", hash = "sha256:68fb519c14306fec9720a2a5b45bc9f0c8d1b9c72adf45c37baedfcd949c35a2"},
] ]
safetensors = [
{file = "safetensors-0.2.4-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:79c4a7610d7699c64d8531c43f758ded4990ebaa7b0887c2078640e6de44e726"},
{file = "safetensors-0.2.4-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:ef425a4ddd29612fe733a6eeca6ad8f3ee3939f530a032114974aac4c4667b89"},
{file = "safetensors-0.2.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:77758f8ba4de6e20bf394dd964854a926dee2efee82eaa95e6c0893e2a7d960c"},
{file = "safetensors-0.2.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fb956e9090cce515649f00b491b5ddc0f9c3d989139016a8d69f9dcf57e8d3d9"},
{file = "safetensors-0.2.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3e31b02d27249bd519f05ec9d189097c59fc6851c59daa1a86ef347659e33ac3"},
{file = "safetensors-0.2.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c2fead03a1497042efea4358574f3d7acf501b0c82e54d605f393f2b4e2aafe"},
{file = "safetensors-0.2.4-cp310-cp310-win32.whl", hash = "sha256:dce6ed3c7d13aafa574737eb3309c928adcb6781e879b41f0861be83b439cf3e"},
{file = "safetensors-0.2.4-cp310-cp310-win_amd64.whl", hash = "sha256:1dfe727325a1342767c6725dc2cc1f00463eb40a1f5df37c338d8e03957e27ce"},
{file = "safetensors-0.2.4-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:c066bc7b90a582a01ec468fef61a7581b5c726bf12c50491cb6ea5db215ea5e0"},
{file = "safetensors-0.2.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca6ed53dad5d7d0e67eb676528ff2ad345cac3a34010e4dc1e3736972de294a5"},
{file = "safetensors-0.2.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ada03b44acbb036cfabe7066a8df4ad9b1ac05bb585a6b6c0f285f08e016381d"},
{file = "safetensors-0.2.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:58a0902708daa7ec2b2293b46e85df61f4fa359ddfe648e7ac025a79e6f59627"},
{file = "safetensors-0.2.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e0a4e38f7cbb4bfc513588e52f349b906c941e74fbbe192f2b19fc34221d448"},
{file = "safetensors-0.2.4-cp37-cp37m-win32.whl", hash = "sha256:4f8695b77dd847203258f035f8468f8b701c90621cb6b457e109f8d89c27f16c"},
{file = "safetensors-0.2.4-cp37-cp37m-win_amd64.whl", hash = "sha256:16b08f33c753c7da64b3999beea7c30d58204a0820961e33881d05a331e3f5c0"},
{file = "safetensors-0.2.4-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:a381606804f23db9eede51135f5fbd1f75dda02100415ee150fd39eb1cd6be4c"},
{file = "safetensors-0.2.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7aceae84d0c7233d83923029aaf8d184848561e0211ec98c5317327b3db025d6"},
{file = "safetensors-0.2.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:da48fc929485cbd9ee22621e388764a7cef27b0205e73aee2ad75aadd7d67662"},
{file = "safetensors-0.2.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2619b88f934c4de6b59de90c9dc00eae2d0e30f254a1daebd6eb232ac1f9a7a7"},
{file = "safetensors-0.2.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d1f78b987ae1f6b71da8ea110164e4cab2ee31b53835d2a66279df89c5d73f0e"},
{file = "safetensors-0.2.4-cp38-cp38-win32.whl", hash = "sha256:34b3e60b5130fb0fe07114705e51d30aa2c7eae4c1d1e77d6f260fa4ade70ede"},
{file = "safetensors-0.2.4-cp38-cp38-win_amd64.whl", hash = "sha256:debaa4fa98a7af44ba6dcb6945efee77b8480284c2cb05918ab97cf511c40826"},
{file = "safetensors-0.2.4-cp39-cp39-macosx_10_11_x86_64.whl", hash = "sha256:90baaafc0c872a736124b341db54b0bdd61765cbf3a61418371066a37905b18d"},
{file = "safetensors-0.2.4-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:b4bf7e23191d6a3ff00de141512869fc776e8ff159c872cb44af018cb04d45eb"},
{file = "safetensors-0.2.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf11a3aba8796e548ceb0a65f34dcd334dcf0c4c891dccabe18a8b53918ae8ab"},
{file = "safetensors-0.2.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:95c31935ea71d63a38c546654136d7f0dbf1e7aeb6564dbc2201bc1fe9b34e4c"},
{file = "safetensors-0.2.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ef31776e2e081d6f075408eed34a0fbd524cbd19e50268bef02c238b209213b7"},
{file = "safetensors-0.2.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d06bb1d68148f6d6934352124d8cbfcf0db092f969db7187e348bd5cbf183db5"},
{file = "safetensors-0.2.4-cp39-cp39-win32.whl", hash = "sha256:5d546152b9a5bd58eae97c2ddefba394404d37ddedec305f7639c9b6054513e5"},
{file = "safetensors-0.2.4-cp39-cp39-win_amd64.whl", hash = "sha256:553ecfd895d379c1e03a7c9241f7343b3af66573436969ed7eb95df81dfbe9af"},
{file = "safetensors-0.2.4.tar.gz", hash = "sha256:35c0719a898f1f1292464f4cd9370bb6c2698032f1db4d677489f078b66b5a75"},
]
setuptools = [ setuptools = [
{file = "setuptools-65.5.0-py3-none-any.whl", hash = "sha256:f62ea9da9ed6289bfe868cd6845968a2c854d1427f8548d52cae02a42b4f0356"}, {file = "setuptools-65.5.0-py3-none-any.whl", hash = "sha256:f62ea9da9ed6289bfe868cd6845968a2c854d1427f8548d52cae02a42b4f0356"},
{file = "setuptools-65.5.0.tar.gz", hash = "sha256:512e5536220e38146176efb833d4a62aa726b7bbff82cfbc8ba9eaa3996e0b17"}, {file = "setuptools-65.5.0.tar.gz", hash = "sha256:512e5536220e38146176efb833d4a62aa726b7bbff82cfbc8ba9eaa3996e0b17"},

View File

@ -15,6 +15,7 @@ typer = "^0.6.1"
grpcio-reflection = "^1.49.1" grpcio-reflection = "^1.49.1"
accelerate = "^0.12.0" accelerate = "^0.12.0"
bitsandbytes = "^0.35.1" bitsandbytes = "^0.35.1"
safetensors = "^0.2.4"
[tool.poetry.extras] [tool.poetry.extras]
bnb = ["bitsandbytes"] bnb = ["bitsandbytes"]

View File

@ -9,17 +9,13 @@ __all__ = ["Model", "BLOOMSharded", "CausalLM", "Seq2SeqLM"]
def get_model(model_name: str, sharded: bool, quantize: bool) -> Model: def get_model(model_name: str, sharded: bool, quantize: bool) -> Model:
if model_name.startswith("bigscience/bloom"): if model_name.startswith("bigscience/bloom"):
if sharded: if sharded:
return BLOOMSharded(model_name, quantize) return BLOOMSharded(model_name, quantize=quantize)
else: else:
if quantize: return CausalLM(model_name, quantize=quantize)
raise ValueError("quantization is not supported for non-sharded BLOOM")
return CausalLM(model_name)
else: else:
if sharded: if sharded:
raise ValueError("sharded is not supported for AutoModel") raise ValueError("sharded is not supported for AutoModel")
if quantize:
raise ValueError("quantize is not supported for AutoModel")
try: try:
return CausalLM(model_name) return CausalLM(model_name, quantize=quantize)
except Exception as e: except Exception as e:
return Seq2SeqLM(model_name) return Seq2SeqLM(model_name, quantize=quantize)

View File

@ -2,7 +2,7 @@ import torch
from dataclasses import dataclass from dataclasses import dataclass
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import Optional, Tuple, List, Dict, Type from typing import Optional, Tuple, List, Type
from text_generation.models import Model from text_generation.models import Model
from text_generation.models.types import GeneratedText from text_generation.models.types import GeneratedText
@ -14,11 +14,23 @@ from text_generation.utils import NextTokenChooser, StoppingCriteria
class CausalLMBatch: class CausalLMBatch:
batch_id: int batch_id: int
requests: List[generate_pb2.Request] requests: List[generate_pb2.Request]
all_input_lengths: List[int]
input_ids: Dict[str, torch.Tensor] # Decoder values
input_ids: torch.Tensor
attention_mask: torch.Tensor
past_key_values: Optional[List[Tuple]]
# All tokens
all_input_ids: List[torch.Tensor] all_input_ids: List[torch.Tensor]
# Lengths of all generations present in the batch
input_lengths: List[int]
# Generation helpers
next_token_choosers: List[NextTokenChooser] next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria] stopping_criterias: List[StoppingCriteria]
# Metadata used for padding
size: int size: int
max_sequence_length: int max_sequence_length: int
@ -36,12 +48,12 @@ class CausalLMBatch:
inputs = [] inputs = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
all_input_lengths = [] input_lengths = []
# Parse batch # Parse batch
for r in pb.requests: for r in pb.requests:
inputs.append(r.inputs) inputs.append(r.inputs)
all_input_lengths.append(r.input_length) input_lengths.append(r.input_length)
next_token_choosers.append( next_token_choosers.append(
NextTokenChooser( NextTokenChooser(
temperature=r.parameters.temperature, temperature=r.parameters.temperature,
@ -56,21 +68,23 @@ class CausalLMBatch:
) )
) )
input_ids = tokenizer( tokenized_inputs = tokenizer(
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8 inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
).to(device) ).to(device)
all_input_ids = input_ids["input_ids"].unsqueeze(-1) all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
requests=pb.requests, requests=pb.requests,
all_input_lengths=all_input_lengths, input_ids=tokenized_inputs["input_ids"],
input_ids=input_ids, attention_mask=tokenized_inputs["attention_mask"],
past_key_values=None,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
input_lengths=input_lengths,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
size=pb.size, size=pb.size,
max_sequence_length=max(all_input_lengths), max_sequence_length=max(input_lengths),
) )
@classmethod @classmethod
@ -80,19 +94,23 @@ class CausalLMBatch:
max_sequence_length = max(batch.max_sequence_length for batch in batches) max_sequence_length = max(batch.max_sequence_length for batch in batches)
# Batch attributes # Batch attributes
input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []}
requests = [] requests = []
all_input_lengths = [] input_lengths = []
all_input_ids = [] all_input_ids = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
# Batch tensors
input_ids = None
attention_mask = None
past_key_values = []
# Used for slicing correctly inside the tensors # Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes # Equivalent to a cumsum on batch sizes
start_index = 0 start_index = 0
for i, batch in enumerate(batches): for i, batch in enumerate(batches):
requests.extend(batch.requests) requests.extend(batch.requests)
all_input_lengths.extend(batch.all_input_lengths) input_lengths.extend(batch.input_lengths)
all_input_ids.extend(batch.all_input_ids) all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers) next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
@ -101,32 +119,35 @@ class CausalLMBatch:
end_index = start_index + batch.size end_index = start_index + batch.size
# We only concatenate batches that did at least one step # We only concatenate batches that did at least one step
if batch.input_ids["input_ids"].shape[1] > 1: if batch.input_ids.shape[1] > 1:
raise ValueError("Batch input_ids should be of shape (batch_size, 1)") raise ValueError("Batch input_ids should be of shape (batch_size, 1)")
# Initialize tensors # Create empty tensor
if i == 0: # input_ids is always of shape [batch_size, 1]
input_ids["input_ids"] = torch.empty(
(total_batch_size, 1),
dtype=batch.input_ids["input_ids"].dtype,
device=batch.input_ids["input_ids"].device,
)
input_ids["attention_mask"] = torch.zeros(
(total_batch_size, max_sequence_length),
dtype=batch.input_ids["attention_mask"].dtype,
device=batch.input_ids["attention_mask"].device,
)
# input_ids["input_ids"] is always of shape [batch_size, 1]
# We do not need to pad it # We do not need to pad it
input_ids["input_ids"][start_index:end_index] = batch.input_ids["input_ids"] if input_ids is None:
input_ids = torch.empty(
(total_batch_size, 1),
dtype=batch.input_ids.dtype,
device=batch.input_ids.device,
)
# Copy to correct indices
input_ids[start_index:end_index] = batch.input_ids
# Create padded tensor
if attention_mask is None:
attention_mask = torch.zeros(
(total_batch_size, max_sequence_length),
dtype=batch.attention_mask.dtype,
device=batch.attention_mask.device,
)
# We need to slice the attention mask to remove padding from previous steps # We need to slice the attention mask to remove padding from previous steps
input_ids["attention_mask"][ attention_mask[
start_index:end_index, -batch.max_sequence_length : start_index:end_index, -batch.max_sequence_length :
] = batch.input_ids["attention_mask"][:, -batch.max_sequence_length :] ] = batch.attention_mask[:, -batch.max_sequence_length :]
for j, past in enumerate(batch.input_ids["past_key_values"]): for j, past in enumerate(batch.past_key_values):
# Shenanigans to get dimensions because BLOOM outputs a past with a different shape # Shenanigans to get dimensions because BLOOM outputs a past with a different shape
# BLOOM: [batch_size * num_heads, ...] vs [batch_size, num_heads, ...] # BLOOM: [batch_size * num_heads, ...] vs [batch_size, num_heads, ...]
head_dim, padded_sequence_length = past[0].shape[-2:] head_dim, padded_sequence_length = past[0].shape[-2:]
@ -137,8 +158,8 @@ class CausalLMBatch:
) )
# This will run only once per layer # This will run only once per layer
if j == len(input_ids["past_key_values"]): if j == len(past_key_values):
input_ids["past_key_values"].append([]) past_key_values.append([])
# Decoder past # Decoder past
for k, t in enumerate(past): for k, t in enumerate(past):
@ -172,21 +193,21 @@ class CausalLMBatch:
# Initialize tensors # Initialize tensors
# This will run only once per layer and per past tensor # This will run only once per layer and per past tensor
if k == len(input_ids["past_key_values"][j]): if k == len(past_key_values[j]):
input_ids["past_key_values"][j].append( past_key_values[j].append(
torch.zeros(padded_t_shape, dtype=t.dtype, device=t.device) torch.zeros(padded_t_shape, dtype=t.dtype, device=t.device)
) )
# We slice the past keys and values to remove the padding from previous batches # We slice the past keys and values to remove the padding from previous batches
if not head_dim_last: if not head_dim_last:
input_ids["past_key_values"][j][k][ past_key_values[j][k][
start_index:end_index, start_index:end_index,
:, :,
:, :,
-(batch.max_sequence_length - 1) :, -(batch.max_sequence_length - 1) :,
] = t[:, :, :, -(batch.max_sequence_length - 1) :] ] = t[:, :, :, -(batch.max_sequence_length - 1) :]
else: else:
input_ids["past_key_values"][j][k][ past_key_values[j][k][
start_index:end_index, start_index:end_index,
:, :,
-(batch.max_sequence_length - 1) :, -(batch.max_sequence_length - 1) :,
@ -198,9 +219,11 @@ class CausalLMBatch:
return cls( return cls(
batch_id=batches[0].batch_id, batch_id=batches[0].batch_id,
requests=requests, requests=requests,
all_input_lengths=all_input_lengths,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
input_lengths=input_lengths,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
size=total_batch_size, size=total_batch_size,
@ -209,7 +232,7 @@ class CausalLMBatch:
class CausalLM(Model): class CausalLM(Model):
def __init__(self, model_name: str): def __init__(self, model_name: str, quantize=False):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
@ -223,6 +246,7 @@ class CausalLM(Model):
model_name, model_name,
torch_dtype=dtype, torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None, device_map="auto" if torch.cuda.is_available() else None,
load_in_8bit=quantize,
).eval() ).eval()
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
@ -255,16 +279,19 @@ class CausalLM(Model):
torch.no_grad if self.device.type == "cpu" else torch.inference_mode torch.no_grad if self.device.type == "cpu" else torch.inference_mode
) )
with context_manager(): with context_manager():
logits, past = self.forward(**batch.input_ids) logits, past = self.forward(
batch.input_ids, batch.attention_mask, batch.past_key_values
)
# List of indices to cache # List of indices to cache
next_batch_keep_indices = [] next_batch_keep_indices = []
# New input_ids for next forward # New values for next forward
next_batch_input_lengths = []
next_batch_input_ids = [] next_batch_input_ids = []
next_batch_all_input_ids = [] next_batch_all_input_ids = []
next_all_input_lengths = []
# Metadata
next_batch_size = 0 next_batch_size = 0
next_batch_max_sequence_length = 0 next_batch_max_sequence_length = 0
@ -274,7 +301,7 @@ class CausalLM(Model):
# Zipped iterator # Zipped iterator
iterator = zip( iterator = zip(
batch.requests, batch.requests,
batch.all_input_lengths, batch.input_lengths,
logits, logits,
batch.next_token_choosers, batch.next_token_choosers,
batch.stopping_criterias, batch.stopping_criterias,
@ -313,7 +340,7 @@ class CausalLM(Model):
next_batch_all_input_ids.append(all_tokens) next_batch_all_input_ids.append(all_tokens)
next_batch_size += 1 next_batch_size += 1
new_input_length = input_length + 1 new_input_length = input_length + 1
next_all_input_lengths.append(new_input_length) next_batch_input_lengths.append(new_input_length)
next_batch_max_sequence_length = max( next_batch_max_sequence_length = max(
next_batch_max_sequence_length, new_input_length next_batch_max_sequence_length, new_input_length
) )
@ -322,15 +349,14 @@ class CausalLM(Model):
if not next_batch_keep_indices: if not next_batch_keep_indices:
return generated_texts, None return generated_texts, None
# If we finished at least one generation next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0)
next_batch_input_ids = {"input_ids": torch.cat(next_batch_input_ids, dim=0)} # If we finished at least one generation, we need to evict the indices of the generations that finished
# from the values of the next batch
if generated_texts: if generated_texts:
# Apply indices to attention mask, past key values and other items that need to be cached # Apply indices to attention mask, past key values and other items that need to be cached
next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"][ next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
next_batch_keep_indices
]
# Force past to be of dim [batch_size, num_heads, ...] for easy indexing # Force past to be of dim [batch_size, num_heads, ...] for easy indexing
next_batch_input_ids["past_key_values"] = [ next_batch_past_key_values = [
[ [
t.view(-1, self.num_heads, *t.shape[-2:])[next_batch_keep_indices] t.view(-1, self.num_heads, *t.shape[-2:])[next_batch_keep_indices]
for t in layer for t in layer
@ -345,16 +371,16 @@ class CausalLM(Model):
batch.stopping_criterias[i] for i in next_batch_keep_indices batch.stopping_criterias[i] for i in next_batch_keep_indices
] ]
else: else:
next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"] next_batch_attention_mask = batch.attention_mask
next_batch_input_ids["past_key_values"] = past next_batch_past_key_values = past
next_batch_requests = batch.requests next_batch_requests = batch.requests
next_batch_next_token_choosers = batch.next_token_choosers next_batch_next_token_choosers = batch.next_token_choosers
next_batch_stopping_criterias = batch.stopping_criterias next_batch_stopping_criterias = batch.stopping_criterias
# Update attention_mask with padding as we added a new token to input_ids # Update attention_mask with padding as we added a new token to input_ids
next_batch_input_ids["attention_mask"] = torch.cat( next_batch_attention_mask = torch.cat(
[ [
next_batch_input_ids["attention_mask"], next_batch_attention_mask,
torch.ones((next_batch_size, 1)).to(self.device), torch.ones((next_batch_size, 1)).to(self.device),
], ],
dim=1, dim=1,
@ -363,9 +389,11 @@ class CausalLM(Model):
next_batch = CausalLMBatch( next_batch = CausalLMBatch(
batch_id=batch.batch_id, batch_id=batch.batch_id,
requests=next_batch_requests, requests=next_batch_requests,
all_input_lengths=next_all_input_lengths,
input_ids=next_batch_input_ids, input_ids=next_batch_input_ids,
attention_mask=next_batch_attention_mask,
past_key_values=next_batch_past_key_values,
all_input_ids=next_batch_all_input_ids, all_input_ids=next_batch_all_input_ids,
input_lengths=next_batch_input_lengths,
next_token_choosers=next_batch_next_token_choosers, next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias, stopping_criterias=next_batch_stopping_criterias,
size=next_batch_size, size=next_batch_size,

View File

@ -15,26 +15,33 @@ class Seq2SeqLMBatch:
batch_id: int batch_id: int
requests: List[generate_pb2.Request] requests: List[generate_pb2.Request]
# Encoder values
input_ids: torch.Tensor input_ids: torch.Tensor
attention_mask: torch.Tensor attention_mask: torch.Tensor
# Decoder values
decoder_input_ids: torch.Tensor decoder_input_ids: torch.Tensor
decoder_attention_mask: Optional[torch.Tensor] decoder_attention_mask: Optional[torch.Tensor]
encoder_last_hidden_state: Optional[torch.Tensor] encoder_last_hidden_state: Optional[torch.Tensor]
# Seq2SeqLM keeps track of both encoder and decoder attention keys and values
past_key_values: Optional[List[Tuple]] past_key_values: Optional[List[Tuple]]
# Lengths of all generations present in the batch
input_lengths: List[int] input_lengths: List[int]
decoder_input_lengths: List[int] decoder_input_lengths: List[int]
# Generation helpers
next_token_choosers: List[NextTokenChooser] next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria] stopping_criterias: List[StoppingCriteria]
# Metadata used for padding
size: int size: int
max_input_length: int max_input_length: int
max_decoder_input_length: int max_decoder_input_length: int
def to_pb(self): def to_pb(self):
"""Convert a Seq2SeqLMBatch to a text_generation.v1.Batch protobuf"""
return generate_pb2.Batch( return generate_pb2.Batch(
id=self.batch_id, id=self.batch_id,
requests=self.requests, requests=self.requests,
@ -45,6 +52,7 @@ class Seq2SeqLMBatch:
def from_pb( def from_pb(
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
) -> "Seq2SeqLMBatch": ) -> "Seq2SeqLMBatch":
"""Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch"""
inputs = [] inputs = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
@ -57,6 +65,7 @@ class Seq2SeqLMBatch:
for r in pb.requests: for r in pb.requests:
inputs.append(r.inputs) inputs.append(r.inputs)
input_lengths.append(r.input_length) input_lengths.append(r.input_length)
# Decoder sequence only contains the bos_token
decoder_input_ids.append(tokenizer.bos_token_id) decoder_input_ids.append(tokenizer.bos_token_id)
decoder_input_lengths.append(1) decoder_input_lengths.append(1)
next_token_choosers.append( next_token_choosers.append(
@ -73,9 +82,11 @@ class Seq2SeqLMBatch:
) )
) )
# Tokenize batch
tokenized_inputs = tokenizer( tokenized_inputs = tokenizer(
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8 inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
).to(device) ).to(device)
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
decoder_input_ids = torch.tensor(decoder_input_ids).to(device).unsqueeze(-1) decoder_input_ids = torch.tensor(decoder_input_ids).to(device).unsqueeze(-1)
return cls( return cls(
@ -98,6 +109,8 @@ class Seq2SeqLMBatch:
@classmethod @classmethod
def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch": def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
"""Concatenate multiple batches together by padding internal torch tensors"""
# Used for padding # Used for padding
total_batch_size = sum(batch.size for batch in batches) total_batch_size = sum(batch.size for batch in batches)
max_input_length = max(batch.max_input_length for batch in batches) max_input_length = max(batch.max_input_length for batch in batches)
@ -112,6 +125,7 @@ class Seq2SeqLMBatch:
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
# Batch tensors
input_ids = None input_ids = None
attention_mask = None attention_mask = None
decoder_input_ids = None decoder_input_ids = None
@ -122,7 +136,9 @@ class Seq2SeqLMBatch:
# Used for slicing correctly inside the tensors # Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes # Equivalent to a cumsum on batch sizes
start_index = 0 start_index = 0
for i, batch in enumerate(batches): for i, batch in enumerate(batches):
# Extend all list attributes
requests.extend(batch.requests) requests.extend(batch.requests)
input_lengths.extend(batch.input_lengths) input_lengths.extend(batch.input_lengths)
decoder_input_lengths.extend(batch.decoder_input_lengths) decoder_input_lengths.extend(batch.decoder_input_lengths)
@ -136,51 +152,62 @@ class Seq2SeqLMBatch:
if batch.encoder_last_hidden_state is None: if batch.encoder_last_hidden_state is None:
raise ValueError("Batch encoder_last_hidden_state cannot be None") raise ValueError("Batch encoder_last_hidden_state cannot be None")
# Create padded tensor
if input_ids is None: if input_ids is None:
input_ids = torch.zeros( input_ids = torch.zeros(
(total_batch_size, max_input_length), (total_batch_size, max_input_length),
dtype=batch.input_ids.dtype, dtype=batch.input_ids.dtype,
device=batch.input_ids.device, device=batch.input_ids.device,
) )
# Copy to correct indices
input_ids[ input_ids[
start_index:end_index, -batch.max_input_length : start_index:end_index, -batch.max_input_length :
] = batch.input_ids[:, -batch.max_input_length :] ] = batch.input_ids[:, -batch.max_input_length :]
# Create padded tensor
if attention_mask is None: if attention_mask is None:
attention_mask = torch.zeros( attention_mask = torch.zeros(
(total_batch_size, max_input_length), (total_batch_size, max_input_length),
dtype=batch.attention_mask.dtype, dtype=batch.attention_mask.dtype,
device=batch.attention_mask.device, device=batch.attention_mask.device,
) )
# Copy to correct indices
attention_mask[ attention_mask[
start_index:end_index, -batch.max_input_length : start_index:end_index, -batch.max_input_length :
] = batch.attention_mask[:, -batch.max_input_length :] ] = batch.attention_mask[:, -batch.max_input_length :]
# Create padded tensor
if decoder_input_ids is None: if decoder_input_ids is None:
decoder_input_ids = torch.zeros( decoder_input_ids = torch.zeros(
(total_batch_size, max_decoder_input_length), (total_batch_size, max_decoder_input_length),
dtype=batch.decoder_input_ids.dtype, dtype=batch.decoder_input_ids.dtype,
device=batch.decoder_input_ids.device, device=batch.decoder_input_ids.device,
) )
# Copy to correct indices
decoder_input_ids[ decoder_input_ids[
start_index:end_index, -batch.max_decoder_input_length : start_index:end_index, -batch.max_decoder_input_length :
] = batch.decoder_input_ids[:, -batch.max_decoder_input_length :] ] = batch.decoder_input_ids[:, -batch.max_decoder_input_length :]
# Create padded tensor
if decoder_attention_mask is None: if decoder_attention_mask is None:
decoder_attention_mask = torch.zeros( decoder_attention_mask = torch.zeros(
(total_batch_size, max_decoder_input_length), (total_batch_size, max_decoder_input_length),
dtype=batch.attention_mask.dtype, dtype=batch.attention_mask.dtype, # As decoder_attention_mask might not exist,
device=batch.attention_mask.device, device=batch.attention_mask.device, # we use `batch.attention_maks` for device here
) )
# If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
# this batch. All generations are of length `batch.max_decoder_input_length`.
if batch.decoder_attention_mask is None: if batch.decoder_attention_mask is None:
decoder_attention_mask[ decoder_attention_mask[
start_index:end_index, -batch.max_decoder_input_length : start_index:end_index, -batch.max_decoder_input_length :
] = 1 ] = 1
# If it exists, we need to index
else: else:
decoder_attention_mask[ decoder_attention_mask[
start_index:end_index, -batch.max_decoder_input_length : start_index:end_index, -batch.max_decoder_input_length :
] = batch.decoder_attention_mask[:, -batch.max_decoder_input_length :] ] = batch.decoder_attention_mask[:, -batch.max_decoder_input_length :]
# Create padded tensor
if encoder_last_hidden_state is None: if encoder_last_hidden_state is None:
encoder_last_hidden_state = torch.zeros( encoder_last_hidden_state = torch.zeros(
( (
@ -192,10 +219,12 @@ class Seq2SeqLMBatch:
device=batch.encoder_last_hidden_state.device, device=batch.encoder_last_hidden_state.device,
) )
# Copy to correct indices
encoder_last_hidden_state[ encoder_last_hidden_state[
start_index:end_index, -batch.max_decoder_input_length :, : start_index:end_index, -batch.max_decoder_input_length :, :
] = batch.encoder_last_hidden_state[:, -batch.max_decoder_input_length :, :] ] = batch.encoder_last_hidden_state[:, -batch.max_decoder_input_length :, :]
# Iterate over attention layers
for j, past in enumerate(batch.past_key_values): for j, past in enumerate(batch.past_key_values):
_, num_heads, _, head_dim = past[0].shape _, num_heads, _, head_dim = past[0].shape
@ -271,7 +300,7 @@ class Seq2SeqLMBatch:
class Seq2SeqLM(Model): class Seq2SeqLM(Model):
def __init__(self, model_name: str): def __init__(self, model_name: str, quantize=False):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
@ -283,6 +312,7 @@ class Seq2SeqLM(Model):
model_name, model_name,
torch_dtype=dtype, torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None, device_map="auto" if torch.cuda.is_available() else None,
load_in_8bit=quantize,
).eval() ).eval()
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
tokenizer.bos_token_id = self.model.config.decoder_start_token_id tokenizer.bos_token_id = self.model.config.decoder_start_token_id
@ -314,14 +344,17 @@ class Seq2SeqLM(Model):
if past_key_values is not None: if past_key_values is not None:
decoder_input_ids = decoder_input_ids[:, -1].unsqueeze(-1) decoder_input_ids = decoder_input_ids[:, -1].unsqueeze(-1)
# Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
# internally...
if encoder_last_hidden_state is not None:
encoder_last_hidden_state = [encoder_last_hidden_state]
outputs = self.model.forward( outputs = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
encoder_outputs=[encoder_last_hidden_state] encoder_outputs=encoder_last_hidden_state,
if encoder_last_hidden_state is not None
else None,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=True, use_cache=True,
) )
@ -351,11 +384,12 @@ class Seq2SeqLM(Model):
# List of indices to cache # List of indices to cache
next_batch_keep_indices = [] next_batch_keep_indices = []
# New input_ids for next forward # New values for next forward
next_batch_input_lengths = [] next_batch_input_lengths = []
next_batch_decoder_input_ids = [] next_batch_decoder_input_ids = []
next_batch_decoder_input_lengths = [] next_batch_decoder_input_lengths = []
# Metadata
next_batch_size = 0 next_batch_size = 0
next_batch_max_input_length = 0 next_batch_max_input_length = 0
next_batch_max_decoder_input_length = 0 next_batch_max_decoder_input_length = 0
@ -395,7 +429,7 @@ class Seq2SeqLM(Model):
# Evaluate stopping criteria # Evaluate stopping criteria
if stopping_criteria(decoder_tokens): if stopping_criteria(decoder_tokens):
# Decode all tokens # Decode tokens
output = self.tokenizer.decode(decoder_tokens, skip_special_tokens=True) output = self.tokenizer.decode(decoder_tokens, skip_special_tokens=True)
# Add to the list of finished generations with the original request # Add to the list of finished generations with the original request
generated_texts.append( generated_texts.append(
@ -420,9 +454,11 @@ class Seq2SeqLM(Model):
if not next_batch_keep_indices: if not next_batch_keep_indices:
return generated_texts, None return generated_texts, None
# If we finished at least one generation
next_batch_decoder_input_ids = torch.cat(next_batch_decoder_input_ids) next_batch_decoder_input_ids = torch.cat(next_batch_decoder_input_ids)
# If we finished at least one generation, we need to evict the indices of the generations that finished
# from the values of the next batch
if generated_texts: if generated_texts:
# Apply indices to attention mask, past key values and other items that need to be cached
next_batch_input_ids = batch.input_ids[next_batch_keep_indices] next_batch_input_ids = batch.input_ids[next_batch_keep_indices]
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices] next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
@ -458,7 +494,7 @@ class Seq2SeqLM(Model):
next_batch_next_token_choosers = batch.next_token_choosers next_batch_next_token_choosers = batch.next_token_choosers
next_batch_stopping_criterias = batch.stopping_criterias next_batch_stopping_criterias = batch.stopping_criterias
# Update attention_mask with padding as we added a new token to input_ids # Update decoder_attention_mask with padding as we added a new token to input_ids
if next_batch_decoder_attention_mask is not None: if next_batch_decoder_attention_mask is not None:
next_batch_decoder_attention_mask = torch.cat( next_batch_decoder_attention_mask = torch.cat(
[ [