feat(server): Improved doc
This commit is contained in:
parent
cea6051eff
commit
4236e41b0d
|
@ -28,6 +28,7 @@ ENV LANG=C.UTF-8 \
|
||||||
MODEL_NAME=bigscience/bloom \
|
MODEL_NAME=bigscience/bloom \
|
||||||
QUANTIZE=false \
|
QUANTIZE=false \
|
||||||
NUM_GPUS=8 \
|
NUM_GPUS=8 \
|
||||||
|
SAFETENSORS_FAST_GPU=1 \
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||||
NCCL_ASYNC_ERROR_HANDLING=1 \
|
NCCL_ASYNC_ERROR_HANDLING=1 \
|
||||||
CUDA_HOME=/usr/local/cuda \
|
CUDA_HOME=/usr/local/cuda \
|
||||||
|
@ -55,12 +56,6 @@ RUN cd server && make install-torch
|
||||||
# Install specific version of transformers
|
# Install specific version of transformers
|
||||||
RUN cd server && make install-transformers
|
RUN cd server && make install-transformers
|
||||||
|
|
||||||
# Install specific version of safetensors
|
|
||||||
# FIXME: This is a temporary fix while we wait for a new release
|
|
||||||
RUN curl https://sh.rustup.rs -sSf | bash -s -- -y
|
|
||||||
ENV PATH="/root/.cargo/bin:${PATH}"
|
|
||||||
RUN cd server && make install-safetensors
|
|
||||||
|
|
||||||
# Install server
|
# Install server
|
||||||
COPY proto proto
|
COPY proto proto
|
||||||
COPY server server
|
COPY server server
|
||||||
|
|
14
README.md
14
README.md
|
@ -6,7 +6,8 @@
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
A Rust and gRPC server for text generation inference.
|
A Rust and gRPC server for text generation inference. Used in production at [HuggingFace](https://huggingface.co)
|
||||||
|
to power Bloom, BloomZ and MT0-XXL api-inference widgets.
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
|
@ -15,11 +16,11 @@ A Rust and gRPC server for text generation inference.
|
||||||
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
|
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
|
||||||
- 45ms per token generation for BLOOM with 8xA100 80GB
|
- 45ms per token generation for BLOOM with 8xA100 80GB
|
||||||
|
|
||||||
## Officialy supported models
|
## Officially supported models
|
||||||
|
|
||||||
- BLOOM
|
- [BLOOM](https://huggingface.co/bigscience/bloom)
|
||||||
- BLOOMZ
|
- [BLOOMZ](https://huggingface.co/bigscience/bloomz)
|
||||||
- BLOOM-560m
|
- [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl)
|
||||||
|
|
||||||
Other models are supported on a best effort basis using:
|
Other models are supported on a best effort basis using:
|
||||||
|
|
||||||
|
@ -90,5 +91,4 @@ make router-dev
|
||||||
## TODO:
|
## TODO:
|
||||||
|
|
||||||
- [ ] Add tests for the `server/model` logic
|
- [ ] Add tests for the `server/model` logic
|
||||||
- [ ] Backport custom CUDA kernels to Transformers
|
- [ ] Backport custom CUDA kernels to Transformers
|
||||||
- [ ] Install safetensors with pip
|
|
|
@ -295,6 +295,10 @@ fn shard_manager(
|
||||||
"MASTER_PORT".parse().unwrap(),
|
"MASTER_PORT".parse().unwrap(),
|
||||||
master_port.to_string().parse().unwrap(),
|
master_port.to_string().parse().unwrap(),
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
"SAFETENSORS_FAST_GPU".parse().unwrap(),
|
||||||
|
"1".to_string().parse().unwrap(),
|
||||||
|
),
|
||||||
];
|
];
|
||||||
|
|
||||||
// If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard
|
// If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard
|
||||||
|
|
|
@ -16,24 +16,13 @@ install-transformers:
|
||||||
mv transformers-7302a24535e8dc5637ea5b4e4572fc971d404098 transformers
|
mv transformers-7302a24535e8dc5637ea5b4e4572fc971d404098 transformers
|
||||||
cd transformers && python setup.py install
|
cd transformers && python setup.py install
|
||||||
|
|
||||||
install-safetensors:
|
|
||||||
# Install specific version of safetensors
|
|
||||||
pip install setuptools_rust
|
|
||||||
rm safetensors || true
|
|
||||||
rm safetensors-634deccbcbad5eaf417935281f8b3be7ebca69c5 || true
|
|
||||||
curl -L -O https://github.com/huggingface/safetensors/archive/634deccbcbad5eaf417935281f8b3be7ebca69c5.zip
|
|
||||||
unzip 634deccbcbad5eaf417935281f8b3be7ebca69c5.zip
|
|
||||||
rm 634deccbcbad5eaf417935281f8b3be7ebca69c5.zip
|
|
||||||
mv safetensors-634deccbcbad5eaf417935281f8b3be7ebca69c5 safetensors
|
|
||||||
cd safetensors/bindings/python && python setup.py develop
|
|
||||||
|
|
||||||
install-torch:
|
install-torch:
|
||||||
# Install specific version of torch
|
# Install specific version of torch
|
||||||
pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 --no-cache-dir
|
pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 --no-cache-dir
|
||||||
|
|
||||||
install: gen-server install-torch install-transformers install-safetensors
|
install: gen-server install-torch install-transformers
|
||||||
pip install pip --upgrade
|
pip install pip --upgrade
|
||||||
pip install -e . --no-cache-dir
|
pip install -e . --no-cache-dir
|
||||||
|
|
||||||
run-dev:
|
run-dev:
|
||||||
python -m torch.distributed.run --nproc_per_node=2 text_generation/cli.py serve bigscience/bloom-560m --sharded
|
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation/cli.py serve bigscience/bloom-560m --sharded
|
|
@ -145,6 +145,18 @@ category = "main"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.6"
|
python-versions = ">=3.6"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "safetensors"
|
||||||
|
version = "0.2.4"
|
||||||
|
description = "Fast and Safe Tensor serialization"
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = "*"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
dev = ["black (==22.3)", "flake8 (>=3.8.3)", "huggingface-hub", "isort (>=5.5.4)", "numpy", "pytest", "setuptools-rust"]
|
||||||
|
testing = ["black (==22.3)", "flake8 (>=3.8.3)", "huggingface-hub", "isort (>=5.5.4)", "numpy", "pytest", "setuptools-rust"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "setuptools"
|
name = "setuptools"
|
||||||
version = "65.5.0"
|
version = "65.5.0"
|
||||||
|
@ -208,7 +220,7 @@ bnb = ["bitsandbytes"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "1.1"
|
lock-version = "1.1"
|
||||||
python-versions = "^3.9"
|
python-versions = "^3.9"
|
||||||
content-hash = "224b1e379d6105fe911bff4563946a90dfa6ff5918cf2e7be59f8d4f7c5cd7cf"
|
content-hash = "3266187ef14fe8f9e29b3b6530d07781ea952aa670c0fe0de34be43efa231a67"
|
||||||
|
|
||||||
[metadata.files]
|
[metadata.files]
|
||||||
accelerate = [
|
accelerate = [
|
||||||
|
@ -459,6 +471,39 @@ PyYAML = [
|
||||||
{file = "PyYAML-6.0-cp39-cp39-win_amd64.whl", hash = "sha256:b3d267842bf12586ba6c734f89d1f5b871df0273157918b0ccefa29deb05c21c"},
|
{file = "PyYAML-6.0-cp39-cp39-win_amd64.whl", hash = "sha256:b3d267842bf12586ba6c734f89d1f5b871df0273157918b0ccefa29deb05c21c"},
|
||||||
{file = "PyYAML-6.0.tar.gz", hash = "sha256:68fb519c14306fec9720a2a5b45bc9f0c8d1b9c72adf45c37baedfcd949c35a2"},
|
{file = "PyYAML-6.0.tar.gz", hash = "sha256:68fb519c14306fec9720a2a5b45bc9f0c8d1b9c72adf45c37baedfcd949c35a2"},
|
||||||
]
|
]
|
||||||
|
safetensors = [
|
||||||
|
{file = "safetensors-0.2.4-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:79c4a7610d7699c64d8531c43f758ded4990ebaa7b0887c2078640e6de44e726"},
|
||||||
|
{file = "safetensors-0.2.4-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:ef425a4ddd29612fe733a6eeca6ad8f3ee3939f530a032114974aac4c4667b89"},
|
||||||
|
{file = "safetensors-0.2.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:77758f8ba4de6e20bf394dd964854a926dee2efee82eaa95e6c0893e2a7d960c"},
|
||||||
|
{file = "safetensors-0.2.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fb956e9090cce515649f00b491b5ddc0f9c3d989139016a8d69f9dcf57e8d3d9"},
|
||||||
|
{file = "safetensors-0.2.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3e31b02d27249bd519f05ec9d189097c59fc6851c59daa1a86ef347659e33ac3"},
|
||||||
|
{file = "safetensors-0.2.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c2fead03a1497042efea4358574f3d7acf501b0c82e54d605f393f2b4e2aafe"},
|
||||||
|
{file = "safetensors-0.2.4-cp310-cp310-win32.whl", hash = "sha256:dce6ed3c7d13aafa574737eb3309c928adcb6781e879b41f0861be83b439cf3e"},
|
||||||
|
{file = "safetensors-0.2.4-cp310-cp310-win_amd64.whl", hash = "sha256:1dfe727325a1342767c6725dc2cc1f00463eb40a1f5df37c338d8e03957e27ce"},
|
||||||
|
{file = "safetensors-0.2.4-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:c066bc7b90a582a01ec468fef61a7581b5c726bf12c50491cb6ea5db215ea5e0"},
|
||||||
|
{file = "safetensors-0.2.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca6ed53dad5d7d0e67eb676528ff2ad345cac3a34010e4dc1e3736972de294a5"},
|
||||||
|
{file = "safetensors-0.2.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ada03b44acbb036cfabe7066a8df4ad9b1ac05bb585a6b6c0f285f08e016381d"},
|
||||||
|
{file = "safetensors-0.2.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:58a0902708daa7ec2b2293b46e85df61f4fa359ddfe648e7ac025a79e6f59627"},
|
||||||
|
{file = "safetensors-0.2.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e0a4e38f7cbb4bfc513588e52f349b906c941e74fbbe192f2b19fc34221d448"},
|
||||||
|
{file = "safetensors-0.2.4-cp37-cp37m-win32.whl", hash = "sha256:4f8695b77dd847203258f035f8468f8b701c90621cb6b457e109f8d89c27f16c"},
|
||||||
|
{file = "safetensors-0.2.4-cp37-cp37m-win_amd64.whl", hash = "sha256:16b08f33c753c7da64b3999beea7c30d58204a0820961e33881d05a331e3f5c0"},
|
||||||
|
{file = "safetensors-0.2.4-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:a381606804f23db9eede51135f5fbd1f75dda02100415ee150fd39eb1cd6be4c"},
|
||||||
|
{file = "safetensors-0.2.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7aceae84d0c7233d83923029aaf8d184848561e0211ec98c5317327b3db025d6"},
|
||||||
|
{file = "safetensors-0.2.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:da48fc929485cbd9ee22621e388764a7cef27b0205e73aee2ad75aadd7d67662"},
|
||||||
|
{file = "safetensors-0.2.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2619b88f934c4de6b59de90c9dc00eae2d0e30f254a1daebd6eb232ac1f9a7a7"},
|
||||||
|
{file = "safetensors-0.2.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d1f78b987ae1f6b71da8ea110164e4cab2ee31b53835d2a66279df89c5d73f0e"},
|
||||||
|
{file = "safetensors-0.2.4-cp38-cp38-win32.whl", hash = "sha256:34b3e60b5130fb0fe07114705e51d30aa2c7eae4c1d1e77d6f260fa4ade70ede"},
|
||||||
|
{file = "safetensors-0.2.4-cp38-cp38-win_amd64.whl", hash = "sha256:debaa4fa98a7af44ba6dcb6945efee77b8480284c2cb05918ab97cf511c40826"},
|
||||||
|
{file = "safetensors-0.2.4-cp39-cp39-macosx_10_11_x86_64.whl", hash = "sha256:90baaafc0c872a736124b341db54b0bdd61765cbf3a61418371066a37905b18d"},
|
||||||
|
{file = "safetensors-0.2.4-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:b4bf7e23191d6a3ff00de141512869fc776e8ff159c872cb44af018cb04d45eb"},
|
||||||
|
{file = "safetensors-0.2.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf11a3aba8796e548ceb0a65f34dcd334dcf0c4c891dccabe18a8b53918ae8ab"},
|
||||||
|
{file = "safetensors-0.2.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:95c31935ea71d63a38c546654136d7f0dbf1e7aeb6564dbc2201bc1fe9b34e4c"},
|
||||||
|
{file = "safetensors-0.2.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ef31776e2e081d6f075408eed34a0fbd524cbd19e50268bef02c238b209213b7"},
|
||||||
|
{file = "safetensors-0.2.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d06bb1d68148f6d6934352124d8cbfcf0db092f969db7187e348bd5cbf183db5"},
|
||||||
|
{file = "safetensors-0.2.4-cp39-cp39-win32.whl", hash = "sha256:5d546152b9a5bd58eae97c2ddefba394404d37ddedec305f7639c9b6054513e5"},
|
||||||
|
{file = "safetensors-0.2.4-cp39-cp39-win_amd64.whl", hash = "sha256:553ecfd895d379c1e03a7c9241f7343b3af66573436969ed7eb95df81dfbe9af"},
|
||||||
|
{file = "safetensors-0.2.4.tar.gz", hash = "sha256:35c0719a898f1f1292464f4cd9370bb6c2698032f1db4d677489f078b66b5a75"},
|
||||||
|
]
|
||||||
setuptools = [
|
setuptools = [
|
||||||
{file = "setuptools-65.5.0-py3-none-any.whl", hash = "sha256:f62ea9da9ed6289bfe868cd6845968a2c854d1427f8548d52cae02a42b4f0356"},
|
{file = "setuptools-65.5.0-py3-none-any.whl", hash = "sha256:f62ea9da9ed6289bfe868cd6845968a2c854d1427f8548d52cae02a42b4f0356"},
|
||||||
{file = "setuptools-65.5.0.tar.gz", hash = "sha256:512e5536220e38146176efb833d4a62aa726b7bbff82cfbc8ba9eaa3996e0b17"},
|
{file = "setuptools-65.5.0.tar.gz", hash = "sha256:512e5536220e38146176efb833d4a62aa726b7bbff82cfbc8ba9eaa3996e0b17"},
|
||||||
|
|
|
@ -15,6 +15,7 @@ typer = "^0.6.1"
|
||||||
grpcio-reflection = "^1.49.1"
|
grpcio-reflection = "^1.49.1"
|
||||||
accelerate = "^0.12.0"
|
accelerate = "^0.12.0"
|
||||||
bitsandbytes = "^0.35.1"
|
bitsandbytes = "^0.35.1"
|
||||||
|
safetensors = "^0.2.4"
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
bnb = ["bitsandbytes"]
|
bnb = ["bitsandbytes"]
|
||||||
|
|
|
@ -9,17 +9,13 @@ __all__ = ["Model", "BLOOMSharded", "CausalLM", "Seq2SeqLM"]
|
||||||
def get_model(model_name: str, sharded: bool, quantize: bool) -> Model:
|
def get_model(model_name: str, sharded: bool, quantize: bool) -> Model:
|
||||||
if model_name.startswith("bigscience/bloom"):
|
if model_name.startswith("bigscience/bloom"):
|
||||||
if sharded:
|
if sharded:
|
||||||
return BLOOMSharded(model_name, quantize)
|
return BLOOMSharded(model_name, quantize=quantize)
|
||||||
else:
|
else:
|
||||||
if quantize:
|
return CausalLM(model_name, quantize=quantize)
|
||||||
raise ValueError("quantization is not supported for non-sharded BLOOM")
|
|
||||||
return CausalLM(model_name)
|
|
||||||
else:
|
else:
|
||||||
if sharded:
|
if sharded:
|
||||||
raise ValueError("sharded is not supported for AutoModel")
|
raise ValueError("sharded is not supported for AutoModel")
|
||||||
if quantize:
|
|
||||||
raise ValueError("quantize is not supported for AutoModel")
|
|
||||||
try:
|
try:
|
||||||
return CausalLM(model_name)
|
return CausalLM(model_name, quantize=quantize)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return Seq2SeqLM(model_name)
|
return Seq2SeqLM(model_name, quantize=quantize)
|
||||||
|
|
|
@ -2,7 +2,7 @@ import torch
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
from typing import Optional, Tuple, List, Dict, Type
|
from typing import Optional, Tuple, List, Type
|
||||||
|
|
||||||
from text_generation.models import Model
|
from text_generation.models import Model
|
||||||
from text_generation.models.types import GeneratedText
|
from text_generation.models.types import GeneratedText
|
||||||
|
@ -14,11 +14,23 @@ from text_generation.utils import NextTokenChooser, StoppingCriteria
|
||||||
class CausalLMBatch:
|
class CausalLMBatch:
|
||||||
batch_id: int
|
batch_id: int
|
||||||
requests: List[generate_pb2.Request]
|
requests: List[generate_pb2.Request]
|
||||||
all_input_lengths: List[int]
|
|
||||||
input_ids: Dict[str, torch.Tensor]
|
# Decoder values
|
||||||
|
input_ids: torch.Tensor
|
||||||
|
attention_mask: torch.Tensor
|
||||||
|
past_key_values: Optional[List[Tuple]]
|
||||||
|
|
||||||
|
# All tokens
|
||||||
all_input_ids: List[torch.Tensor]
|
all_input_ids: List[torch.Tensor]
|
||||||
|
|
||||||
|
# Lengths of all generations present in the batch
|
||||||
|
input_lengths: List[int]
|
||||||
|
|
||||||
|
# Generation helpers
|
||||||
next_token_choosers: List[NextTokenChooser]
|
next_token_choosers: List[NextTokenChooser]
|
||||||
stopping_criterias: List[StoppingCriteria]
|
stopping_criterias: List[StoppingCriteria]
|
||||||
|
|
||||||
|
# Metadata used for padding
|
||||||
size: int
|
size: int
|
||||||
max_sequence_length: int
|
max_sequence_length: int
|
||||||
|
|
||||||
|
@ -36,12 +48,12 @@ class CausalLMBatch:
|
||||||
inputs = []
|
inputs = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
all_input_lengths = []
|
input_lengths = []
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
for r in pb.requests:
|
for r in pb.requests:
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
all_input_lengths.append(r.input_length)
|
input_lengths.append(r.input_length)
|
||||||
next_token_choosers.append(
|
next_token_choosers.append(
|
||||||
NextTokenChooser(
|
NextTokenChooser(
|
||||||
temperature=r.parameters.temperature,
|
temperature=r.parameters.temperature,
|
||||||
|
@ -56,21 +68,23 @@ class CausalLMBatch:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
input_ids = tokenizer(
|
tokenized_inputs = tokenizer(
|
||||||
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
|
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
|
||||||
).to(device)
|
).to(device)
|
||||||
all_input_ids = input_ids["input_ids"].unsqueeze(-1)
|
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=pb.id,
|
batch_id=pb.id,
|
||||||
requests=pb.requests,
|
requests=pb.requests,
|
||||||
all_input_lengths=all_input_lengths,
|
input_ids=tokenized_inputs["input_ids"],
|
||||||
input_ids=input_ids,
|
attention_mask=tokenized_inputs["attention_mask"],
|
||||||
|
past_key_values=None,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
|
input_lengths=input_lengths,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
size=pb.size,
|
size=pb.size,
|
||||||
max_sequence_length=max(all_input_lengths),
|
max_sequence_length=max(input_lengths),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -80,19 +94,23 @@ class CausalLMBatch:
|
||||||
max_sequence_length = max(batch.max_sequence_length for batch in batches)
|
max_sequence_length = max(batch.max_sequence_length for batch in batches)
|
||||||
|
|
||||||
# Batch attributes
|
# Batch attributes
|
||||||
input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []}
|
|
||||||
requests = []
|
requests = []
|
||||||
all_input_lengths = []
|
input_lengths = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
|
||||||
|
# Batch tensors
|
||||||
|
input_ids = None
|
||||||
|
attention_mask = None
|
||||||
|
past_key_values = []
|
||||||
|
|
||||||
# Used for slicing correctly inside the tensors
|
# Used for slicing correctly inside the tensors
|
||||||
# Equivalent to a cumsum on batch sizes
|
# Equivalent to a cumsum on batch sizes
|
||||||
start_index = 0
|
start_index = 0
|
||||||
for i, batch in enumerate(batches):
|
for i, batch in enumerate(batches):
|
||||||
requests.extend(batch.requests)
|
requests.extend(batch.requests)
|
||||||
all_input_lengths.extend(batch.all_input_lengths)
|
input_lengths.extend(batch.input_lengths)
|
||||||
all_input_ids.extend(batch.all_input_ids)
|
all_input_ids.extend(batch.all_input_ids)
|
||||||
next_token_choosers.extend(batch.next_token_choosers)
|
next_token_choosers.extend(batch.next_token_choosers)
|
||||||
stopping_criterias.extend(batch.stopping_criterias)
|
stopping_criterias.extend(batch.stopping_criterias)
|
||||||
|
@ -101,32 +119,35 @@ class CausalLMBatch:
|
||||||
end_index = start_index + batch.size
|
end_index = start_index + batch.size
|
||||||
|
|
||||||
# We only concatenate batches that did at least one step
|
# We only concatenate batches that did at least one step
|
||||||
if batch.input_ids["input_ids"].shape[1] > 1:
|
if batch.input_ids.shape[1] > 1:
|
||||||
raise ValueError("Batch input_ids should be of shape (batch_size, 1)")
|
raise ValueError("Batch input_ids should be of shape (batch_size, 1)")
|
||||||
|
|
||||||
# Initialize tensors
|
# Create empty tensor
|
||||||
if i == 0:
|
# input_ids is always of shape [batch_size, 1]
|
||||||
input_ids["input_ids"] = torch.empty(
|
|
||||||
(total_batch_size, 1),
|
|
||||||
dtype=batch.input_ids["input_ids"].dtype,
|
|
||||||
device=batch.input_ids["input_ids"].device,
|
|
||||||
)
|
|
||||||
input_ids["attention_mask"] = torch.zeros(
|
|
||||||
(total_batch_size, max_sequence_length),
|
|
||||||
dtype=batch.input_ids["attention_mask"].dtype,
|
|
||||||
device=batch.input_ids["attention_mask"].device,
|
|
||||||
)
|
|
||||||
|
|
||||||
# input_ids["input_ids"] is always of shape [batch_size, 1]
|
|
||||||
# We do not need to pad it
|
# We do not need to pad it
|
||||||
input_ids["input_ids"][start_index:end_index] = batch.input_ids["input_ids"]
|
if input_ids is None:
|
||||||
|
input_ids = torch.empty(
|
||||||
|
(total_batch_size, 1),
|
||||||
|
dtype=batch.input_ids.dtype,
|
||||||
|
device=batch.input_ids.device,
|
||||||
|
)
|
||||||
|
# Copy to correct indices
|
||||||
|
input_ids[start_index:end_index] = batch.input_ids
|
||||||
|
|
||||||
|
# Create padded tensor
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = torch.zeros(
|
||||||
|
(total_batch_size, max_sequence_length),
|
||||||
|
dtype=batch.attention_mask.dtype,
|
||||||
|
device=batch.attention_mask.device,
|
||||||
|
)
|
||||||
|
|
||||||
# We need to slice the attention mask to remove padding from previous steps
|
# We need to slice the attention mask to remove padding from previous steps
|
||||||
input_ids["attention_mask"][
|
attention_mask[
|
||||||
start_index:end_index, -batch.max_sequence_length :
|
start_index:end_index, -batch.max_sequence_length :
|
||||||
] = batch.input_ids["attention_mask"][:, -batch.max_sequence_length :]
|
] = batch.attention_mask[:, -batch.max_sequence_length :]
|
||||||
|
|
||||||
for j, past in enumerate(batch.input_ids["past_key_values"]):
|
for j, past in enumerate(batch.past_key_values):
|
||||||
# Shenanigans to get dimensions because BLOOM outputs a past with a different shape
|
# Shenanigans to get dimensions because BLOOM outputs a past with a different shape
|
||||||
# BLOOM: [batch_size * num_heads, ...] vs [batch_size, num_heads, ...]
|
# BLOOM: [batch_size * num_heads, ...] vs [batch_size, num_heads, ...]
|
||||||
head_dim, padded_sequence_length = past[0].shape[-2:]
|
head_dim, padded_sequence_length = past[0].shape[-2:]
|
||||||
|
@ -137,8 +158,8 @@ class CausalLMBatch:
|
||||||
)
|
)
|
||||||
|
|
||||||
# This will run only once per layer
|
# This will run only once per layer
|
||||||
if j == len(input_ids["past_key_values"]):
|
if j == len(past_key_values):
|
||||||
input_ids["past_key_values"].append([])
|
past_key_values.append([])
|
||||||
|
|
||||||
# Decoder past
|
# Decoder past
|
||||||
for k, t in enumerate(past):
|
for k, t in enumerate(past):
|
||||||
|
@ -172,21 +193,21 @@ class CausalLMBatch:
|
||||||
|
|
||||||
# Initialize tensors
|
# Initialize tensors
|
||||||
# This will run only once per layer and per past tensor
|
# This will run only once per layer and per past tensor
|
||||||
if k == len(input_ids["past_key_values"][j]):
|
if k == len(past_key_values[j]):
|
||||||
input_ids["past_key_values"][j].append(
|
past_key_values[j].append(
|
||||||
torch.zeros(padded_t_shape, dtype=t.dtype, device=t.device)
|
torch.zeros(padded_t_shape, dtype=t.dtype, device=t.device)
|
||||||
)
|
)
|
||||||
|
|
||||||
# We slice the past keys and values to remove the padding from previous batches
|
# We slice the past keys and values to remove the padding from previous batches
|
||||||
if not head_dim_last:
|
if not head_dim_last:
|
||||||
input_ids["past_key_values"][j][k][
|
past_key_values[j][k][
|
||||||
start_index:end_index,
|
start_index:end_index,
|
||||||
:,
|
:,
|
||||||
:,
|
:,
|
||||||
-(batch.max_sequence_length - 1) :,
|
-(batch.max_sequence_length - 1) :,
|
||||||
] = t[:, :, :, -(batch.max_sequence_length - 1) :]
|
] = t[:, :, :, -(batch.max_sequence_length - 1) :]
|
||||||
else:
|
else:
|
||||||
input_ids["past_key_values"][j][k][
|
past_key_values[j][k][
|
||||||
start_index:end_index,
|
start_index:end_index,
|
||||||
:,
|
:,
|
||||||
-(batch.max_sequence_length - 1) :,
|
-(batch.max_sequence_length - 1) :,
|
||||||
|
@ -198,9 +219,11 @@ class CausalLMBatch:
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=batches[0].batch_id,
|
batch_id=batches[0].batch_id,
|
||||||
requests=requests,
|
requests=requests,
|
||||||
all_input_lengths=all_input_lengths,
|
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
|
input_lengths=input_lengths,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
size=total_batch_size,
|
size=total_batch_size,
|
||||||
|
@ -209,7 +232,7 @@ class CausalLMBatch:
|
||||||
|
|
||||||
|
|
||||||
class CausalLM(Model):
|
class CausalLM(Model):
|
||||||
def __init__(self, model_name: str):
|
def __init__(self, model_name: str, quantize=False):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||||
|
@ -223,6 +246,7 @@ class CausalLM(Model):
|
||||||
model_name,
|
model_name,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
device_map="auto" if torch.cuda.is_available() else None,
|
device_map="auto" if torch.cuda.is_available() else None,
|
||||||
|
load_in_8bit=quantize,
|
||||||
).eval()
|
).eval()
|
||||||
|
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
|
@ -255,16 +279,19 @@ class CausalLM(Model):
|
||||||
torch.no_grad if self.device.type == "cpu" else torch.inference_mode
|
torch.no_grad if self.device.type == "cpu" else torch.inference_mode
|
||||||
)
|
)
|
||||||
with context_manager():
|
with context_manager():
|
||||||
logits, past = self.forward(**batch.input_ids)
|
logits, past = self.forward(
|
||||||
|
batch.input_ids, batch.attention_mask, batch.past_key_values
|
||||||
|
)
|
||||||
|
|
||||||
# List of indices to cache
|
# List of indices to cache
|
||||||
next_batch_keep_indices = []
|
next_batch_keep_indices = []
|
||||||
|
|
||||||
# New input_ids for next forward
|
# New values for next forward
|
||||||
|
next_batch_input_lengths = []
|
||||||
next_batch_input_ids = []
|
next_batch_input_ids = []
|
||||||
next_batch_all_input_ids = []
|
next_batch_all_input_ids = []
|
||||||
next_all_input_lengths = []
|
|
||||||
|
|
||||||
|
# Metadata
|
||||||
next_batch_size = 0
|
next_batch_size = 0
|
||||||
next_batch_max_sequence_length = 0
|
next_batch_max_sequence_length = 0
|
||||||
|
|
||||||
|
@ -274,7 +301,7 @@ class CausalLM(Model):
|
||||||
# Zipped iterator
|
# Zipped iterator
|
||||||
iterator = zip(
|
iterator = zip(
|
||||||
batch.requests,
|
batch.requests,
|
||||||
batch.all_input_lengths,
|
batch.input_lengths,
|
||||||
logits,
|
logits,
|
||||||
batch.next_token_choosers,
|
batch.next_token_choosers,
|
||||||
batch.stopping_criterias,
|
batch.stopping_criterias,
|
||||||
|
@ -313,7 +340,7 @@ class CausalLM(Model):
|
||||||
next_batch_all_input_ids.append(all_tokens)
|
next_batch_all_input_ids.append(all_tokens)
|
||||||
next_batch_size += 1
|
next_batch_size += 1
|
||||||
new_input_length = input_length + 1
|
new_input_length = input_length + 1
|
||||||
next_all_input_lengths.append(new_input_length)
|
next_batch_input_lengths.append(new_input_length)
|
||||||
next_batch_max_sequence_length = max(
|
next_batch_max_sequence_length = max(
|
||||||
next_batch_max_sequence_length, new_input_length
|
next_batch_max_sequence_length, new_input_length
|
||||||
)
|
)
|
||||||
|
@ -322,15 +349,14 @@ class CausalLM(Model):
|
||||||
if not next_batch_keep_indices:
|
if not next_batch_keep_indices:
|
||||||
return generated_texts, None
|
return generated_texts, None
|
||||||
|
|
||||||
# If we finished at least one generation
|
next_batch_input_ids = torch.cat(next_batch_input_ids, dim=0)
|
||||||
next_batch_input_ids = {"input_ids": torch.cat(next_batch_input_ids, dim=0)}
|
# If we finished at least one generation, we need to evict the indices of the generations that finished
|
||||||
|
# from the values of the next batch
|
||||||
if generated_texts:
|
if generated_texts:
|
||||||
# Apply indices to attention mask, past key values and other items that need to be cached
|
# Apply indices to attention mask, past key values and other items that need to be cached
|
||||||
next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"][
|
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
|
||||||
next_batch_keep_indices
|
|
||||||
]
|
|
||||||
# Force past to be of dim [batch_size, num_heads, ...] for easy indexing
|
# Force past to be of dim [batch_size, num_heads, ...] for easy indexing
|
||||||
next_batch_input_ids["past_key_values"] = [
|
next_batch_past_key_values = [
|
||||||
[
|
[
|
||||||
t.view(-1, self.num_heads, *t.shape[-2:])[next_batch_keep_indices]
|
t.view(-1, self.num_heads, *t.shape[-2:])[next_batch_keep_indices]
|
||||||
for t in layer
|
for t in layer
|
||||||
|
@ -345,16 +371,16 @@ class CausalLM(Model):
|
||||||
batch.stopping_criterias[i] for i in next_batch_keep_indices
|
batch.stopping_criterias[i] for i in next_batch_keep_indices
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"]
|
next_batch_attention_mask = batch.attention_mask
|
||||||
next_batch_input_ids["past_key_values"] = past
|
next_batch_past_key_values = past
|
||||||
next_batch_requests = batch.requests
|
next_batch_requests = batch.requests
|
||||||
next_batch_next_token_choosers = batch.next_token_choosers
|
next_batch_next_token_choosers = batch.next_token_choosers
|
||||||
next_batch_stopping_criterias = batch.stopping_criterias
|
next_batch_stopping_criterias = batch.stopping_criterias
|
||||||
|
|
||||||
# Update attention_mask with padding as we added a new token to input_ids
|
# Update attention_mask with padding as we added a new token to input_ids
|
||||||
next_batch_input_ids["attention_mask"] = torch.cat(
|
next_batch_attention_mask = torch.cat(
|
||||||
[
|
[
|
||||||
next_batch_input_ids["attention_mask"],
|
next_batch_attention_mask,
|
||||||
torch.ones((next_batch_size, 1)).to(self.device),
|
torch.ones((next_batch_size, 1)).to(self.device),
|
||||||
],
|
],
|
||||||
dim=1,
|
dim=1,
|
||||||
|
@ -363,9 +389,11 @@ class CausalLM(Model):
|
||||||
next_batch = CausalLMBatch(
|
next_batch = CausalLMBatch(
|
||||||
batch_id=batch.batch_id,
|
batch_id=batch.batch_id,
|
||||||
requests=next_batch_requests,
|
requests=next_batch_requests,
|
||||||
all_input_lengths=next_all_input_lengths,
|
|
||||||
input_ids=next_batch_input_ids,
|
input_ids=next_batch_input_ids,
|
||||||
|
attention_mask=next_batch_attention_mask,
|
||||||
|
past_key_values=next_batch_past_key_values,
|
||||||
all_input_ids=next_batch_all_input_ids,
|
all_input_ids=next_batch_all_input_ids,
|
||||||
|
input_lengths=next_batch_input_lengths,
|
||||||
next_token_choosers=next_batch_next_token_choosers,
|
next_token_choosers=next_batch_next_token_choosers,
|
||||||
stopping_criterias=next_batch_stopping_criterias,
|
stopping_criterias=next_batch_stopping_criterias,
|
||||||
size=next_batch_size,
|
size=next_batch_size,
|
||||||
|
|
|
@ -15,26 +15,33 @@ class Seq2SeqLMBatch:
|
||||||
batch_id: int
|
batch_id: int
|
||||||
requests: List[generate_pb2.Request]
|
requests: List[generate_pb2.Request]
|
||||||
|
|
||||||
|
# Encoder values
|
||||||
input_ids: torch.Tensor
|
input_ids: torch.Tensor
|
||||||
attention_mask: torch.Tensor
|
attention_mask: torch.Tensor
|
||||||
|
|
||||||
|
# Decoder values
|
||||||
decoder_input_ids: torch.Tensor
|
decoder_input_ids: torch.Tensor
|
||||||
decoder_attention_mask: Optional[torch.Tensor]
|
decoder_attention_mask: Optional[torch.Tensor]
|
||||||
encoder_last_hidden_state: Optional[torch.Tensor]
|
encoder_last_hidden_state: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
# Seq2SeqLM keeps track of both encoder and decoder attention keys and values
|
||||||
past_key_values: Optional[List[Tuple]]
|
past_key_values: Optional[List[Tuple]]
|
||||||
|
|
||||||
|
# Lengths of all generations present in the batch
|
||||||
input_lengths: List[int]
|
input_lengths: List[int]
|
||||||
decoder_input_lengths: List[int]
|
decoder_input_lengths: List[int]
|
||||||
|
|
||||||
|
# Generation helpers
|
||||||
next_token_choosers: List[NextTokenChooser]
|
next_token_choosers: List[NextTokenChooser]
|
||||||
stopping_criterias: List[StoppingCriteria]
|
stopping_criterias: List[StoppingCriteria]
|
||||||
|
|
||||||
|
# Metadata used for padding
|
||||||
size: int
|
size: int
|
||||||
max_input_length: int
|
max_input_length: int
|
||||||
max_decoder_input_length: int
|
max_decoder_input_length: int
|
||||||
|
|
||||||
def to_pb(self):
|
def to_pb(self):
|
||||||
|
"""Convert a Seq2SeqLMBatch to a text_generation.v1.Batch protobuf"""
|
||||||
return generate_pb2.Batch(
|
return generate_pb2.Batch(
|
||||||
id=self.batch_id,
|
id=self.batch_id,
|
||||||
requests=self.requests,
|
requests=self.requests,
|
||||||
|
@ -45,6 +52,7 @@ class Seq2SeqLMBatch:
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
|
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
|
||||||
) -> "Seq2SeqLMBatch":
|
) -> "Seq2SeqLMBatch":
|
||||||
|
"""Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch"""
|
||||||
inputs = []
|
inputs = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
@ -57,6 +65,7 @@ class Seq2SeqLMBatch:
|
||||||
for r in pb.requests:
|
for r in pb.requests:
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
input_lengths.append(r.input_length)
|
input_lengths.append(r.input_length)
|
||||||
|
# Decoder sequence only contains the bos_token
|
||||||
decoder_input_ids.append(tokenizer.bos_token_id)
|
decoder_input_ids.append(tokenizer.bos_token_id)
|
||||||
decoder_input_lengths.append(1)
|
decoder_input_lengths.append(1)
|
||||||
next_token_choosers.append(
|
next_token_choosers.append(
|
||||||
|
@ -73,9 +82,11 @@ class Seq2SeqLMBatch:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Tokenize batch
|
||||||
tokenized_inputs = tokenizer(
|
tokenized_inputs = tokenizer(
|
||||||
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
|
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
|
||||||
).to(device)
|
).to(device)
|
||||||
|
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
|
||||||
decoder_input_ids = torch.tensor(decoder_input_ids).to(device).unsqueeze(-1)
|
decoder_input_ids = torch.tensor(decoder_input_ids).to(device).unsqueeze(-1)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
|
@ -98,6 +109,8 @@ class Seq2SeqLMBatch:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
|
def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
|
||||||
|
"""Concatenate multiple batches together by padding internal torch tensors"""
|
||||||
|
|
||||||
# Used for padding
|
# Used for padding
|
||||||
total_batch_size = sum(batch.size for batch in batches)
|
total_batch_size = sum(batch.size for batch in batches)
|
||||||
max_input_length = max(batch.max_input_length for batch in batches)
|
max_input_length = max(batch.max_input_length for batch in batches)
|
||||||
|
@ -112,6 +125,7 @@ class Seq2SeqLMBatch:
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
|
||||||
|
# Batch tensors
|
||||||
input_ids = None
|
input_ids = None
|
||||||
attention_mask = None
|
attention_mask = None
|
||||||
decoder_input_ids = None
|
decoder_input_ids = None
|
||||||
|
@ -122,7 +136,9 @@ class Seq2SeqLMBatch:
|
||||||
# Used for slicing correctly inside the tensors
|
# Used for slicing correctly inside the tensors
|
||||||
# Equivalent to a cumsum on batch sizes
|
# Equivalent to a cumsum on batch sizes
|
||||||
start_index = 0
|
start_index = 0
|
||||||
|
|
||||||
for i, batch in enumerate(batches):
|
for i, batch in enumerate(batches):
|
||||||
|
# Extend all list attributes
|
||||||
requests.extend(batch.requests)
|
requests.extend(batch.requests)
|
||||||
input_lengths.extend(batch.input_lengths)
|
input_lengths.extend(batch.input_lengths)
|
||||||
decoder_input_lengths.extend(batch.decoder_input_lengths)
|
decoder_input_lengths.extend(batch.decoder_input_lengths)
|
||||||
|
@ -136,51 +152,62 @@ class Seq2SeqLMBatch:
|
||||||
if batch.encoder_last_hidden_state is None:
|
if batch.encoder_last_hidden_state is None:
|
||||||
raise ValueError("Batch encoder_last_hidden_state cannot be None")
|
raise ValueError("Batch encoder_last_hidden_state cannot be None")
|
||||||
|
|
||||||
|
# Create padded tensor
|
||||||
if input_ids is None:
|
if input_ids is None:
|
||||||
input_ids = torch.zeros(
|
input_ids = torch.zeros(
|
||||||
(total_batch_size, max_input_length),
|
(total_batch_size, max_input_length),
|
||||||
dtype=batch.input_ids.dtype,
|
dtype=batch.input_ids.dtype,
|
||||||
device=batch.input_ids.device,
|
device=batch.input_ids.device,
|
||||||
)
|
)
|
||||||
|
# Copy to correct indices
|
||||||
input_ids[
|
input_ids[
|
||||||
start_index:end_index, -batch.max_input_length :
|
start_index:end_index, -batch.max_input_length :
|
||||||
] = batch.input_ids[:, -batch.max_input_length :]
|
] = batch.input_ids[:, -batch.max_input_length :]
|
||||||
|
|
||||||
|
# Create padded tensor
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.zeros(
|
attention_mask = torch.zeros(
|
||||||
(total_batch_size, max_input_length),
|
(total_batch_size, max_input_length),
|
||||||
dtype=batch.attention_mask.dtype,
|
dtype=batch.attention_mask.dtype,
|
||||||
device=batch.attention_mask.device,
|
device=batch.attention_mask.device,
|
||||||
)
|
)
|
||||||
|
# Copy to correct indices
|
||||||
attention_mask[
|
attention_mask[
|
||||||
start_index:end_index, -batch.max_input_length :
|
start_index:end_index, -batch.max_input_length :
|
||||||
] = batch.attention_mask[:, -batch.max_input_length :]
|
] = batch.attention_mask[:, -batch.max_input_length :]
|
||||||
|
|
||||||
|
# Create padded tensor
|
||||||
if decoder_input_ids is None:
|
if decoder_input_ids is None:
|
||||||
decoder_input_ids = torch.zeros(
|
decoder_input_ids = torch.zeros(
|
||||||
(total_batch_size, max_decoder_input_length),
|
(total_batch_size, max_decoder_input_length),
|
||||||
dtype=batch.decoder_input_ids.dtype,
|
dtype=batch.decoder_input_ids.dtype,
|
||||||
device=batch.decoder_input_ids.device,
|
device=batch.decoder_input_ids.device,
|
||||||
)
|
)
|
||||||
|
# Copy to correct indices
|
||||||
decoder_input_ids[
|
decoder_input_ids[
|
||||||
start_index:end_index, -batch.max_decoder_input_length :
|
start_index:end_index, -batch.max_decoder_input_length :
|
||||||
] = batch.decoder_input_ids[:, -batch.max_decoder_input_length :]
|
] = batch.decoder_input_ids[:, -batch.max_decoder_input_length :]
|
||||||
|
|
||||||
|
# Create padded tensor
|
||||||
if decoder_attention_mask is None:
|
if decoder_attention_mask is None:
|
||||||
decoder_attention_mask = torch.zeros(
|
decoder_attention_mask = torch.zeros(
|
||||||
(total_batch_size, max_decoder_input_length),
|
(total_batch_size, max_decoder_input_length),
|
||||||
dtype=batch.attention_mask.dtype,
|
dtype=batch.attention_mask.dtype, # As decoder_attention_mask might not exist,
|
||||||
device=batch.attention_mask.device,
|
device=batch.attention_mask.device, # we use `batch.attention_maks` for device here
|
||||||
)
|
)
|
||||||
|
# If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
|
||||||
|
# this batch. All generations are of length `batch.max_decoder_input_length`.
|
||||||
if batch.decoder_attention_mask is None:
|
if batch.decoder_attention_mask is None:
|
||||||
decoder_attention_mask[
|
decoder_attention_mask[
|
||||||
start_index:end_index, -batch.max_decoder_input_length :
|
start_index:end_index, -batch.max_decoder_input_length :
|
||||||
] = 1
|
] = 1
|
||||||
|
# If it exists, we need to index
|
||||||
else:
|
else:
|
||||||
decoder_attention_mask[
|
decoder_attention_mask[
|
||||||
start_index:end_index, -batch.max_decoder_input_length :
|
start_index:end_index, -batch.max_decoder_input_length :
|
||||||
] = batch.decoder_attention_mask[:, -batch.max_decoder_input_length :]
|
] = batch.decoder_attention_mask[:, -batch.max_decoder_input_length :]
|
||||||
|
|
||||||
|
# Create padded tensor
|
||||||
if encoder_last_hidden_state is None:
|
if encoder_last_hidden_state is None:
|
||||||
encoder_last_hidden_state = torch.zeros(
|
encoder_last_hidden_state = torch.zeros(
|
||||||
(
|
(
|
||||||
|
@ -192,10 +219,12 @@ class Seq2SeqLMBatch:
|
||||||
device=batch.encoder_last_hidden_state.device,
|
device=batch.encoder_last_hidden_state.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Copy to correct indices
|
||||||
encoder_last_hidden_state[
|
encoder_last_hidden_state[
|
||||||
start_index:end_index, -batch.max_decoder_input_length :, :
|
start_index:end_index, -batch.max_decoder_input_length :, :
|
||||||
] = batch.encoder_last_hidden_state[:, -batch.max_decoder_input_length :, :]
|
] = batch.encoder_last_hidden_state[:, -batch.max_decoder_input_length :, :]
|
||||||
|
|
||||||
|
# Iterate over attention layers
|
||||||
for j, past in enumerate(batch.past_key_values):
|
for j, past in enumerate(batch.past_key_values):
|
||||||
_, num_heads, _, head_dim = past[0].shape
|
_, num_heads, _, head_dim = past[0].shape
|
||||||
|
|
||||||
|
@ -271,7 +300,7 @@ class Seq2SeqLMBatch:
|
||||||
|
|
||||||
|
|
||||||
class Seq2SeqLM(Model):
|
class Seq2SeqLM(Model):
|
||||||
def __init__(self, model_name: str):
|
def __init__(self, model_name: str, quantize=False):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||||
|
@ -283,6 +312,7 @@ class Seq2SeqLM(Model):
|
||||||
model_name,
|
model_name,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
device_map="auto" if torch.cuda.is_available() else None,
|
device_map="auto" if torch.cuda.is_available() else None,
|
||||||
|
load_in_8bit=quantize,
|
||||||
).eval()
|
).eval()
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
|
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
|
||||||
tokenizer.bos_token_id = self.model.config.decoder_start_token_id
|
tokenizer.bos_token_id = self.model.config.decoder_start_token_id
|
||||||
|
@ -314,14 +344,17 @@ class Seq2SeqLM(Model):
|
||||||
if past_key_values is not None:
|
if past_key_values is not None:
|
||||||
decoder_input_ids = decoder_input_ids[:, -1].unsqueeze(-1)
|
decoder_input_ids = decoder_input_ids[:, -1].unsqueeze(-1)
|
||||||
|
|
||||||
|
# Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
|
||||||
|
# internally...
|
||||||
|
if encoder_last_hidden_state is not None:
|
||||||
|
encoder_last_hidden_state = [encoder_last_hidden_state]
|
||||||
|
|
||||||
outputs = self.model.forward(
|
outputs = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
encoder_outputs=[encoder_last_hidden_state]
|
encoder_outputs=encoder_last_hidden_state,
|
||||||
if encoder_last_hidden_state is not None
|
|
||||||
else None,
|
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
)
|
)
|
||||||
|
@ -351,11 +384,12 @@ class Seq2SeqLM(Model):
|
||||||
# List of indices to cache
|
# List of indices to cache
|
||||||
next_batch_keep_indices = []
|
next_batch_keep_indices = []
|
||||||
|
|
||||||
# New input_ids for next forward
|
# New values for next forward
|
||||||
next_batch_input_lengths = []
|
next_batch_input_lengths = []
|
||||||
next_batch_decoder_input_ids = []
|
next_batch_decoder_input_ids = []
|
||||||
next_batch_decoder_input_lengths = []
|
next_batch_decoder_input_lengths = []
|
||||||
|
|
||||||
|
# Metadata
|
||||||
next_batch_size = 0
|
next_batch_size = 0
|
||||||
next_batch_max_input_length = 0
|
next_batch_max_input_length = 0
|
||||||
next_batch_max_decoder_input_length = 0
|
next_batch_max_decoder_input_length = 0
|
||||||
|
@ -395,7 +429,7 @@ class Seq2SeqLM(Model):
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
# Evaluate stopping criteria
|
||||||
if stopping_criteria(decoder_tokens):
|
if stopping_criteria(decoder_tokens):
|
||||||
# Decode all tokens
|
# Decode tokens
|
||||||
output = self.tokenizer.decode(decoder_tokens, skip_special_tokens=True)
|
output = self.tokenizer.decode(decoder_tokens, skip_special_tokens=True)
|
||||||
# Add to the list of finished generations with the original request
|
# Add to the list of finished generations with the original request
|
||||||
generated_texts.append(
|
generated_texts.append(
|
||||||
|
@ -420,9 +454,11 @@ class Seq2SeqLM(Model):
|
||||||
if not next_batch_keep_indices:
|
if not next_batch_keep_indices:
|
||||||
return generated_texts, None
|
return generated_texts, None
|
||||||
|
|
||||||
# If we finished at least one generation
|
|
||||||
next_batch_decoder_input_ids = torch.cat(next_batch_decoder_input_ids)
|
next_batch_decoder_input_ids = torch.cat(next_batch_decoder_input_ids)
|
||||||
|
# If we finished at least one generation, we need to evict the indices of the generations that finished
|
||||||
|
# from the values of the next batch
|
||||||
if generated_texts:
|
if generated_texts:
|
||||||
|
# Apply indices to attention mask, past key values and other items that need to be cached
|
||||||
next_batch_input_ids = batch.input_ids[next_batch_keep_indices]
|
next_batch_input_ids = batch.input_ids[next_batch_keep_indices]
|
||||||
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
|
next_batch_attention_mask = batch.attention_mask[next_batch_keep_indices]
|
||||||
|
|
||||||
|
@ -458,7 +494,7 @@ class Seq2SeqLM(Model):
|
||||||
next_batch_next_token_choosers = batch.next_token_choosers
|
next_batch_next_token_choosers = batch.next_token_choosers
|
||||||
next_batch_stopping_criterias = batch.stopping_criterias
|
next_batch_stopping_criterias = batch.stopping_criterias
|
||||||
|
|
||||||
# Update attention_mask with padding as we added a new token to input_ids
|
# Update decoder_attention_mask with padding as we added a new token to input_ids
|
||||||
if next_batch_decoder_attention_mask is not None:
|
if next_batch_decoder_attention_mask is not None:
|
||||||
next_batch_decoder_attention_mask = torch.cat(
|
next_batch_decoder_attention_mask = torch.cat(
|
||||||
[
|
[
|
||||||
|
|
Loading…
Reference in New Issue