feat(server): Add model tests (#6)
This commit is contained in:
parent
31d76e238d
commit
a2985036aa
|
@ -87,9 +87,4 @@ curl 127.0.0.1:3000/generate \
|
|||
```shell
|
||||
make server-dev
|
||||
make router-dev
|
||||
```
|
||||
|
||||
## TODO:
|
||||
|
||||
- [ ] Add tests for the `server/model` logic
|
||||
- [ ] Backport custom CUDA kernels to Transformers
|
||||
```
|
|
@ -70,7 +70,7 @@ impl Batcher {
|
|||
|
||||
// Notify the background task that we have a new entry in the database that needs
|
||||
// to be batched
|
||||
self.shared.batching_task.notify_waiters();
|
||||
self.shared.batching_task.notify_one();
|
||||
|
||||
// Await on the response from the background task
|
||||
// We can safely unwrap as the background task will never drop the sender
|
||||
|
|
|
@ -8,8 +8,9 @@ gen-server:
|
|||
|
||||
install-transformers:
|
||||
# Install specific version of transformers with custom cuda kernels
|
||||
rm transformers || true
|
||||
rm transformers-text_generation_inference || true
|
||||
pip uninstall transformers -y || true
|
||||
rm -rf transformers || true
|
||||
rm -rf transformers-text_generation_inference || true
|
||||
curl -L -O https://github.com/OlivierDehaene/transformers/archive/refs/heads/text_generation_inference.zip
|
||||
unzip text_generation_inference.zip
|
||||
rm text_generation_inference.zip
|
||||
|
|
|
@ -22,6 +22,20 @@ test_prod = ["parameterized", "pytest", "pytest-subtests", "pytest-xdist"]
|
|||
test_trackers = ["comet-ml", "tensorboard", "wandb"]
|
||||
testing = ["datasets", "deepspeed (<0.7.0)", "evaluate", "parameterized", "pytest", "pytest-subtests", "pytest-xdist", "scipy", "sklearn", "tqdm", "transformers"]
|
||||
|
||||
[[package]]
|
||||
name = "attrs"
|
||||
version = "22.1.0"
|
||||
description = "Classes Without Boilerplate"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.5"
|
||||
|
||||
[package.extras]
|
||||
dev = ["cloudpickle", "coverage[toml] (>=5.0.2)", "furo", "hypothesis", "mypy (>=0.900,!=0.940)", "pre-commit", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "sphinx", "sphinx-notfound-page", "zope.interface"]
|
||||
docs = ["furo", "sphinx", "sphinx-notfound-page", "zope.interface"]
|
||||
tests = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "zope.interface"]
|
||||
tests_no_zope = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins"]
|
||||
|
||||
[[package]]
|
||||
name = "bitsandbytes"
|
||||
version = "0.35.1"
|
||||
|
@ -49,6 +63,17 @@ category = "main"
|
|||
optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
|
||||
|
||||
[[package]]
|
||||
name = "exceptiongroup"
|
||||
version = "1.0.4"
|
||||
description = "Backport of PEP 654 (exception groups)"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
|
||||
[package.extras]
|
||||
test = ["pytest (>=6)"]
|
||||
|
||||
[[package]]
|
||||
name = "grpcio"
|
||||
version = "1.50.0"
|
||||
|
@ -88,6 +113,14 @@ grpcio = ">=1.50.0"
|
|||
protobuf = ">=4.21.6,<5.0dev"
|
||||
setuptools = "*"
|
||||
|
||||
[[package]]
|
||||
name = "iniconfig"
|
||||
version = "1.1.1"
|
||||
description = "iniconfig: brain-dead simple config-ini parsing"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
|
||||
[[package]]
|
||||
name = "numpy"
|
||||
version = "1.23.4"
|
||||
|
@ -107,6 +140,18 @@ python-versions = ">=3.6"
|
|||
[package.dependencies]
|
||||
pyparsing = ">=2.0.2,<3.0.5 || >3.0.5"
|
||||
|
||||
[[package]]
|
||||
name = "pluggy"
|
||||
version = "1.0.0"
|
||||
description = "plugin and hook calling mechanisms for python"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
|
||||
[package.extras]
|
||||
dev = ["pre-commit", "tox"]
|
||||
testing = ["pytest", "pytest-benchmark"]
|
||||
|
||||
[[package]]
|
||||
name = "protobuf"
|
||||
version = "4.21.8"
|
||||
|
@ -137,6 +182,26 @@ python-versions = ">=3.6.8"
|
|||
[package.extras]
|
||||
diagrams = ["jinja2", "railroad-diagrams"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest"
|
||||
version = "7.2.0"
|
||||
description = "pytest: simple powerful testing with Python"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
|
||||
[package.dependencies]
|
||||
attrs = ">=19.2.0"
|
||||
colorama = {version = "*", markers = "sys_platform == \"win32\""}
|
||||
exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""}
|
||||
iniconfig = "*"
|
||||
packaging = "*"
|
||||
pluggy = ">=0.12,<2.0"
|
||||
tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""}
|
||||
|
||||
[package.extras]
|
||||
testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "xmlschema"]
|
||||
|
||||
[[package]]
|
||||
name = "PyYAML"
|
||||
version = "6.0"
|
||||
|
@ -178,6 +243,14 @@ category = "main"
|
|||
optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*"
|
||||
|
||||
[[package]]
|
||||
name = "tomli"
|
||||
version = "2.0.1"
|
||||
description = "A lil' TOML parser"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
|
||||
[[package]]
|
||||
name = "torch"
|
||||
version = "1.12.1"
|
||||
|
@ -220,13 +293,17 @@ bnb = ["bitsandbytes"]
|
|||
[metadata]
|
||||
lock-version = "1.1"
|
||||
python-versions = "^3.9"
|
||||
content-hash = "3266187ef14fe8f9e29b3b6530d07781ea952aa670c0fe0de34be43efa231a67"
|
||||
content-hash = "51693654531e3229ac64bee250932ace20a60e8d45af074ae7b860ed32b25ef8"
|
||||
|
||||
[metadata.files]
|
||||
accelerate = [
|
||||
{file = "accelerate-0.12.0-py3-none-any.whl", hash = "sha256:7742ca5c9f15dd1e0a283305599c196e260af4717a561d1f544aeab27d828af6"},
|
||||
{file = "accelerate-0.12.0.tar.gz", hash = "sha256:e8b119c94fac31877620d5f9de311164ec81fa9dc9e175f0d0d4f50fc8d79473"},
|
||||
]
|
||||
attrs = [
|
||||
{file = "attrs-22.1.0-py2.py3-none-any.whl", hash = "sha256:86efa402f67bf2df34f51a335487cf46b1ec130d02b8d39fd248abfd30da551c"},
|
||||
{file = "attrs-22.1.0.tar.gz", hash = "sha256:29adc2665447e5191d0e7c568fde78b21f9672d344281d0c6e1ab085429b22b6"},
|
||||
]
|
||||
bitsandbytes = [
|
||||
{file = "bitsandbytes-0.35.1-py3-none-any.whl", hash = "sha256:4506a9e3778359a743938aa5592d8d043fa91d1df66cd01ba8cc6486e64dea45"},
|
||||
{file = "bitsandbytes-0.35.1.tar.gz", hash = "sha256:63a6f59c87b713a731a685e43d68c19789ee6381e62196cafab293b87eca5d46"},
|
||||
|
@ -239,6 +316,10 @@ colorama = [
|
|||
{file = "colorama-0.4.5-py2.py3-none-any.whl", hash = "sha256:854bf444933e37f5824ae7bfc1e98d5bce2ebe4160d46b5edf346a89358e99da"},
|
||||
{file = "colorama-0.4.5.tar.gz", hash = "sha256:e6c6b4334fc50988a639d9b98aa429a0b57da6e17b9a44f0451f930b6967b7a4"},
|
||||
]
|
||||
exceptiongroup = [
|
||||
{file = "exceptiongroup-1.0.4-py3-none-any.whl", hash = "sha256:542adf9dea4055530d6e1279602fa5cb11dab2395fa650b8674eaec35fc4a828"},
|
||||
{file = "exceptiongroup-1.0.4.tar.gz", hash = "sha256:bd14967b79cd9bdb54d97323216f8fdf533e278df937aa2a90089e7d6e06e5ec"},
|
||||
]
|
||||
grpcio = [
|
||||
{file = "grpcio-1.50.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:906f4d1beb83b3496be91684c47a5d870ee628715227d5d7c54b04a8de802974"},
|
||||
{file = "grpcio-1.50.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:2d9fd6e38b16c4d286a01e1776fdf6c7a4123d99ae8d6b3f0b4a03a34bf6ce45"},
|
||||
|
@ -337,6 +418,10 @@ grpcio-tools = [
|
|||
{file = "grpcio_tools-1.50.0-cp39-cp39-win32.whl", hash = "sha256:e1a8f9a57bbcc2e633aaf327e39830527f3c1f7add18c7580f3058fe9a0fa780"},
|
||||
{file = "grpcio_tools-1.50.0-cp39-cp39-win_amd64.whl", hash = "sha256:b7eb7a84d9171c0ae1550833f4a6ca52372bed9db0fa10f8c9dbe6ca65f97a8c"},
|
||||
]
|
||||
iniconfig = [
|
||||
{file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"},
|
||||
{file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"},
|
||||
]
|
||||
numpy = [
|
||||
{file = "numpy-1.23.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:95d79ada05005f6f4f337d3bb9de8a7774f259341c70bc88047a1f7b96a4bcb2"},
|
||||
{file = "numpy-1.23.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:926db372bc4ac1edf81cfb6c59e2a881606b409ddc0d0920b988174b2e2a767f"},
|
||||
|
@ -371,6 +456,10 @@ packaging = [
|
|||
{file = "packaging-21.3-py3-none-any.whl", hash = "sha256:ef103e05f519cdc783ae24ea4e2e0f508a9c99b2d4969652eed6a2e1ea5bd522"},
|
||||
{file = "packaging-21.3.tar.gz", hash = "sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb"},
|
||||
]
|
||||
pluggy = [
|
||||
{file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"},
|
||||
{file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"},
|
||||
]
|
||||
protobuf = [
|
||||
{file = "protobuf-4.21.8-cp310-abi3-win32.whl", hash = "sha256:c252c55ee15175aa1b21b7b9896e6add5162d066d5202e75c39f96136f08cce3"},
|
||||
{file = "protobuf-4.21.8-cp310-abi3-win_amd64.whl", hash = "sha256:809ca0b225d3df42655a12f311dd0f4148a943c51f1ad63c38343e457492b689"},
|
||||
|
@ -429,6 +518,10 @@ pyparsing = [
|
|||
{file = "pyparsing-3.0.9-py3-none-any.whl", hash = "sha256:5026bae9a10eeaefb61dab2f09052b9f4307d44aee4eda64b309723d8d206bbc"},
|
||||
{file = "pyparsing-3.0.9.tar.gz", hash = "sha256:2b020ecf7d21b687f219b71ecad3631f644a47f01403fa1d1036b0c6416d70fb"},
|
||||
]
|
||||
pytest = [
|
||||
{file = "pytest-7.2.0-py3-none-any.whl", hash = "sha256:892f933d339f068883b6fd5a459f03d85bfcb355e4981e146d2c7616c21fef71"},
|
||||
{file = "pytest-7.2.0.tar.gz", hash = "sha256:c4014eb40e10f11f355ad4e3c2fb2c6c6d1919c73f3b5a433de4708202cade59"},
|
||||
]
|
||||
PyYAML = [
|
||||
{file = "PyYAML-6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d4db7c7aef085872ef65a8fd7d6d09a14ae91f691dec3e87ee5ee0539d516f53"},
|
||||
{file = "PyYAML-6.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9df7ed3b3d2e0ecfe09e14741b857df43adb5a3ddadc919a2d94fbdf78fea53c"},
|
||||
|
@ -512,6 +605,10 @@ six = [
|
|||
{file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"},
|
||||
{file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"},
|
||||
]
|
||||
tomli = [
|
||||
{file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
|
||||
{file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
|
||||
]
|
||||
torch = [
|
||||
{file = "torch-1.12.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:9c038662db894a23e49e385df13d47b2a777ffd56d9bcd5b832593fab0a7e286"},
|
||||
{file = "torch-1.12.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:4e1b9c14cf13fd2ab8d769529050629a0e68a6fc5cb8e84b4a3cc1dd8c4fe541"},
|
||||
|
|
|
@ -22,6 +22,7 @@ bnb = ["bitsandbytes"]
|
|||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
grpcio-tools = "^1.49.1"
|
||||
pytest = "^7.2.0"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
import pytest
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from text_generation.pb import generate_pb2
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_pb_parameters():
|
||||
return generate_pb2.LogitsWarperParameters(
|
||||
temperature=1.0,
|
||||
top_k=0,
|
||||
top_p=1.0,
|
||||
do_sample=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def bloom_560m_tokenizer():
|
||||
return AutoTokenizer.from_pretrained("bigscience/bloom-560m", padding_side="left")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def gpt2_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left")
|
||||
tokenizer.pad_token_id = 50256
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def mt0_small_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"bigscience/mt0-small", padding_side="left"
|
||||
)
|
||||
tokenizer.bos_token_id = 0
|
||||
return tokenizer
|
|
@ -0,0 +1,279 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from copy import copy
|
||||
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation.models.causal_lm import CausalLMBatch
|
||||
from text_generation.models.bloom import BloomCausalLMBatch, BLOOM
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_pb_request(default_pb_parameters):
|
||||
return generate_pb2.Request(
|
||||
id=0,
|
||||
inputs="Test",
|
||||
input_length=1,
|
||||
parameters=default_pb_parameters,
|
||||
max_new_tokens=10,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_pb_batch(default_pb_request):
|
||||
return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_bloom_batch(default_pb_batch, bloom_560m_tokenizer):
|
||||
return BloomCausalLMBatch.from_pb(
|
||||
default_pb_batch, bloom_560m_tokenizer, torch.device("cpu")
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer):
|
||||
req_0 = copy(default_pb_request)
|
||||
req_1 = default_pb_request
|
||||
req_1.id = 1
|
||||
req_1.max_new_tokens = 5
|
||||
|
||||
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
|
||||
return BloomCausalLMBatch.from_pb(
|
||||
batch_pb, bloom_560m_tokenizer, torch.device("cpu")
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def default_bloom():
|
||||
return BLOOM("bigscience/bloom-560m")
|
||||
|
||||
|
||||
def test_batch_from_pb(default_pb_batch, default_bloom_batch):
|
||||
batch = default_bloom_batch
|
||||
|
||||
assert batch.batch_id == default_pb_batch.id
|
||||
assert batch.requests == default_pb_batch.requests
|
||||
|
||||
assert len(batch.input_ids) == default_pb_batch.size
|
||||
assert len(batch.input_ids[0]) == 8
|
||||
assert batch.input_ids[0][-1] == 10264
|
||||
assert torch.all(batch.input_ids[0][:-1] == 3)
|
||||
|
||||
assert batch.attention_mask[0][-1] == 1
|
||||
assert torch.all(batch.attention_mask[0][:-1] == 0)
|
||||
|
||||
assert batch.past_key_values is None
|
||||
|
||||
assert torch.equal(batch.input_ids, batch.all_input_ids[:, :, 0])
|
||||
|
||||
assert batch.input_lengths == [1]
|
||||
|
||||
assert batch.size == default_pb_batch.size
|
||||
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
|
||||
|
||||
assert batch.max_sequence_length == batch.input_lengths[0]
|
||||
|
||||
|
||||
def test_batch_concatenate_no_prefill(default_bloom_batch):
|
||||
with pytest.raises(ValueError):
|
||||
BloomCausalLMBatch.concatenate([default_bloom_batch, default_bloom_batch])
|
||||
|
||||
|
||||
def test_causal_lm_batch_type(default_bloom):
|
||||
assert default_bloom.batch_type == BloomCausalLMBatch
|
||||
|
||||
|
||||
def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
|
||||
generated_texts, next_batch = default_bloom.generate_token(default_bloom_batch)
|
||||
|
||||
assert generated_texts == []
|
||||
assert isinstance(next_batch, CausalLMBatch)
|
||||
assert not next_batch.keys_head_dim_last
|
||||
|
||||
assert len(next_batch.all_input_ids) == next_batch.size
|
||||
assert len(next_batch.all_input_ids[0]) == len(next_batch.attention_mask[0]) == 9
|
||||
assert torch.all(next_batch.all_input_ids[0][-2:] == 10264)
|
||||
assert torch.all(next_batch.all_input_ids[0][:-2] == 3)
|
||||
|
||||
assert torch.all(next_batch.attention_mask[0][-2:] == 1)
|
||||
assert torch.all(next_batch.attention_mask[0][:-2] == 0)
|
||||
|
||||
assert next_batch.input_ids.shape == (next_batch.size, 1)
|
||||
assert next_batch.input_ids[0, 0] == 10264
|
||||
|
||||
assert next_batch.input_lengths == [2]
|
||||
assert next_batch.max_sequence_length == next_batch.input_lengths[0]
|
||||
|
||||
assert next_batch.past_key_values is not None
|
||||
assert all([p[0].shape == (16, 64, 8) for p in next_batch.past_key_values])
|
||||
assert all([p[1].shape == (16, 8, 64) for p in next_batch.past_key_values])
|
||||
|
||||
|
||||
def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch):
|
||||
next_batch = default_bloom_batch
|
||||
for _ in range(default_bloom_batch.stopping_criterias[0].max_new_tokens - 1):
|
||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
||||
assert generated_texts == []
|
||||
|
||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest"
|
||||
assert generated_texts[0].request == default_bloom_batch.requests[0]
|
||||
assert (
|
||||
generated_texts[0].tokens
|
||||
== default_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||
)
|
||||
|
||||
|
||||
def test_causal_lm_generate_token_completion_multi(
|
||||
default_bloom, default_multi_requests_bloom_batch
|
||||
):
|
||||
next_batch = default_multi_requests_bloom_batch
|
||||
|
||||
for i in range(
|
||||
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 1
|
||||
):
|
||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
||||
assert generated_texts == []
|
||||
|
||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
||||
assert next_batch is not None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output == "TestTestTestTestTestTest"
|
||||
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[1]
|
||||
assert (
|
||||
generated_texts[0].tokens
|
||||
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||
)
|
||||
|
||||
for _ in range(
|
||||
default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||
- 1
|
||||
):
|
||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
||||
assert generated_texts == []
|
||||
|
||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest"
|
||||
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0]
|
||||
assert (
|
||||
generated_texts[0].tokens
|
||||
== default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||
)
|
||||
|
||||
|
||||
def test_batch_concatenate(
|
||||
default_bloom, default_bloom_batch, default_multi_requests_bloom_batch
|
||||
):
|
||||
next_batch_0 = default_bloom_batch
|
||||
_, next_batch_0 = default_bloom.generate_token(next_batch_0)
|
||||
_, next_batch_0 = default_bloom.generate_token(next_batch_0)
|
||||
|
||||
next_batch_1 = default_multi_requests_bloom_batch
|
||||
_, next_batch_1 = default_bloom.generate_token(next_batch_1)
|
||||
|
||||
next_batch = BloomCausalLMBatch.concatenate([next_batch_0, next_batch_1])
|
||||
|
||||
assert torch.equal(next_batch.all_input_ids[0], next_batch_0.all_input_ids[0])
|
||||
assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0])
|
||||
assert torch.equal(next_batch.all_input_ids[2], next_batch_1.all_input_ids[1])
|
||||
|
||||
assert torch.all(next_batch.attention_mask[0] == 1)
|
||||
assert torch.all(next_batch.attention_mask[1:, -2:] == 1)
|
||||
assert torch.all(next_batch.attention_mask[1:, :-2] == 0)
|
||||
|
||||
assert next_batch.batch_id == 0
|
||||
assert torch.all(next_batch.input_ids == 10264)
|
||||
|
||||
assert next_batch.input_lengths == [3, 2, 2]
|
||||
assert next_batch.max_sequence_length == 3
|
||||
|
||||
assert next_batch.requests[0] == next_batch_0.requests[0]
|
||||
assert next_batch.requests[1:] == next_batch_1.requests
|
||||
|
||||
assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0]
|
||||
assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers
|
||||
|
||||
assert next_batch.stopping_criterias[0] == next_batch_0.stopping_criterias[0]
|
||||
assert next_batch.stopping_criterias[1:] == next_batch_1.stopping_criterias
|
||||
|
||||
assert next_batch.past_key_values is not None
|
||||
assert all([p[0].shape == (3, 16, 64, 2) for p in next_batch.past_key_values])
|
||||
assert all([p[1].shape == (3, 16, 2, 64) for p in next_batch.past_key_values])
|
||||
|
||||
for i, past in enumerate(next_batch.past_key_values):
|
||||
assert torch.equal(next_batch_0.past_key_values[i][0][:, :, -2:], past[0][0])
|
||||
assert torch.equal(
|
||||
next_batch_1.past_key_values[i][0][:, :, -1:],
|
||||
past[0][1:, :, :, -1].reshape(-1, 64, 1),
|
||||
)
|
||||
|
||||
assert torch.equal(next_batch_0.past_key_values[i][1][:, -2:, :], past[1][0])
|
||||
assert torch.equal(
|
||||
next_batch_1.past_key_values[i][1][:, -1:, :],
|
||||
past[1][1:, :, -1, :].reshape(-1, 1, 64),
|
||||
)
|
||||
|
||||
for _ in range(
|
||||
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 2
|
||||
):
|
||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
||||
assert generated_texts == []
|
||||
|
||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
||||
assert next_batch is not None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output == "TestTestTestTestTestTest"
|
||||
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[1]
|
||||
assert (
|
||||
generated_texts[0].tokens
|
||||
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||
)
|
||||
|
||||
for _ in range(
|
||||
default_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||
- 2
|
||||
):
|
||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
||||
assert generated_texts == []
|
||||
|
||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
||||
assert next_batch is not None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest"
|
||||
assert generated_texts[0].request == default_bloom_batch.requests[0]
|
||||
assert (
|
||||
generated_texts[0].tokens
|
||||
== default_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||
)
|
||||
|
||||
for _ in range(
|
||||
default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||
- default_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||
- 4
|
||||
):
|
||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
||||
assert generated_texts == []
|
||||
|
||||
generated_texts, next_batch = default_bloom.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest"
|
||||
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0]
|
||||
assert (
|
||||
generated_texts[0].tokens
|
||||
== default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||
)
|
|
@ -0,0 +1,296 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from copy import copy
|
||||
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation.models.causal_lm import CausalLM, CausalLMBatch
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_pb_request(default_pb_parameters):
|
||||
return generate_pb2.Request(
|
||||
id=0,
|
||||
inputs="Test",
|
||||
input_length=1,
|
||||
parameters=default_pb_parameters,
|
||||
max_new_tokens=10,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_pb_batch(default_pb_request):
|
||||
return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer):
|
||||
return CausalLMBatch.from_pb(default_pb_batch, gpt2_tokenizer, torch.device("cpu"))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer):
|
||||
req_0 = copy(default_pb_request)
|
||||
req_1 = default_pb_request
|
||||
req_1.id = 1
|
||||
req_1.max_new_tokens = 5
|
||||
|
||||
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
|
||||
return CausalLMBatch.from_pb(batch_pb, gpt2_tokenizer, torch.device("cpu"))
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def default_causal_lm():
|
||||
return CausalLM("gpt2")
|
||||
|
||||
|
||||
def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
|
||||
batch = default_causal_lm_batch
|
||||
|
||||
assert batch.batch_id == default_pb_batch.id
|
||||
assert batch.requests == default_pb_batch.requests
|
||||
|
||||
assert len(batch.input_ids) == default_pb_batch.size
|
||||
assert len(batch.input_ids[0]) == 8
|
||||
assert batch.input_ids[0][-1] == 14402
|
||||
assert torch.all(batch.input_ids[0][:-1] == 50256)
|
||||
|
||||
assert batch.attention_mask[0][-1] == 1
|
||||
assert torch.all(batch.attention_mask[0][:-1] == 0)
|
||||
|
||||
assert batch.past_key_values is None
|
||||
|
||||
assert torch.equal(batch.input_ids, batch.all_input_ids[:, :, 0])
|
||||
|
||||
assert batch.input_lengths == [1]
|
||||
|
||||
assert batch.size == default_pb_batch.size
|
||||
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
|
||||
|
||||
assert batch.max_sequence_length == batch.input_lengths[0]
|
||||
|
||||
|
||||
def test_batch_concatenate_no_prefill(default_causal_lm_batch):
|
||||
with pytest.raises(ValueError):
|
||||
CausalLMBatch.concatenate([default_causal_lm_batch, default_causal_lm_batch])
|
||||
|
||||
|
||||
def test_causal_lm_batch_type(default_causal_lm):
|
||||
assert default_causal_lm.batch_type == CausalLMBatch
|
||||
|
||||
|
||||
def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
|
||||
generated_texts, next_batch = default_causal_lm.generate_token(
|
||||
default_causal_lm_batch
|
||||
)
|
||||
|
||||
assert generated_texts == []
|
||||
assert isinstance(next_batch, CausalLMBatch)
|
||||
|
||||
assert len(next_batch.all_input_ids) == next_batch.size
|
||||
assert len(next_batch.all_input_ids[0]) == len(next_batch.attention_mask[0]) == 9
|
||||
assert next_batch.all_input_ids[0][-1] == 6208
|
||||
assert next_batch.all_input_ids[0][-2] == 14402
|
||||
assert torch.all(next_batch.all_input_ids[0][:-2] == 50256)
|
||||
|
||||
assert torch.all(next_batch.attention_mask[0][-2:] == 1)
|
||||
assert torch.all(next_batch.attention_mask[0][:-2] == 0)
|
||||
|
||||
assert next_batch.input_ids.shape == (next_batch.size, 1)
|
||||
assert next_batch.input_ids[0, 0] == 6208
|
||||
|
||||
assert next_batch.input_lengths == [2]
|
||||
assert next_batch.max_sequence_length == next_batch.input_lengths[0]
|
||||
|
||||
assert next_batch.past_key_values is not None
|
||||
assert all([p[0].shape == (1, 12, 8, 64) for p in next_batch.past_key_values])
|
||||
assert all([p[1].shape == (1, 12, 8, 64) for p in next_batch.past_key_values])
|
||||
|
||||
|
||||
def test_causal_lm_generate_token_completion(
|
||||
default_causal_lm, default_causal_lm_batch
|
||||
):
|
||||
next_batch = default_causal_lm_batch
|
||||
for _ in range(default_causal_lm_batch.stopping_criterias[0].max_new_tokens - 1):
|
||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
assert generated_texts == []
|
||||
|
||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert (
|
||||
generated_texts[0].output
|
||||
== "Test Test Test Test Test Test Test Test Test Test Test"
|
||||
)
|
||||
assert generated_texts[0].request == default_causal_lm_batch.requests[0]
|
||||
assert (
|
||||
generated_texts[0].tokens
|
||||
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||
)
|
||||
|
||||
|
||||
def test_causal_lm_generate_token_completion_multi(
|
||||
default_causal_lm, default_multi_requests_causal_lm_batch
|
||||
):
|
||||
next_batch = default_multi_requests_causal_lm_batch
|
||||
|
||||
for i in range(
|
||||
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1
|
||||
):
|
||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
assert generated_texts == []
|
||||
|
||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
assert next_batch is not None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output == "Test Test Test Test Test Test"
|
||||
assert (
|
||||
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1]
|
||||
)
|
||||
assert (
|
||||
generated_texts[0].tokens
|
||||
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||
)
|
||||
|
||||
for _ in range(
|
||||
default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||
- 1
|
||||
):
|
||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
assert generated_texts == []
|
||||
|
||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert (
|
||||
generated_texts[0].output
|
||||
== "Test Test Test Test Test Test Test Test Test Test Test"
|
||||
)
|
||||
assert (
|
||||
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0]
|
||||
)
|
||||
assert (
|
||||
generated_texts[0].tokens
|
||||
== default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||
)
|
||||
|
||||
|
||||
def test_batch_concatenate(
|
||||
default_causal_lm, default_causal_lm_batch, default_multi_requests_causal_lm_batch
|
||||
):
|
||||
next_batch_0 = default_causal_lm_batch
|
||||
_, next_batch_0 = default_causal_lm.generate_token(next_batch_0)
|
||||
_, next_batch_0 = default_causal_lm.generate_token(next_batch_0)
|
||||
|
||||
next_batch_1 = default_multi_requests_causal_lm_batch
|
||||
_, next_batch_1 = default_causal_lm.generate_token(next_batch_1)
|
||||
|
||||
next_batch = CausalLMBatch.concatenate([next_batch_0, next_batch_1])
|
||||
|
||||
assert torch.equal(next_batch.all_input_ids[0], next_batch_0.all_input_ids[0])
|
||||
assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0])
|
||||
assert torch.equal(next_batch.all_input_ids[2], next_batch_1.all_input_ids[1])
|
||||
|
||||
assert torch.all(next_batch.attention_mask[0] == 1)
|
||||
assert torch.all(next_batch.attention_mask[1:, -2:] == 1)
|
||||
assert torch.all(next_batch.attention_mask[1:, :-2] == 0)
|
||||
|
||||
assert next_batch.batch_id == 0
|
||||
assert torch.all(next_batch.input_ids == 6208)
|
||||
|
||||
assert next_batch.input_lengths == [3, 2, 2]
|
||||
assert next_batch.max_sequence_length == 3
|
||||
|
||||
assert next_batch.requests[0] == next_batch_0.requests[0]
|
||||
assert next_batch.requests[1:] == next_batch_1.requests
|
||||
|
||||
assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0]
|
||||
assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers
|
||||
|
||||
assert next_batch.stopping_criterias[0] == next_batch_0.stopping_criterias[0]
|
||||
assert next_batch.stopping_criterias[1:] == next_batch_1.stopping_criterias
|
||||
|
||||
assert next_batch.past_key_values is not None
|
||||
assert all([p[0].shape == (3, 12, 2, 64) for p in next_batch.past_key_values])
|
||||
assert all([p[1].shape == (3, 12, 2, 64) for p in next_batch.past_key_values])
|
||||
|
||||
for i, past in enumerate(next_batch.past_key_values):
|
||||
assert torch.equal(next_batch_0.past_key_values[i][0][0, :, -2:], past[0][0])
|
||||
assert torch.equal(
|
||||
next_batch_1.past_key_values[i][0][:, :, -1:], past[0][1:, :, -1:, :]
|
||||
)
|
||||
|
||||
assert torch.equal(next_batch_0.past_key_values[i][1][0, :, -2:], past[1][0])
|
||||
assert torch.equal(
|
||||
next_batch_1.past_key_values[i][1][:, :, -1:], past[1][1:, :, -1:, :]
|
||||
)
|
||||
|
||||
for _ in range(
|
||||
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2
|
||||
):
|
||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
assert generated_texts == []
|
||||
|
||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
assert next_batch is not None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output == "Test Test Test Test Test Test"
|
||||
assert (
|
||||
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1]
|
||||
)
|
||||
assert (
|
||||
generated_texts[0].tokens
|
||||
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||
)
|
||||
|
||||
for _ in range(
|
||||
default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||
- 2
|
||||
):
|
||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
assert generated_texts == []
|
||||
|
||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
assert next_batch is not None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert (
|
||||
generated_texts[0].output
|
||||
== "Test Test Test Test Test Test Test Test Test Test Test"
|
||||
)
|
||||
assert generated_texts[0].request == default_causal_lm_batch.requests[0]
|
||||
assert (
|
||||
generated_texts[0].tokens
|
||||
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||
)
|
||||
|
||||
for _ in range(
|
||||
default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||
- default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||
- 4
|
||||
):
|
||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
assert generated_texts == []
|
||||
|
||||
generated_texts, next_batch = default_causal_lm.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert (
|
||||
generated_texts[0].output
|
||||
== "Test Test Test Test Test Test Test Test Test Test Test"
|
||||
)
|
||||
assert (
|
||||
generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0]
|
||||
)
|
||||
assert (
|
||||
generated_texts[0].tokens
|
||||
== default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||
)
|
|
@ -0,0 +1,306 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from copy import copy
|
||||
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_pb_request(default_pb_parameters):
|
||||
return generate_pb2.Request(
|
||||
id=0,
|
||||
inputs="Test",
|
||||
input_length=2,
|
||||
parameters=default_pb_parameters,
|
||||
max_new_tokens=10,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_pb_batch(default_pb_request):
|
||||
return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_seq2seq_lm_batch(default_pb_batch, mt0_small_tokenizer):
|
||||
return Seq2SeqLMBatch.from_pb(
|
||||
default_pb_batch, mt0_small_tokenizer, torch.device("cpu")
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokenizer):
|
||||
req_0 = copy(default_pb_request)
|
||||
req_1 = default_pb_request
|
||||
req_1.id = 1
|
||||
req_1.max_new_tokens = 5
|
||||
|
||||
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
|
||||
return Seq2SeqLMBatch.from_pb(batch_pb, mt0_small_tokenizer, torch.device("cpu"))
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def default_seq2seq_lm():
|
||||
return Seq2SeqLM("bigscience/mt0-small")
|
||||
|
||||
|
||||
def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch):
|
||||
batch = default_seq2seq_lm_batch
|
||||
|
||||
assert batch.batch_id == default_pb_batch.id
|
||||
assert batch.requests == default_pb_batch.requests
|
||||
|
||||
assert batch.input_ids.shape == (default_pb_batch.size, 8)
|
||||
assert batch.input_ids[0][-2] == 4268
|
||||
assert batch.input_ids[0][-1] == 1
|
||||
assert torch.all(batch.input_ids[0][:-2] == 0)
|
||||
|
||||
assert torch.all(batch.attention_mask[0][-2:] == 1)
|
||||
assert torch.all(batch.attention_mask[0][:-2] == 0)
|
||||
|
||||
assert batch.decoder_input_ids.shape == (default_pb_batch.size, 1)
|
||||
assert batch.decoder_attention_mask is None
|
||||
assert batch.encoder_last_hidden_state is None
|
||||
|
||||
assert batch.past_key_values is None
|
||||
|
||||
assert batch.input_lengths == [2]
|
||||
assert batch.decoder_input_lengths == [1]
|
||||
|
||||
assert batch.size == default_pb_batch.size
|
||||
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size
|
||||
|
||||
assert batch.max_input_length == batch.input_lengths[0]
|
||||
assert batch.max_decoder_input_length == batch.decoder_input_lengths[0]
|
||||
|
||||
|
||||
def test_batch_concatenate_no_prefill(default_seq2seq_lm_batch):
|
||||
with pytest.raises(ValueError):
|
||||
Seq2SeqLMBatch.concatenate([default_seq2seq_lm_batch, default_seq2seq_lm_batch])
|
||||
|
||||
|
||||
def test_seq2seq_lm_batch_type(default_seq2seq_lm):
|
||||
assert default_seq2seq_lm.batch_type == Seq2SeqLMBatch
|
||||
|
||||
|
||||
def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch):
|
||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(
|
||||
default_seq2seq_lm_batch
|
||||
)
|
||||
|
||||
assert generated_texts == []
|
||||
assert isinstance(next_batch, Seq2SeqLMBatch)
|
||||
|
||||
assert torch.equal(next_batch.input_ids, default_seq2seq_lm_batch.input_ids)
|
||||
assert torch.equal(
|
||||
next_batch.attention_mask, default_seq2seq_lm_batch.attention_mask
|
||||
)
|
||||
assert next_batch.input_lengths == default_seq2seq_lm_batch.input_lengths
|
||||
assert next_batch.max_input_length == default_seq2seq_lm_batch.max_input_length
|
||||
assert (
|
||||
next_batch.next_token_choosers == default_seq2seq_lm_batch.next_token_choosers
|
||||
)
|
||||
assert next_batch.stopping_criterias == default_seq2seq_lm_batch.stopping_criterias
|
||||
|
||||
assert next_batch.decoder_input_ids.shape == (next_batch.size, 2)
|
||||
assert next_batch.decoder_input_ids[0, 0] == 0
|
||||
assert next_batch.decoder_input_ids[0, 1] == 259
|
||||
assert next_batch.decoder_attention_mask is None
|
||||
assert next_batch.encoder_last_hidden_state.shape == (1, 8, 512)
|
||||
|
||||
assert next_batch.decoder_input_lengths == [2]
|
||||
assert next_batch.max_decoder_input_length == 2
|
||||
|
||||
assert next_batch.past_key_values is not None
|
||||
assert all(
|
||||
[p[0].shape == (next_batch.size, 6, 1, 64) for p in next_batch.past_key_values]
|
||||
)
|
||||
assert all(
|
||||
[p[1].shape == (next_batch.size, 6, 1, 64) for p in next_batch.past_key_values]
|
||||
)
|
||||
assert all(
|
||||
[p[2].shape == (next_batch.size, 6, 8, 64) for p in next_batch.past_key_values]
|
||||
)
|
||||
assert all(
|
||||
[p[3].shape == (next_batch.size, 6, 8, 64) for p in next_batch.past_key_values]
|
||||
)
|
||||
|
||||
|
||||
def test_seq2seq_lm_generate_token_completion(
|
||||
default_seq2seq_lm, default_seq2seq_lm_batch
|
||||
):
|
||||
next_batch = default_seq2seq_lm_batch
|
||||
for _ in range(6):
|
||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert generated_texts == []
|
||||
|
||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output == "a few weeks"
|
||||
assert generated_texts[0].request == default_seq2seq_lm_batch.requests[0]
|
||||
assert generated_texts[0].tokens == 7
|
||||
|
||||
|
||||
def test_seq2seq_lm_generate_token_completion_multi(
|
||||
default_seq2seq_lm, default_multi_requests_seq2seq_lm_batch
|
||||
):
|
||||
next_batch = default_multi_requests_seq2seq_lm_batch
|
||||
|
||||
for i in range(4):
|
||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert generated_texts == []
|
||||
|
||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert next_batch is not None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output == "a few "
|
||||
assert (
|
||||
generated_texts[0].request
|
||||
== default_multi_requests_seq2seq_lm_batch.requests[1]
|
||||
)
|
||||
assert generated_texts[0].tokens == 5
|
||||
|
||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert generated_texts == []
|
||||
|
||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output == "a few weeks"
|
||||
assert (
|
||||
generated_texts[0].request
|
||||
== default_multi_requests_seq2seq_lm_batch.requests[0]
|
||||
)
|
||||
assert generated_texts[0].tokens == 7
|
||||
|
||||
|
||||
def test_batch_concatenate(
|
||||
default_seq2seq_lm,
|
||||
default_seq2seq_lm_batch,
|
||||
default_multi_requests_seq2seq_lm_batch,
|
||||
):
|
||||
next_batch_0 = default_seq2seq_lm_batch
|
||||
_, next_batch_0 = default_seq2seq_lm.generate_token(next_batch_0)
|
||||
_, next_batch_0 = default_seq2seq_lm.generate_token(next_batch_0)
|
||||
|
||||
next_batch_1 = default_multi_requests_seq2seq_lm_batch
|
||||
_, next_batch_1 = default_seq2seq_lm.generate_token(next_batch_1)
|
||||
|
||||
next_batch = Seq2SeqLMBatch.concatenate([next_batch_0, next_batch_1])
|
||||
|
||||
assert next_batch.batch_id == 0
|
||||
|
||||
assert torch.all(next_batch.input_ids[:, 0] == 4268)
|
||||
assert torch.all(next_batch.input_ids[:, 1] == 1)
|
||||
|
||||
assert torch.all(next_batch.attention_mask == 1)
|
||||
|
||||
assert torch.equal(
|
||||
next_batch.decoder_input_ids[0], next_batch_0.decoder_input_ids[0]
|
||||
)
|
||||
assert torch.all(next_batch.decoder_input_ids[1:, 0] == 0)
|
||||
assert torch.equal(
|
||||
next_batch.decoder_input_ids[1:, -2:], next_batch_1.decoder_input_ids
|
||||
)
|
||||
|
||||
assert torch.all(next_batch.decoder_attention_mask[0] == 1)
|
||||
assert torch.all(next_batch.decoder_attention_mask[1:, 0] == 0)
|
||||
assert torch.all(next_batch.decoder_attention_mask[1:, -2:] == 1)
|
||||
|
||||
assert torch.equal(
|
||||
next_batch.encoder_last_hidden_state[0],
|
||||
next_batch_0.encoder_last_hidden_state[0, -2:],
|
||||
)
|
||||
assert torch.equal(
|
||||
next_batch.encoder_last_hidden_state[1:],
|
||||
next_batch_1.encoder_last_hidden_state[:, -2:],
|
||||
)
|
||||
|
||||
assert next_batch.input_lengths == [2, 2, 2]
|
||||
assert next_batch.decoder_input_lengths == [3, 2, 2]
|
||||
assert next_batch.max_input_length == 2
|
||||
assert next_batch.max_decoder_input_length == 3
|
||||
|
||||
assert next_batch.requests[0] == next_batch_0.requests[0]
|
||||
assert next_batch.requests[1:] == next_batch_1.requests
|
||||
|
||||
assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0]
|
||||
assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers
|
||||
|
||||
assert next_batch.stopping_criterias[0] == next_batch_0.stopping_criterias[0]
|
||||
assert next_batch.stopping_criterias[1:] == next_batch_1.stopping_criterias
|
||||
|
||||
assert next_batch.past_key_values is not None
|
||||
assert all(
|
||||
[p[0].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values]
|
||||
)
|
||||
assert all(
|
||||
[p[1].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values]
|
||||
)
|
||||
assert all(
|
||||
[p[2].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values]
|
||||
)
|
||||
assert all(
|
||||
[p[3].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values]
|
||||
)
|
||||
|
||||
for i, past in enumerate(next_batch.past_key_values):
|
||||
assert torch.equal(next_batch_0.past_key_values[i][0][0, :, -2:, :], past[0][0])
|
||||
assert torch.equal(
|
||||
next_batch_1.past_key_values[i][0][:, :, -1:, :], past[0][1:, :, -1:, :]
|
||||
)
|
||||
|
||||
assert torch.equal(next_batch_0.past_key_values[i][1][0, :, -2:, :], past[1][0])
|
||||
assert torch.equal(
|
||||
next_batch_1.past_key_values[i][1][:, :, -1:, :], past[1][1:, :, -1:, :]
|
||||
)
|
||||
|
||||
assert torch.equal(next_batch_0.past_key_values[i][2][0, :, -2:, :], past[2][0])
|
||||
assert torch.equal(
|
||||
next_batch_1.past_key_values[i][2][:, :, -2:, :], past[2][1:]
|
||||
)
|
||||
|
||||
assert torch.equal(next_batch_0.past_key_values[i][3][0, :, -2:, :], past[3][0])
|
||||
assert torch.equal(
|
||||
next_batch_1.past_key_values[i][3][:, :, -2:, :], past[3][1:]
|
||||
)
|
||||
|
||||
for _ in range(3):
|
||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert generated_texts == []
|
||||
|
||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert next_batch is not None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output == "a few "
|
||||
assert (
|
||||
generated_texts[0].request
|
||||
== default_multi_requests_seq2seq_lm_batch.requests[1]
|
||||
)
|
||||
assert generated_texts[0].tokens == 5
|
||||
|
||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert next_batch is not None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output == "a few weeks"
|
||||
assert generated_texts[0].request == default_seq2seq_lm_batch.requests[0]
|
||||
assert generated_texts[0].tokens == 7
|
||||
|
||||
generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch)
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generated_texts) == 1
|
||||
assert generated_texts[0].output == "a few weeks"
|
||||
assert (
|
||||
generated_texts[0].request
|
||||
== default_multi_requests_seq2seq_lm_batch.requests[0]
|
||||
)
|
||||
assert generated_texts[0].tokens == 7
|
|
@ -0,0 +1,34 @@
|
|||
import pytest
|
||||
|
||||
from text_generation.utils import (
|
||||
weight_hub_files,
|
||||
download_weights,
|
||||
weight_files,
|
||||
LocalEntryNotFoundError,
|
||||
)
|
||||
|
||||
|
||||
def test_weight_hub_files():
|
||||
filenames = weight_hub_files("bigscience/bloom-560m")
|
||||
assert filenames == ["model.safetensors"]
|
||||
|
||||
|
||||
def test_weight_hub_files_llm():
|
||||
filenames = weight_hub_files("bigscience/bloom")
|
||||
assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)]
|
||||
|
||||
|
||||
def test_weight_hub_files_empty():
|
||||
filenames = weight_hub_files("bigscience/bloom", ".errors")
|
||||
assert filenames == []
|
||||
|
||||
|
||||
def test_download_weights():
|
||||
files = download_weights("bigscience/bloom-560m")
|
||||
local_files = weight_files("bigscience/bloom-560m")
|
||||
assert files == local_files
|
||||
|
||||
|
||||
def test_weight_files_error():
|
||||
with pytest.raises(LocalEntryNotFoundError):
|
||||
weight_files("bert-base-uncased")
|
|
@ -1,10 +1,10 @@
|
|||
from text_generation.models.model import Model
|
||||
from text_generation.models.causal_lm import CausalLM
|
||||
from text_generation.models.bloom import BLOOMSharded
|
||||
from text_generation.models.bloom import BLOOM, BLOOMSharded
|
||||
from text_generation.models.seq2seq_lm import Seq2SeqLM
|
||||
from text_generation.models.galactica import Galactica, GalacticaSharded
|
||||
|
||||
__all__ = ["Model", "BLOOMSharded", "CausalLM", "Seq2SeqLM"]
|
||||
__all__ = ["Model", "BLOOM", "BLOOMSharded", "CausalLM", "Seq2SeqLM"]
|
||||
|
||||
|
||||
def get_model(model_name: str, sharded: bool, quantize: bool) -> Model:
|
||||
|
@ -12,7 +12,7 @@ def get_model(model_name: str, sharded: bool, quantize: bool) -> Model:
|
|||
if sharded:
|
||||
return BLOOMSharded(model_name, quantize=quantize)
|
||||
else:
|
||||
return CausalLM(model_name, quantize=quantize)
|
||||
return BLOOM(model_name, quantize=quantize)
|
||||
elif model_name.startswith("facebook/galactica"):
|
||||
if sharded:
|
||||
return GalacticaSharded(model_name, quantize=quantize)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from accelerate import init_empty_weights
|
||||
from safetensors import safe_open
|
||||
|
@ -13,6 +13,8 @@ from transformers.models.bloom.parallel_layers import (
|
|||
)
|
||||
|
||||
from text_generation.models import CausalLM
|
||||
from text_generation.models.causal_lm import CausalLMBatch
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
|
@ -29,7 +31,25 @@ except Exception as e:
|
|||
torch.manual_seed(0)
|
||||
|
||||
|
||||
class BLOOMSharded(CausalLM):
|
||||
class BloomCausalLMBatch(CausalLMBatch):
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
|
||||
) -> "CausalLMBatch":
|
||||
batch = super(BloomCausalLMBatch, cls).from_pb(
|
||||
pb=pb, tokenizer=tokenizer, device=device
|
||||
)
|
||||
batch.keys_head_dim_last = False
|
||||
return batch
|
||||
|
||||
|
||||
class BLOOM(CausalLM):
|
||||
@property
|
||||
def batch_type(self) -> Type[CausalLMBatch]:
|
||||
return BloomCausalLMBatch
|
||||
|
||||
|
||||
class BLOOMSharded(BLOOM):
|
||||
def __init__(self, model_name: str, quantize: bool = False):
|
||||
if not model_name.startswith("bigscience/bloom"):
|
||||
raise ValueError(f"Model {model_name} is not supported")
|
||||
|
|
|
@ -34,6 +34,9 @@ class CausalLMBatch:
|
|||
size: int
|
||||
max_sequence_length: int
|
||||
|
||||
# Past metadata
|
||||
keys_head_dim_last: bool = True
|
||||
|
||||
def to_pb(self):
|
||||
return generate_pb2.Batch(
|
||||
id=self.batch_id,
|
||||
|
@ -165,20 +168,16 @@ class CausalLMBatch:
|
|||
head_dim,
|
||||
)
|
||||
|
||||
if batch.keys_head_dim_last:
|
||||
padded_past_keys_shape = padded_past_values_shape
|
||||
# seq_length is last for BLOOM
|
||||
if past_keys.shape[-2] == head_dim:
|
||||
past_keys_head_dim_last = False
|
||||
else:
|
||||
padded_past_keys_shape = (
|
||||
total_batch_size,
|
||||
num_heads,
|
||||
head_dim,
|
||||
max_sequence_length - 1,
|
||||
)
|
||||
elif past_keys.shape[-1] == head_dim:
|
||||
past_keys_head_dim_last = True
|
||||
padded_past_keys_shape = padded_past_values_shape
|
||||
else:
|
||||
raise ValueError(f"past_keys shape {past_keys.shape} is not valid")
|
||||
|
||||
# This will run only once per layer
|
||||
if j == len(past_key_values):
|
||||
|
@ -195,7 +194,7 @@ class CausalLMBatch:
|
|||
past_key_values.append((padded_past_keys, padded_past_values))
|
||||
|
||||
# We slice the past keys and values to remove the padding from previous batches
|
||||
if past_keys_head_dim_last:
|
||||
if batch.keys_head_dim_last:
|
||||
past_key_values[j][0][
|
||||
start_index:end_index,
|
||||
:,
|
||||
|
@ -228,6 +227,7 @@ class CausalLMBatch:
|
|||
stopping_criterias=stopping_criterias,
|
||||
size=total_batch_size,
|
||||
max_sequence_length=max_sequence_length,
|
||||
keys_head_dim_last=batches[0].keys_head_dim_last,
|
||||
)
|
||||
|
||||
|
||||
|
@ -237,6 +237,9 @@ class CausalLM(Model):
|
|||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
||||
|
@ -247,7 +250,11 @@ class CausalLM(Model):
|
|||
device_map="auto" if torch.cuda.is_available() else None,
|
||||
load_in_8bit=quantize,
|
||||
).eval()
|
||||
tokenizer.pad_token_id = self.model.config.pad_token_id
|
||||
tokenizer.pad_token_id = (
|
||||
self.model.config.pad_token_id
|
||||
if self.model.config.pad_token_id is not None
|
||||
else self.model.config.eos_token_id
|
||||
)
|
||||
|
||||
super(CausalLM, self).__init__(
|
||||
tokenizer=tokenizer,
|
||||
|
@ -397,5 +404,6 @@ class CausalLM(Model):
|
|||
stopping_criterias=next_batch_stopping_criterias,
|
||||
size=next_batch_size,
|
||||
max_sequence_length=next_batch_max_sequence_length,
|
||||
keys_head_dim_last=batch.keys_head_dim_last,
|
||||
)
|
||||
return generated_texts, next_batch
|
||||
|
|
|
@ -83,7 +83,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||
@classmethod
|
||||
def from_pb(
|
||||
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
|
||||
) -> "CausalLMBatch":
|
||||
) -> "GalacticaCausalLMBatch":
|
||||
inputs = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
|
|
|
@ -221,8 +221,8 @@ class Seq2SeqLMBatch:
|
|||
|
||||
# Copy to correct indices
|
||||
encoder_last_hidden_state[
|
||||
start_index:end_index, -batch.max_decoder_input_length :, :
|
||||
] = batch.encoder_last_hidden_state[:, -batch.max_decoder_input_length :, :]
|
||||
start_index:end_index, -batch.max_input_length :, :
|
||||
] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :]
|
||||
|
||||
# Iterate over attention layers
|
||||
for j, past in enumerate(batch.past_key_values):
|
||||
|
@ -305,6 +305,9 @@ class Seq2SeqLM(Model):
|
|||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
||||
|
|
|
@ -137,8 +137,8 @@ def download_weights(model_name, extension=".safetensors"):
|
|||
executor.submit(download_function, filename=filename) for filename in filenames
|
||||
]
|
||||
files = [
|
||||
file
|
||||
for file in tqdm(concurrent.futures.as_completed(futures), total=len(futures))
|
||||
future.result()
|
||||
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures))
|
||||
]
|
||||
|
||||
return files
|
||||
|
|
Loading…
Reference in New Issue