feat(server): Add model tests (#6)

This commit is contained in:
OlivierDehaene 2022-12-08 18:49:33 +01:00 committed by GitHub
parent 31d76e238d
commit a2985036aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 1105 additions and 29 deletions

View File

@ -87,9 +87,4 @@ curl 127.0.0.1:3000/generate \
```shell
make server-dev
make router-dev
```
## TODO:
- [ ] Add tests for the `server/model` logic
- [ ] Backport custom CUDA kernels to Transformers
```

View File

@ -70,7 +70,7 @@ impl Batcher {
// Notify the background task that we have a new entry in the database that needs
// to be batched
self.shared.batching_task.notify_waiters();
self.shared.batching_task.notify_one();
// Await on the response from the background task
// We can safely unwrap as the background task will never drop the sender

View File

@ -8,8 +8,9 @@ gen-server:
install-transformers:
# Install specific version of transformers with custom cuda kernels
rm transformers || true
rm transformers-text_generation_inference || true
pip uninstall transformers -y || true
rm -rf transformers || true
rm -rf transformers-text_generation_inference || true
curl -L -O https://github.com/OlivierDehaene/transformers/archive/refs/heads/text_generation_inference.zip
unzip text_generation_inference.zip
rm text_generation_inference.zip

99
server/poetry.lock generated
View File

@ -22,6 +22,20 @@ test_prod = ["parameterized", "pytest", "pytest-subtests", "pytest-xdist"]
test_trackers = ["comet-ml", "tensorboard", "wandb"]
testing = ["datasets", "deepspeed (<0.7.0)", "evaluate", "parameterized", "pytest", "pytest-subtests", "pytest-xdist", "scipy", "sklearn", "tqdm", "transformers"]
[[package]]
name = "attrs"
version = "22.1.0"
description = "Classes Without Boilerplate"
category = "dev"
optional = false
python-versions = ">=3.5"
[package.extras]
dev = ["cloudpickle", "coverage[toml] (>=5.0.2)", "furo", "hypothesis", "mypy (>=0.900,!=0.940)", "pre-commit", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "sphinx", "sphinx-notfound-page", "zope.interface"]
docs = ["furo", "sphinx", "sphinx-notfound-page", "zope.interface"]
tests = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "zope.interface"]
tests_no_zope = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins"]
[[package]]
name = "bitsandbytes"
version = "0.35.1"
@ -49,6 +63,17 @@ category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
[[package]]
name = "exceptiongroup"
version = "1.0.4"
description = "Backport of PEP 654 (exception groups)"
category = "dev"
optional = false
python-versions = ">=3.7"
[package.extras]
test = ["pytest (>=6)"]
[[package]]
name = "grpcio"
version = "1.50.0"
@ -88,6 +113,14 @@ grpcio = ">=1.50.0"
protobuf = ">=4.21.6,<5.0dev"
setuptools = "*"
[[package]]
name = "iniconfig"
version = "1.1.1"
description = "iniconfig: brain-dead simple config-ini parsing"
category = "dev"
optional = false
python-versions = "*"
[[package]]
name = "numpy"
version = "1.23.4"
@ -107,6 +140,18 @@ python-versions = ">=3.6"
[package.dependencies]
pyparsing = ">=2.0.2,<3.0.5 || >3.0.5"
[[package]]
name = "pluggy"
version = "1.0.0"
description = "plugin and hook calling mechanisms for python"
category = "dev"
optional = false
python-versions = ">=3.6"
[package.extras]
dev = ["pre-commit", "tox"]
testing = ["pytest", "pytest-benchmark"]
[[package]]
name = "protobuf"
version = "4.21.8"
@ -137,6 +182,26 @@ python-versions = ">=3.6.8"
[package.extras]
diagrams = ["jinja2", "railroad-diagrams"]
[[package]]
name = "pytest"
version = "7.2.0"
description = "pytest: simple powerful testing with Python"
category = "dev"
optional = false
python-versions = ">=3.7"
[package.dependencies]
attrs = ">=19.2.0"
colorama = {version = "*", markers = "sys_platform == \"win32\""}
exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""}
iniconfig = "*"
packaging = "*"
pluggy = ">=0.12,<2.0"
tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""}
[package.extras]
testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "xmlschema"]
[[package]]
name = "PyYAML"
version = "6.0"
@ -178,6 +243,14 @@ category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*"
[[package]]
name = "tomli"
version = "2.0.1"
description = "A lil' TOML parser"
category = "dev"
optional = false
python-versions = ">=3.7"
[[package]]
name = "torch"
version = "1.12.1"
@ -220,13 +293,17 @@ bnb = ["bitsandbytes"]
[metadata]
lock-version = "1.1"
python-versions = "^3.9"
content-hash = "3266187ef14fe8f9e29b3b6530d07781ea952aa670c0fe0de34be43efa231a67"
content-hash = "51693654531e3229ac64bee250932ace20a60e8d45af074ae7b860ed32b25ef8"
[metadata.files]
accelerate = [
{file = "accelerate-0.12.0-py3-none-any.whl", hash = "sha256:7742ca5c9f15dd1e0a283305599c196e260af4717a561d1f544aeab27d828af6"},
{file = "accelerate-0.12.0.tar.gz", hash = "sha256:e8b119c94fac31877620d5f9de311164ec81fa9dc9e175f0d0d4f50fc8d79473"},
]
attrs = [
{file = "attrs-22.1.0-py2.py3-none-any.whl", hash = "sha256:86efa402f67bf2df34f51a335487cf46b1ec130d02b8d39fd248abfd30da551c"},
{file = "attrs-22.1.0.tar.gz", hash = "sha256:29adc2665447e5191d0e7c568fde78b21f9672d344281d0c6e1ab085429b22b6"},
]
bitsandbytes = [
{file = "bitsandbytes-0.35.1-py3-none-any.whl", hash = "sha256:4506a9e3778359a743938aa5592d8d043fa91d1df66cd01ba8cc6486e64dea45"},
{file = "bitsandbytes-0.35.1.tar.gz", hash = "sha256:63a6f59c87b713a731a685e43d68c19789ee6381e62196cafab293b87eca5d46"},
@ -239,6 +316,10 @@ colorama = [
{file = "colorama-0.4.5-py2.py3-none-any.whl", hash = "sha256:854bf444933e37f5824ae7bfc1e98d5bce2ebe4160d46b5edf346a89358e99da"},
{file = "colorama-0.4.5.tar.gz", hash = "sha256:e6c6b4334fc50988a639d9b98aa429a0b57da6e17b9a44f0451f930b6967b7a4"},
]
exceptiongroup = [
{file = "exceptiongroup-1.0.4-py3-none-any.whl", hash = "sha256:542adf9dea4055530d6e1279602fa5cb11dab2395fa650b8674eaec35fc4a828"},
{file = "exceptiongroup-1.0.4.tar.gz", hash = "sha256:bd14967b79cd9bdb54d97323216f8fdf533e278df937aa2a90089e7d6e06e5ec"},
]
grpcio = [
{file = "grpcio-1.50.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:906f4d1beb83b3496be91684c47a5d870ee628715227d5d7c54b04a8de802974"},
{file = "grpcio-1.50.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:2d9fd6e38b16c4d286a01e1776fdf6c7a4123d99ae8d6b3f0b4a03a34bf6ce45"},
@ -337,6 +418,10 @@ grpcio-tools = [
{file = "grpcio_tools-1.50.0-cp39-cp39-win32.whl", hash = "sha256:e1a8f9a57bbcc2e633aaf327e39830527f3c1f7add18c7580f3058fe9a0fa780"},
{file = "grpcio_tools-1.50.0-cp39-cp39-win_amd64.whl", hash = "sha256:b7eb7a84d9171c0ae1550833f4a6ca52372bed9db0fa10f8c9dbe6ca65f97a8c"},
]
iniconfig = [
{file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"},
{file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"},
]
numpy = [
{file = "numpy-1.23.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:95d79ada05005f6f4f337d3bb9de8a7774f259341c70bc88047a1f7b96a4bcb2"},
{file = "numpy-1.23.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:926db372bc4ac1edf81cfb6c59e2a881606b409ddc0d0920b988174b2e2a767f"},
@ -371,6 +456,10 @@ packaging = [
{file = "packaging-21.3-py3-none-any.whl", hash = "sha256:ef103e05f519cdc783ae24ea4e2e0f508a9c99b2d4969652eed6a2e1ea5bd522"},
{file = "packaging-21.3.tar.gz", hash = "sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb"},
]
pluggy = [
{file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"},
{file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"},
]
protobuf = [
{file = "protobuf-4.21.8-cp310-abi3-win32.whl", hash = "sha256:c252c55ee15175aa1b21b7b9896e6add5162d066d5202e75c39f96136f08cce3"},
{file = "protobuf-4.21.8-cp310-abi3-win_amd64.whl", hash = "sha256:809ca0b225d3df42655a12f311dd0f4148a943c51f1ad63c38343e457492b689"},
@ -429,6 +518,10 @@ pyparsing = [
{file = "pyparsing-3.0.9-py3-none-any.whl", hash = "sha256:5026bae9a10eeaefb61dab2f09052b9f4307d44aee4eda64b309723d8d206bbc"},
{file = "pyparsing-3.0.9.tar.gz", hash = "sha256:2b020ecf7d21b687f219b71ecad3631f644a47f01403fa1d1036b0c6416d70fb"},
]
pytest = [
{file = "pytest-7.2.0-py3-none-any.whl", hash = "sha256:892f933d339f068883b6fd5a459f03d85bfcb355e4981e146d2c7616c21fef71"},
{file = "pytest-7.2.0.tar.gz", hash = "sha256:c4014eb40e10f11f355ad4e3c2fb2c6c6d1919c73f3b5a433de4708202cade59"},
]
PyYAML = [
{file = "PyYAML-6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d4db7c7aef085872ef65a8fd7d6d09a14ae91f691dec3e87ee5ee0539d516f53"},
{file = "PyYAML-6.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9df7ed3b3d2e0ecfe09e14741b857df43adb5a3ddadc919a2d94fbdf78fea53c"},
@ -512,6 +605,10 @@ six = [
{file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"},
{file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"},
]
tomli = [
{file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
{file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
]
torch = [
{file = "torch-1.12.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:9c038662db894a23e49e385df13d47b2a777ffd56d9bcd5b832593fab0a7e286"},
{file = "torch-1.12.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:4e1b9c14cf13fd2ab8d769529050629a0e68a6fc5cb8e84b4a3cc1dd8c4fe541"},

View File

@ -22,6 +22,7 @@ bnb = ["bitsandbytes"]
[tool.poetry.group.dev.dependencies]
grpcio-tools = "^1.49.1"
pytest = "^7.2.0"
[build-system]
requires = ["poetry-core>=1.0.0"]

36
server/tests/conftest.py Normal file
View File

@ -0,0 +1,36 @@
import pytest
from transformers import AutoTokenizer
from text_generation.pb import generate_pb2
@pytest.fixture
def default_pb_parameters():
return generate_pb2.LogitsWarperParameters(
temperature=1.0,
top_k=0,
top_p=1.0,
do_sample=False,
)
@pytest.fixture(scope="session")
def bloom_560m_tokenizer():
return AutoTokenizer.from_pretrained("bigscience/bloom-560m", padding_side="left")
@pytest.fixture(scope="session")
def gpt2_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left")
tokenizer.pad_token_id = 50256
return tokenizer
@pytest.fixture(scope="session")
def mt0_small_tokenizer():
tokenizer = AutoTokenizer.from_pretrained(
"bigscience/mt0-small", padding_side="left"
)
tokenizer.bos_token_id = 0
return tokenizer

View File

@ -0,0 +1,279 @@
import pytest
import torch
from copy import copy
from text_generation.pb import generate_pb2
from text_generation.models.causal_lm import CausalLMBatch
from text_generation.models.bloom import BloomCausalLMBatch, BLOOM
@pytest.fixture
def default_pb_request(default_pb_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
input_length=1,
parameters=default_pb_parameters,
max_new_tokens=10,
)
@pytest.fixture
def default_pb_batch(default_pb_request):
return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1)
@pytest.fixture
def default_bloom_batch(default_pb_batch, bloom_560m_tokenizer):
return BloomCausalLMBatch.from_pb(
default_pb_batch, bloom_560m_tokenizer, torch.device("cpu")
)
@pytest.fixture
def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer):
req_0 = copy(default_pb_request)
req_1 = default_pb_request
req_1.id = 1
req_1.max_new_tokens = 5
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
return BloomCausalLMBatch.from_pb(
batch_pb, bloom_560m_tokenizer, torch.device("cpu")
)
@pytest.fixture(scope="session")
def default_bloom():
return BLOOM("bigscience/bloom-560m")
def test_batch_from_pb(default_pb_batch, default_bloom_batch):
batch = default_bloom_batch
assert batch.batch_id == default_pb_batch.id
assert batch.requests == default_pb_batch.requests
assert len(batch.input_ids) == default_pb_batch.size
assert len(batch.input_ids[0]) == 8
assert batch.input_ids[0][-1] == 10264
assert torch.all(batch.input_ids[0][:-1] == 3)
assert batch.attention_mask[0][-1] == 1
assert torch.all(batch.attention_mask[0][:-1] == 0)
assert batch.past_key_values is None
assert torch.equal(batch.input_ids, batch.all_input_ids[:, :, 0])
assert batch.input_lengths == [1]
assert batch.size == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
assert batch.max_sequence_length == batch.input_lengths[0]
def test_batch_concatenate_no_prefill(default_bloom_batch):
with pytest.raises(ValueError):
BloomCausalLMBatch.concatenate([default_bloom_batch, default_bloom_batch])
def test_causal_lm_batch_type(default_bloom):
assert default_bloom.batch_type == BloomCausalLMBatch
def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
generated_texts, next_batch = default_bloom.generate_token(default_bloom_batch)
assert generated_texts == []
assert isinstance(next_batch, CausalLMBatch)
assert not next_batch.keys_head_dim_last
assert len(next_batch.all_input_ids) == next_batch.size
assert len(next_batch.all_input_ids[0]) == len(next_batch.attention_mask[0]) == 9
assert torch.all(next_batch.all_input_ids[0][-2:] == 10264)
assert torch.all(next_batch.all_input_ids[0][:-2] == 3)
assert torch.all(next_batch.attention_mask[0][-2:] == 1)
assert torch.all(next_batch.attention_mask[0][:-2] == 0)
assert next_batch.input_ids.shape == (next_batch.size, 1)
assert next_batch.input_ids[0, 0] == 10264
assert next_batch.input_lengths == [2]
assert next_batch.max_sequence_length == next_batch.input_lengths[0]
assert next_batch.past_key_values is not None
assert all([p[0].shape == (16, 64, 8) for p in next_batch.past_key_values])
assert all([p[1].shape == (16, 8, 64) for p in next_batch.past_key_values])
def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch):
next_batch = default_bloom_batch
for _ in range(default_bloom_batch.stopping_criterias[0].max_new_tokens - 1):
generated_texts, next_batch = default_bloom.generate_token(next_batch)
assert generated_texts == []
generated_texts, next_batch = default_bloom.generate_token(next_batch)
assert next_batch is None
assert len(generated_texts) == 1
assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest"
assert generated_texts[0].request == default_bloom_batch.requests[0]
assert (
generated_texts[0].tokens
== default_bloom_batch.stopping_criterias[0].max_new_tokens
)
def test_causal_lm_generate_token_completion_multi(
default_bloom, default_multi_requests_bloom_batch
):
next_batch = default_multi_requests_bloom_batch
for i in range(
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 1
):
generated_texts, next_batch = default_bloom.generate_token(next_batch)
assert generated_texts == []
generated_texts, next_batch = default_bloom.generate_token(next_batch)
assert next_batch is not None
assert len(generated_texts) == 1
assert generated_texts[0].output == "TestTestTestTestTestTest"
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[1]
assert (
generated_texts[0].tokens
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
)
for _ in range(
default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
- 1
):
generated_texts, next_batch = default_bloom.generate_token(next_batch)
assert generated_texts == []
generated_texts, next_batch = default_bloom.generate_token(next_batch)
assert next_batch is None
assert len(generated_texts) == 1
assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest"
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0]
assert (
generated_texts[0].tokens
== default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
)
def test_batch_concatenate(
default_bloom, default_bloom_batch, default_multi_requests_bloom_batch
):
next_batch_0 = default_bloom_batch
_, next_batch_0 = default_bloom.generate_token(next_batch_0)
_, next_batch_0 = default_bloom.generate_token(next_batch_0)
next_batch_1 = default_multi_requests_bloom_batch
_, next_batch_1 = default_bloom.generate_token(next_batch_1)
next_batch = BloomCausalLMBatch.concatenate([next_batch_0, next_batch_1])
assert torch.equal(next_batch.all_input_ids[0], next_batch_0.all_input_ids[0])
assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0])
assert torch.equal(next_batch.all_input_ids[2], next_batch_1.all_input_ids[1])
assert torch.all(next_batch.attention_mask[0] == 1)
assert torch.all(next_batch.attention_mask[1:, -2:] == 1)
assert torch.all(next_batch.attention_mask[1:, :-2] == 0)
assert next_batch.batch_id == 0
assert torch.all(next_batch.input_ids == 10264)
assert next_batch.input_lengths == [3, 2, 2]
assert next_batch.max_sequence_length == 3
assert next_batch.requests[0] == next_batch_0.requests[0]
assert next_batch.requests[1:] == next_batch_1.requests
assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0]
assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers
assert next_batch.stopping_criterias[0] == next_batch_0.stopping_criterias[0]
assert next_batch.stopping_criterias[1:] == next_batch_1.stopping_criterias
assert next_batch.past_key_values is not None
assert all([p[0].shape == (3, 16, 64, 2) for p in next_batch.past_key_values])
assert all([p[1].shape == (3, 16, 2, 64) for p in next_batch.past_key_values])
for i, past in enumerate(next_batch.past_key_values):
assert torch.equal(next_batch_0.past_key_values[i][0][:, :, -2:], past[0][0])
assert torch.equal(
next_batch_1.past_key_values[i][0][:, :, -1:],
past[0][1:, :, :, -1].reshape(-1, 64, 1),
)
assert torch.equal(next_batch_0.past_key_values[i][1][:, -2:, :], past[1][0])
assert torch.equal(
next_batch_1.past_key_values[i][1][:, -1:, :],
past[1][1:, :, -1, :].reshape(-1, 1, 64),
)
for _ in range(
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 2
):
generated_texts, next_batch = default_bloom.generate_token(next_batch)
assert generated_texts == []
generated_texts, next_batch = default_bloom.generate_token(next_batch)
assert next_batch is not None
assert len(generated_texts) == 1
assert generated_texts[0].output == "TestTestTestTestTestTest"
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[1]
assert (
generated_texts[0].tokens
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
)
for _ in range(
default_bloom_batch.stopping_criterias[0].max_new_tokens
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
- 2
):
generated_texts, next_batch = default_bloom.generate_token(next_batch)
assert generated_texts == []
generated_texts, next_batch = default_bloom.generate_token(next_batch)
assert next_batch is not None
assert len(generated_texts) == 1
assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest"
assert generated_texts[0].request == default_bloom_batch.requests[0]
assert (
generated_texts[0].tokens
== default_bloom_batch.stopping_criterias[0].max_new_tokens
)
for _ in range(
default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
- default_bloom_batch.stopping_criterias[0].max_new_tokens
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
- 4
):
generated_texts, next_batch = default_bloom.generate_token(next_batch)
assert generated_texts == []
generated_texts, next_batch = default_bloom.generate_token(next_batch)
assert next_batch is None
assert len(generated_texts) == 1
assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest"
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0]
assert (
generated_texts[0].tokens
== default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
)

View File

@ -0,0 +1,296 @@
import pytest
import torch
from copy import copy
from text_generation.pb import generate_pb2
from text_generation.models.causal_lm import CausalLM, CausalLMBatch
@pytest.fixture
def default_pb_request(default_pb_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
input_length=1,
parameters=default_pb_parameters,
max_new_tokens=10,
)
@pytest.fixture
def default_pb_batch(default_pb_request):
return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1)
@pytest.fixture
def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer):
return CausalLMBatch.from_pb(default_pb_batch, gpt2_tokenizer, torch.device("cpu"))
@pytest.fixture
def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer):
req_0 = copy(default_pb_request)
req_1 = default_pb_request
req_1.id = 1
req_1.max_new_tokens = 5
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
return CausalLMBatch.from_pb(batch_pb, gpt2_tokenizer, torch.device("cpu"))
@pytest.fixture(scope="session")
def default_causal_lm():
return CausalLM("gpt2")
def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
batch = default_causal_lm_batch
assert batch.batch_id == default_pb_batch.id
assert batch.requests == default_pb_batch.requests
assert len(batch.input_ids) == default_pb_batch.size
assert len(batch.input_ids[0]) == 8
assert batch.input_ids[0][-1] == 14402
assert torch.all(batch.input_ids[0][:-1] == 50256)
assert batch.attention_mask[0][-1] == 1
assert torch.all(batch.attention_mask[0][:-1] == 0)
assert batch.past_key_values is None
assert torch.equal(batch.input_ids, batch.all_input_ids[:, :, 0])
assert batch.input_lengths == [1]
assert batch.size == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
assert batch.max_sequence_length == batch.input_lengths[0]
def test_batch_concatenate_no_prefill(default_causal_lm_batch):
with pytest.raises(ValueError):
CausalLMBatch.concatenate([default_causal_lm_batch, default_causal_lm_batch])
def test_causal_lm_batch_type(default_causal_lm):
assert default_causal_lm.batch_type == CausalLMBatch
def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
generated_texts, next_batch = default_causal_lm.generate_token(
default_causal_lm_batch
)
assert generated_texts == []
assert isinstance(next_batch, CausalLMBatch)
assert len(next_batch.all_input_ids) == next_batch.size
assert len(next_batch.all_input_ids[0]) == len(next_batch.attention_mask[0]) == 9
assert next_batch.all_input_ids[0][-1] == 6208
assert next_batch.all_input_ids[0][-2] == 14402
assert torch.all(next_batch.all_input_ids[0][:-2] == 50256)
assert torch.all(next_batch.attention_mask[0][-2:] == 1)
assert torch.all(next_batch.attention_mask[0][:-2] == 0)
assert next_batch.input_ids.shape == (next_batch.size, 1)
assert next_batch.input_ids[0, 0] == 6208
assert next_batch.input_lengths == [2]
assert next_batch.max_sequence_length == next_batch.input_lengths[0]
assert next_batch.past_key_values is not None
assert all([p[0].shape == (1, 12, 8, 64) for p in next_batch.past_key_values])
assert all([p[1].shape == (1, 12, 8, 64) for p in next_batch.past_key_values])
def test_causal_lm_generate_token_completion(
default_causal_lm, default_causal_lm_batch
):
next_batch = default_causal_lm_batch
for _ in range(default_causal_lm_batch.stopping_criterias[0].max_new_tokens - 1):
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
assert generated_texts == []
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
assert next_batch is None
assert len(generated_texts) == 1
assert (
generated_texts[0].output
== "Test Test Test Test Test Test Test Test Test Test Test"
)
assert generated_texts[0].request == default_causal_lm_batch.requests[0]
assert (
generated_texts[0].tokens
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
)
def test_causal_lm_generate_token_completion_multi(
default_causal_lm, default_multi_requests_causal_lm_batch
):
next_batch = default_multi_requests_causal_lm_batch
for i in range(
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1
):
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
assert generated_texts == []
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generated_texts) == 1
assert generated_texts[0].output == "Test Test Test Test Test Test"
assert (
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1]
)
assert (
generated_texts[0].tokens
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
)
for _ in range(
default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
- 1
):
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
assert generated_texts == []
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
assert next_batch is None
assert len(generated_texts) == 1
assert (
generated_texts[0].output
== "Test Test Test Test Test Test Test Test Test Test Test"
)
assert (
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0]
)
assert (
generated_texts[0].tokens
== default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
)
def test_batch_concatenate(
default_causal_lm, default_causal_lm_batch, default_multi_requests_causal_lm_batch
):
next_batch_0 = default_causal_lm_batch
_, next_batch_0 = default_causal_lm.generate_token(next_batch_0)
_, next_batch_0 = default_causal_lm.generate_token(next_batch_0)
next_batch_1 = default_multi_requests_causal_lm_batch
_, next_batch_1 = default_causal_lm.generate_token(next_batch_1)
next_batch = CausalLMBatch.concatenate([next_batch_0, next_batch_1])
assert torch.equal(next_batch.all_input_ids[0], next_batch_0.all_input_ids[0])
assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0])
assert torch.equal(next_batch.all_input_ids[2], next_batch_1.all_input_ids[1])
assert torch.all(next_batch.attention_mask[0] == 1)
assert torch.all(next_batch.attention_mask[1:, -2:] == 1)
assert torch.all(next_batch.attention_mask[1:, :-2] == 0)
assert next_batch.batch_id == 0
assert torch.all(next_batch.input_ids == 6208)
assert next_batch.input_lengths == [3, 2, 2]
assert next_batch.max_sequence_length == 3
assert next_batch.requests[0] == next_batch_0.requests[0]
assert next_batch.requests[1:] == next_batch_1.requests
assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0]
assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers
assert next_batch.stopping_criterias[0] == next_batch_0.stopping_criterias[0]
assert next_batch.stopping_criterias[1:] == next_batch_1.stopping_criterias
assert next_batch.past_key_values is not None
assert all([p[0].shape == (3, 12, 2, 64) for p in next_batch.past_key_values])
assert all([p[1].shape == (3, 12, 2, 64) for p in next_batch.past_key_values])
for i, past in enumerate(next_batch.past_key_values):
assert torch.equal(next_batch_0.past_key_values[i][0][0, :, -2:], past[0][0])
assert torch.equal(
next_batch_1.past_key_values[i][0][:, :, -1:], past[0][1:, :, -1:, :]
)
assert torch.equal(next_batch_0.past_key_values[i][1][0, :, -2:], past[1][0])
assert torch.equal(
next_batch_1.past_key_values[i][1][:, :, -1:], past[1][1:, :, -1:, :]
)
for _ in range(
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2
):
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
assert generated_texts == []
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generated_texts) == 1
assert generated_texts[0].output == "Test Test Test Test Test Test"
assert (
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1]
)
assert (
generated_texts[0].tokens
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
)
for _ in range(
default_causal_lm_batch.stopping_criterias[0].max_new_tokens
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
- 2
):
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
assert generated_texts == []
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generated_texts) == 1
assert (
generated_texts[0].output
== "Test Test Test Test Test Test Test Test Test Test Test"
)
assert generated_texts[0].request == default_causal_lm_batch.requests[0]
assert (
generated_texts[0].tokens
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
)
for _ in range(
default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
- default_causal_lm_batch.stopping_criterias[0].max_new_tokens
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
- 4
):
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
assert generated_texts == []
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
assert next_batch is None
assert len(generated_texts) == 1
assert (
generated_texts[0].output
== "Test Test Test Test Test Test Test Test Test Test Test"
)
assert (
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0]
)
assert (
generated_texts[0].tokens
== default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
)

View File

@ -0,0 +1,306 @@
import pytest
import torch
from copy import copy
from text_generation.pb import generate_pb2
from text_generation.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch
@pytest.fixture
def default_pb_request(default_pb_parameters):
return generate_pb2.Request(
id=0,
inputs="Test",
input_length=2,
parameters=default_pb_parameters,
max_new_tokens=10,
)
@pytest.fixture
def default_pb_batch(default_pb_request):
return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1)
@pytest.fixture
def default_seq2seq_lm_batch(default_pb_batch, mt0_small_tokenizer):
return Seq2SeqLMBatch.from_pb(
default_pb_batch, mt0_small_tokenizer, torch.device("cpu")
)
@pytest.fixture
def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokenizer):
req_0 = copy(default_pb_request)
req_1 = default_pb_request
req_1.id = 1
req_1.max_new_tokens = 5
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
return Seq2SeqLMBatch.from_pb(batch_pb, mt0_small_tokenizer, torch.device("cpu"))
@pytest.fixture(scope="session")
def default_seq2seq_lm():
return Seq2SeqLM("bigscience/mt0-small")
def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch):
batch = default_seq2seq_lm_batch
assert batch.batch_id == default_pb_batch.id
assert batch.requests == default_pb_batch.requests
assert batch.input_ids.shape == (default_pb_batch.size, 8)
assert batch.input_ids[0][-2] == 4268
assert batch.input_ids[0][-1] == 1
assert torch.all(batch.input_ids[0][:-2] == 0)
assert torch.all(batch.attention_mask[0][-2:] == 1)
assert torch.all(batch.attention_mask[0][:-2] == 0)
assert batch.decoder_input_ids.shape == (default_pb_batch.size, 1)
assert batch.decoder_attention_mask is None
assert batch.encoder_last_hidden_state is None
assert batch.past_key_values is None
assert batch.input_lengths == [2]
assert batch.decoder_input_lengths == [1]
assert batch.size == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
assert batch.max_input_length == batch.input_lengths[0]
assert batch.max_decoder_input_length == batch.decoder_input_lengths[0]
def test_batch_concatenate_no_prefill(default_seq2seq_lm_batch):
with pytest.raises(ValueError):
Seq2SeqLMBatch.concatenate([default_seq2seq_lm_batch, default_seq2seq_lm_batch])
def test_seq2seq_lm_batch_type(default_seq2seq_lm):
assert default_seq2seq_lm.batch_type == Seq2SeqLMBatch
def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch):
generated_texts, next_batch = default_seq2seq_lm.generate_token(
default_seq2seq_lm_batch
)
assert generated_texts == []
assert isinstance(next_batch, Seq2SeqLMBatch)
assert torch.equal(next_batch.input_ids, default_seq2seq_lm_batch.input_ids)
assert torch.equal(
next_batch.attention_mask, default_seq2seq_lm_batch.attention_mask
)
assert next_batch.input_lengths == default_seq2seq_lm_batch.input_lengths
assert next_batch.max_input_length == default_seq2seq_lm_batch.max_input_length
assert (
next_batch.next_token_choosers == default_seq2seq_lm_batch.next_token_choosers
)
assert next_batch.stopping_criterias == default_seq2seq_lm_batch.stopping_criterias
assert next_batch.decoder_input_ids.shape == (next_batch.size, 2)
assert next_batch.decoder_input_ids[0, 0] == 0
assert next_batch.decoder_input_ids[0, 1] == 259
assert next_batch.decoder_attention_mask is None
assert next_batch.encoder_last_hidden_state.shape == (1, 8, 512)
assert next_batch.decoder_input_lengths == [2]
assert next_batch.max_decoder_input_length == 2
assert next_batch.past_key_values is not None
assert all(
[p[0].shape == (next_batch.size, 6, 1, 64) for p in next_batch.past_key_values]
)
assert all(
[p[1].shape == (next_batch.size, 6, 1, 64) for p in next_batch.past_key_values]
)
assert all(
[p[2].shape == (next_batch.size, 6, 8, 64) for p in next_batch.past_key_values]
)
assert all(
[p[3].shape == (next_batch.size, 6, 8, 64) for p in next_batch.past_key_values]
)
def test_seq2seq_lm_generate_token_completion(
default_seq2seq_lm, default_seq2seq_lm_batch
):
next_batch = default_seq2seq_lm_batch
for _ in range(6):
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert generated_texts == []
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None
assert len(generated_texts) == 1
assert generated_texts[0].output == "a few weeks"
assert generated_texts[0].request == default_seq2seq_lm_batch.requests[0]
assert generated_texts[0].tokens == 7
def test_seq2seq_lm_generate_token_completion_multi(
default_seq2seq_lm, default_multi_requests_seq2seq_lm_batch
):
next_batch = default_multi_requests_seq2seq_lm_batch
for i in range(4):
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert generated_texts == []
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generated_texts) == 1
assert generated_texts[0].output == "a few "
assert (
generated_texts[0].request
== default_multi_requests_seq2seq_lm_batch.requests[1]
)
assert generated_texts[0].tokens == 5
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert generated_texts == []
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None
assert len(generated_texts) == 1
assert generated_texts[0].output == "a few weeks"
assert (
generated_texts[0].request
== default_multi_requests_seq2seq_lm_batch.requests[0]
)
assert generated_texts[0].tokens == 7
def test_batch_concatenate(
default_seq2seq_lm,
default_seq2seq_lm_batch,
default_multi_requests_seq2seq_lm_batch,
):
next_batch_0 = default_seq2seq_lm_batch
_, next_batch_0 = default_seq2seq_lm.generate_token(next_batch_0)
_, next_batch_0 = default_seq2seq_lm.generate_token(next_batch_0)
next_batch_1 = default_multi_requests_seq2seq_lm_batch
_, next_batch_1 = default_seq2seq_lm.generate_token(next_batch_1)
next_batch = Seq2SeqLMBatch.concatenate([next_batch_0, next_batch_1])
assert next_batch.batch_id == 0
assert torch.all(next_batch.input_ids[:, 0] == 4268)
assert torch.all(next_batch.input_ids[:, 1] == 1)
assert torch.all(next_batch.attention_mask == 1)
assert torch.equal(
next_batch.decoder_input_ids[0], next_batch_0.decoder_input_ids[0]
)
assert torch.all(next_batch.decoder_input_ids[1:, 0] == 0)
assert torch.equal(
next_batch.decoder_input_ids[1:, -2:], next_batch_1.decoder_input_ids
)
assert torch.all(next_batch.decoder_attention_mask[0] == 1)
assert torch.all(next_batch.decoder_attention_mask[1:, 0] == 0)
assert torch.all(next_batch.decoder_attention_mask[1:, -2:] == 1)
assert torch.equal(
next_batch.encoder_last_hidden_state[0],
next_batch_0.encoder_last_hidden_state[0, -2:],
)
assert torch.equal(
next_batch.encoder_last_hidden_state[1:],
next_batch_1.encoder_last_hidden_state[:, -2:],
)
assert next_batch.input_lengths == [2, 2, 2]
assert next_batch.decoder_input_lengths == [3, 2, 2]
assert next_batch.max_input_length == 2
assert next_batch.max_decoder_input_length == 3
assert next_batch.requests[0] == next_batch_0.requests[0]
assert next_batch.requests[1:] == next_batch_1.requests
assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0]
assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers
assert next_batch.stopping_criterias[0] == next_batch_0.stopping_criterias[0]
assert next_batch.stopping_criterias[1:] == next_batch_1.stopping_criterias
assert next_batch.past_key_values is not None
assert all(
[p[0].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values]
)
assert all(
[p[1].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values]
)
assert all(
[p[2].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values]
)
assert all(
[p[3].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values]
)
for i, past in enumerate(next_batch.past_key_values):
assert torch.equal(next_batch_0.past_key_values[i][0][0, :, -2:, :], past[0][0])
assert torch.equal(
next_batch_1.past_key_values[i][0][:, :, -1:, :], past[0][1:, :, -1:, :]
)
assert torch.equal(next_batch_0.past_key_values[i][1][0, :, -2:, :], past[1][0])
assert torch.equal(
next_batch_1.past_key_values[i][1][:, :, -1:, :], past[1][1:, :, -1:, :]
)
assert torch.equal(next_batch_0.past_key_values[i][2][0, :, -2:, :], past[2][0])
assert torch.equal(
next_batch_1.past_key_values[i][2][:, :, -2:, :], past[2][1:]
)
assert torch.equal(next_batch_0.past_key_values[i][3][0, :, -2:, :], past[3][0])
assert torch.equal(
next_batch_1.past_key_values[i][3][:, :, -2:, :], past[3][1:]
)
for _ in range(3):
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert generated_texts == []
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generated_texts) == 1
assert generated_texts[0].output == "a few "
assert (
generated_texts[0].request
== default_multi_requests_seq2seq_lm_batch.requests[1]
)
assert generated_texts[0].tokens == 5
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generated_texts) == 1
assert generated_texts[0].output == "a few weeks"
assert generated_texts[0].request == default_seq2seq_lm_batch.requests[0]
assert generated_texts[0].tokens == 7
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None
assert len(generated_texts) == 1
assert generated_texts[0].output == "a few weeks"
assert (
generated_texts[0].request
== default_multi_requests_seq2seq_lm_batch.requests[0]
)
assert generated_texts[0].tokens == 7

View File

@ -0,0 +1,34 @@
import pytest
from text_generation.utils import (
weight_hub_files,
download_weights,
weight_files,
LocalEntryNotFoundError,
)
def test_weight_hub_files():
filenames = weight_hub_files("bigscience/bloom-560m")
assert filenames == ["model.safetensors"]
def test_weight_hub_files_llm():
filenames = weight_hub_files("bigscience/bloom")
assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)]
def test_weight_hub_files_empty():
filenames = weight_hub_files("bigscience/bloom", ".errors")
assert filenames == []
def test_download_weights():
files = download_weights("bigscience/bloom-560m")
local_files = weight_files("bigscience/bloom-560m")
assert files == local_files
def test_weight_files_error():
with pytest.raises(LocalEntryNotFoundError):
weight_files("bert-base-uncased")

View File

@ -1,10 +1,10 @@
from text_generation.models.model import Model
from text_generation.models.causal_lm import CausalLM
from text_generation.models.bloom import BLOOMSharded
from text_generation.models.bloom import BLOOM, BLOOMSharded
from text_generation.models.seq2seq_lm import Seq2SeqLM
from text_generation.models.galactica import Galactica, GalacticaSharded
__all__ = ["Model", "BLOOMSharded", "CausalLM", "Seq2SeqLM"]
__all__ = ["Model", "BLOOM", "BLOOMSharded", "CausalLM", "Seq2SeqLM"]
def get_model(model_name: str, sharded: bool, quantize: bool) -> Model:
@ -12,7 +12,7 @@ def get_model(model_name: str, sharded: bool, quantize: bool) -> Model:
if sharded:
return BLOOMSharded(model_name, quantize=quantize)
else:
return CausalLM(model_name, quantize=quantize)
return BLOOM(model_name, quantize=quantize)
elif model_name.startswith("facebook/galactica"):
if sharded:
return GalacticaSharded(model_name, quantize=quantize)

View File

@ -1,7 +1,7 @@
import torch
import torch.distributed
from typing import List, Optional
from typing import List, Optional, Type
from accelerate import init_empty_weights
from safetensors import safe_open
@ -13,6 +13,8 @@ from transformers.models.bloom.parallel_layers import (
)
from text_generation.models import CausalLM
from text_generation.models.causal_lm import CausalLMBatch
from text_generation.pb import generate_pb2
from text_generation.utils import (
initialize_torch_distributed,
weight_files,
@ -29,7 +31,25 @@ except Exception as e:
torch.manual_seed(0)
class BLOOMSharded(CausalLM):
class BloomCausalLMBatch(CausalLMBatch):
@classmethod
def from_pb(
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
) -> "CausalLMBatch":
batch = super(BloomCausalLMBatch, cls).from_pb(
pb=pb, tokenizer=tokenizer, device=device
)
batch.keys_head_dim_last = False
return batch
class BLOOM(CausalLM):
@property
def batch_type(self) -> Type[CausalLMBatch]:
return BloomCausalLMBatch
class BLOOMSharded(BLOOM):
def __init__(self, model_name: str, quantize: bool = False):
if not model_name.startswith("bigscience/bloom"):
raise ValueError(f"Model {model_name} is not supported")

View File

@ -34,6 +34,9 @@ class CausalLMBatch:
size: int
max_sequence_length: int
# Past metadata
keys_head_dim_last: bool = True
def to_pb(self):
return generate_pb2.Batch(
id=self.batch_id,
@ -165,20 +168,16 @@ class CausalLMBatch:
head_dim,
)
if batch.keys_head_dim_last:
padded_past_keys_shape = padded_past_values_shape
# seq_length is last for BLOOM
if past_keys.shape[-2] == head_dim:
past_keys_head_dim_last = False
else:
padded_past_keys_shape = (
total_batch_size,
num_heads,
head_dim,
max_sequence_length - 1,
)
elif past_keys.shape[-1] == head_dim:
past_keys_head_dim_last = True
padded_past_keys_shape = padded_past_values_shape
else:
raise ValueError(f"past_keys shape {past_keys.shape} is not valid")
# This will run only once per layer
if j == len(past_key_values):
@ -195,7 +194,7 @@ class CausalLMBatch:
past_key_values.append((padded_past_keys, padded_past_values))
# We slice the past keys and values to remove the padding from previous batches
if past_keys_head_dim_last:
if batch.keys_head_dim_last:
past_key_values[j][0][
start_index:end_index,
:,
@ -228,6 +227,7 @@ class CausalLMBatch:
stopping_criterias=stopping_criterias,
size=total_batch_size,
max_sequence_length=max_sequence_length,
keys_head_dim_last=batches[0].keys_head_dim_last,
)
@ -237,6 +237,9 @@ class CausalLM(Model):
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32
@ -247,7 +250,11 @@ class CausalLM(Model):
device_map="auto" if torch.cuda.is_available() else None,
load_in_8bit=quantize,
).eval()
tokenizer.pad_token_id = self.model.config.pad_token_id
tokenizer.pad_token_id = (
self.model.config.pad_token_id
if self.model.config.pad_token_id is not None
else self.model.config.eos_token_id
)
super(CausalLM, self).__init__(
tokenizer=tokenizer,
@ -397,5 +404,6 @@ class CausalLM(Model):
stopping_criterias=next_batch_stopping_criterias,
size=next_batch_size,
max_sequence_length=next_batch_max_sequence_length,
keys_head_dim_last=batch.keys_head_dim_last,
)
return generated_texts, next_batch

View File

@ -83,7 +83,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
@classmethod
def from_pb(
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
) -> "CausalLMBatch":
) -> "GalacticaCausalLMBatch":
inputs = []
next_token_choosers = []
stopping_criterias = []

View File

@ -221,8 +221,8 @@ class Seq2SeqLMBatch:
# Copy to correct indices
encoder_last_hidden_state[
start_index:end_index, -batch.max_decoder_input_length :, :
] = batch.encoder_last_hidden_state[:, -batch.max_decoder_input_length :, :]
start_index:end_index, -batch.max_input_length :, :
] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :]
# Iterate over attention layers
for j, past in enumerate(batch.past_key_values):
@ -305,6 +305,9 @@ class Seq2SeqLM(Model):
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32

View File

@ -137,8 +137,8 @@ def download_weights(model_name, extension=".safetensors"):
executor.submit(download_function, filename=filename) for filename in filenames
]
files = [
file
for file in tqdm(concurrent.futures.as_completed(futures), total=len(futures))
future.result()
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures))
]
return files