diff --git a/README.md b/README.md index e76e38d0..bb58e281 100644 --- a/README.md +++ b/README.md @@ -87,9 +87,4 @@ curl 127.0.0.1:3000/generate \ ```shell make server-dev make router-dev -``` - -## TODO: - -- [ ] Add tests for the `server/model` logic -- [ ] Backport custom CUDA kernels to Transformers \ No newline at end of file +``` \ No newline at end of file diff --git a/router/src/batcher.rs b/router/src/batcher.rs index a9b892cc..0c85a406 100644 --- a/router/src/batcher.rs +++ b/router/src/batcher.rs @@ -70,7 +70,7 @@ impl Batcher { // Notify the background task that we have a new entry in the database that needs // to be batched - self.shared.batching_task.notify_waiters(); + self.shared.batching_task.notify_one(); // Await on the response from the background task // We can safely unwrap as the background task will never drop the sender diff --git a/server/Makefile b/server/Makefile index 10028dad..ac9cb785 100644 --- a/server/Makefile +++ b/server/Makefile @@ -8,8 +8,9 @@ gen-server: install-transformers: # Install specific version of transformers with custom cuda kernels - rm transformers || true - rm transformers-text_generation_inference || true + pip uninstall transformers -y || true + rm -rf transformers || true + rm -rf transformers-text_generation_inference || true curl -L -O https://github.com/OlivierDehaene/transformers/archive/refs/heads/text_generation_inference.zip unzip text_generation_inference.zip rm text_generation_inference.zip diff --git a/server/poetry.lock b/server/poetry.lock index ebd64ea9..c141612e 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -22,6 +22,20 @@ test_prod = ["parameterized", "pytest", "pytest-subtests", "pytest-xdist"] test_trackers = ["comet-ml", "tensorboard", "wandb"] testing = ["datasets", "deepspeed (<0.7.0)", "evaluate", "parameterized", "pytest", "pytest-subtests", "pytest-xdist", "scipy", "sklearn", "tqdm", "transformers"] +[[package]] +name = "attrs" +version = "22.1.0" +description = "Classes Without Boilerplate" +category = "dev" +optional = false +python-versions = ">=3.5" + +[package.extras] +dev = ["cloudpickle", "coverage[toml] (>=5.0.2)", "furo", "hypothesis", "mypy (>=0.900,!=0.940)", "pre-commit", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "sphinx", "sphinx-notfound-page", "zope.interface"] +docs = ["furo", "sphinx", "sphinx-notfound-page", "zope.interface"] +tests = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "zope.interface"] +tests_no_zope = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins"] + [[package]] name = "bitsandbytes" version = "0.35.1" @@ -49,6 +63,17 @@ category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +[[package]] +name = "exceptiongroup" +version = "1.0.4" +description = "Backport of PEP 654 (exception groups)" +category = "dev" +optional = false +python-versions = ">=3.7" + +[package.extras] +test = ["pytest (>=6)"] + [[package]] name = "grpcio" version = "1.50.0" @@ -88,6 +113,14 @@ grpcio = ">=1.50.0" protobuf = ">=4.21.6,<5.0dev" setuptools = "*" +[[package]] +name = "iniconfig" +version = "1.1.1" +description = "iniconfig: brain-dead simple config-ini parsing" +category = "dev" +optional = false +python-versions = "*" + [[package]] name = "numpy" version = "1.23.4" @@ -107,6 +140,18 @@ python-versions = ">=3.6" [package.dependencies] pyparsing = ">=2.0.2,<3.0.5 || >3.0.5" +[[package]] +name = "pluggy" +version = "1.0.0" +description = "plugin and hook calling mechanisms for python" +category = "dev" +optional = false +python-versions = ">=3.6" + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["pytest", "pytest-benchmark"] + [[package]] name = "protobuf" version = "4.21.8" @@ -137,6 +182,26 @@ python-versions = ">=3.6.8" [package.extras] diagrams = ["jinja2", "railroad-diagrams"] +[[package]] +name = "pytest" +version = "7.2.0" +description = "pytest: simple powerful testing with Python" +category = "dev" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +attrs = ">=19.2.0" +colorama = {version = "*", markers = "sys_platform == \"win32\""} +exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} +iniconfig = "*" +packaging = "*" +pluggy = ">=0.12,<2.0" +tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} + +[package.extras] +testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "xmlschema"] + [[package]] name = "PyYAML" version = "6.0" @@ -178,6 +243,14 @@ category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" +[[package]] +name = "tomli" +version = "2.0.1" +description = "A lil' TOML parser" +category = "dev" +optional = false +python-versions = ">=3.7" + [[package]] name = "torch" version = "1.12.1" @@ -220,13 +293,17 @@ bnb = ["bitsandbytes"] [metadata] lock-version = "1.1" python-versions = "^3.9" -content-hash = "3266187ef14fe8f9e29b3b6530d07781ea952aa670c0fe0de34be43efa231a67" +content-hash = "51693654531e3229ac64bee250932ace20a60e8d45af074ae7b860ed32b25ef8" [metadata.files] accelerate = [ {file = "accelerate-0.12.0-py3-none-any.whl", hash = "sha256:7742ca5c9f15dd1e0a283305599c196e260af4717a561d1f544aeab27d828af6"}, {file = "accelerate-0.12.0.tar.gz", hash = "sha256:e8b119c94fac31877620d5f9de311164ec81fa9dc9e175f0d0d4f50fc8d79473"}, ] +attrs = [ + {file = "attrs-22.1.0-py2.py3-none-any.whl", hash = "sha256:86efa402f67bf2df34f51a335487cf46b1ec130d02b8d39fd248abfd30da551c"}, + {file = "attrs-22.1.0.tar.gz", hash = "sha256:29adc2665447e5191d0e7c568fde78b21f9672d344281d0c6e1ab085429b22b6"}, +] bitsandbytes = [ {file = "bitsandbytes-0.35.1-py3-none-any.whl", hash = "sha256:4506a9e3778359a743938aa5592d8d043fa91d1df66cd01ba8cc6486e64dea45"}, {file = "bitsandbytes-0.35.1.tar.gz", hash = "sha256:63a6f59c87b713a731a685e43d68c19789ee6381e62196cafab293b87eca5d46"}, @@ -239,6 +316,10 @@ colorama = [ {file = "colorama-0.4.5-py2.py3-none-any.whl", hash = "sha256:854bf444933e37f5824ae7bfc1e98d5bce2ebe4160d46b5edf346a89358e99da"}, {file = "colorama-0.4.5.tar.gz", hash = "sha256:e6c6b4334fc50988a639d9b98aa429a0b57da6e17b9a44f0451f930b6967b7a4"}, ] +exceptiongroup = [ + {file = "exceptiongroup-1.0.4-py3-none-any.whl", hash = "sha256:542adf9dea4055530d6e1279602fa5cb11dab2395fa650b8674eaec35fc4a828"}, + {file = "exceptiongroup-1.0.4.tar.gz", hash = "sha256:bd14967b79cd9bdb54d97323216f8fdf533e278df937aa2a90089e7d6e06e5ec"}, +] grpcio = [ {file = "grpcio-1.50.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:906f4d1beb83b3496be91684c47a5d870ee628715227d5d7c54b04a8de802974"}, {file = "grpcio-1.50.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:2d9fd6e38b16c4d286a01e1776fdf6c7a4123d99ae8d6b3f0b4a03a34bf6ce45"}, @@ -337,6 +418,10 @@ grpcio-tools = [ {file = "grpcio_tools-1.50.0-cp39-cp39-win32.whl", hash = "sha256:e1a8f9a57bbcc2e633aaf327e39830527f3c1f7add18c7580f3058fe9a0fa780"}, {file = "grpcio_tools-1.50.0-cp39-cp39-win_amd64.whl", hash = "sha256:b7eb7a84d9171c0ae1550833f4a6ca52372bed9db0fa10f8c9dbe6ca65f97a8c"}, ] +iniconfig = [ + {file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"}, + {file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"}, +] numpy = [ {file = "numpy-1.23.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:95d79ada05005f6f4f337d3bb9de8a7774f259341c70bc88047a1f7b96a4bcb2"}, {file = "numpy-1.23.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:926db372bc4ac1edf81cfb6c59e2a881606b409ddc0d0920b988174b2e2a767f"}, @@ -371,6 +456,10 @@ packaging = [ {file = "packaging-21.3-py3-none-any.whl", hash = "sha256:ef103e05f519cdc783ae24ea4e2e0f508a9c99b2d4969652eed6a2e1ea5bd522"}, {file = "packaging-21.3.tar.gz", hash = "sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb"}, ] +pluggy = [ + {file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"}, + {file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"}, +] protobuf = [ {file = "protobuf-4.21.8-cp310-abi3-win32.whl", hash = "sha256:c252c55ee15175aa1b21b7b9896e6add5162d066d5202e75c39f96136f08cce3"}, {file = "protobuf-4.21.8-cp310-abi3-win_amd64.whl", hash = "sha256:809ca0b225d3df42655a12f311dd0f4148a943c51f1ad63c38343e457492b689"}, @@ -429,6 +518,10 @@ pyparsing = [ {file = "pyparsing-3.0.9-py3-none-any.whl", hash = "sha256:5026bae9a10eeaefb61dab2f09052b9f4307d44aee4eda64b309723d8d206bbc"}, {file = "pyparsing-3.0.9.tar.gz", hash = "sha256:2b020ecf7d21b687f219b71ecad3631f644a47f01403fa1d1036b0c6416d70fb"}, ] +pytest = [ + {file = "pytest-7.2.0-py3-none-any.whl", hash = "sha256:892f933d339f068883b6fd5a459f03d85bfcb355e4981e146d2c7616c21fef71"}, + {file = "pytest-7.2.0.tar.gz", hash = "sha256:c4014eb40e10f11f355ad4e3c2fb2c6c6d1919c73f3b5a433de4708202cade59"}, +] PyYAML = [ {file = "PyYAML-6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d4db7c7aef085872ef65a8fd7d6d09a14ae91f691dec3e87ee5ee0539d516f53"}, {file = "PyYAML-6.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9df7ed3b3d2e0ecfe09e14741b857df43adb5a3ddadc919a2d94fbdf78fea53c"}, @@ -512,6 +605,10 @@ six = [ {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, ] +tomli = [ + {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, + {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, +] torch = [ {file = "torch-1.12.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:9c038662db894a23e49e385df13d47b2a777ffd56d9bcd5b832593fab0a7e286"}, {file = "torch-1.12.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:4e1b9c14cf13fd2ab8d769529050629a0e68a6fc5cb8e84b4a3cc1dd8c4fe541"}, diff --git a/server/pyproject.toml b/server/pyproject.toml index cdf89869..f2628854 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -22,6 +22,7 @@ bnb = ["bitsandbytes"] [tool.poetry.group.dev.dependencies] grpcio-tools = "^1.49.1" +pytest = "^7.2.0" [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/server/tests/conftest.py b/server/tests/conftest.py new file mode 100644 index 00000000..0640d45d --- /dev/null +++ b/server/tests/conftest.py @@ -0,0 +1,36 @@ +import pytest + +from transformers import AutoTokenizer + +from text_generation.pb import generate_pb2 + + +@pytest.fixture +def default_pb_parameters(): + return generate_pb2.LogitsWarperParameters( + temperature=1.0, + top_k=0, + top_p=1.0, + do_sample=False, + ) + + +@pytest.fixture(scope="session") +def bloom_560m_tokenizer(): + return AutoTokenizer.from_pretrained("bigscience/bloom-560m", padding_side="left") + + +@pytest.fixture(scope="session") +def gpt2_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left") + tokenizer.pad_token_id = 50256 + return tokenizer + + +@pytest.fixture(scope="session") +def mt0_small_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained( + "bigscience/mt0-small", padding_side="left" + ) + tokenizer.bos_token_id = 0 + return tokenizer diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py new file mode 100644 index 00000000..49dabb14 --- /dev/null +++ b/server/tests/models/test_bloom.py @@ -0,0 +1,279 @@ +import pytest +import torch + +from copy import copy + +from text_generation.pb import generate_pb2 +from text_generation.models.causal_lm import CausalLMBatch +from text_generation.models.bloom import BloomCausalLMBatch, BLOOM + + +@pytest.fixture +def default_pb_request(default_pb_parameters): + return generate_pb2.Request( + id=0, + inputs="Test", + input_length=1, + parameters=default_pb_parameters, + max_new_tokens=10, + ) + + +@pytest.fixture +def default_pb_batch(default_pb_request): + return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1) + + +@pytest.fixture +def default_bloom_batch(default_pb_batch, bloom_560m_tokenizer): + return BloomCausalLMBatch.from_pb( + default_pb_batch, bloom_560m_tokenizer, torch.device("cpu") + ) + + +@pytest.fixture +def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer): + req_0 = copy(default_pb_request) + req_1 = default_pb_request + req_1.id = 1 + req_1.max_new_tokens = 5 + + batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2) + return BloomCausalLMBatch.from_pb( + batch_pb, bloom_560m_tokenizer, torch.device("cpu") + ) + + +@pytest.fixture(scope="session") +def default_bloom(): + return BLOOM("bigscience/bloom-560m") + + +def test_batch_from_pb(default_pb_batch, default_bloom_batch): + batch = default_bloom_batch + + assert batch.batch_id == default_pb_batch.id + assert batch.requests == default_pb_batch.requests + + assert len(batch.input_ids) == default_pb_batch.size + assert len(batch.input_ids[0]) == 8 + assert batch.input_ids[0][-1] == 10264 + assert torch.all(batch.input_ids[0][:-1] == 3) + + assert batch.attention_mask[0][-1] == 1 + assert torch.all(batch.attention_mask[0][:-1] == 0) + + assert batch.past_key_values is None + + assert torch.equal(batch.input_ids, batch.all_input_ids[:, :, 0]) + + assert batch.input_lengths == [1] + + assert batch.size == default_pb_batch.size + assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size + + assert batch.max_sequence_length == batch.input_lengths[0] + + +def test_batch_concatenate_no_prefill(default_bloom_batch): + with pytest.raises(ValueError): + BloomCausalLMBatch.concatenate([default_bloom_batch, default_bloom_batch]) + + +def test_causal_lm_batch_type(default_bloom): + assert default_bloom.batch_type == BloomCausalLMBatch + + +def test_causal_lm_generate_token(default_bloom, default_bloom_batch): + generated_texts, next_batch = default_bloom.generate_token(default_bloom_batch) + + assert generated_texts == [] + assert isinstance(next_batch, CausalLMBatch) + assert not next_batch.keys_head_dim_last + + assert len(next_batch.all_input_ids) == next_batch.size + assert len(next_batch.all_input_ids[0]) == len(next_batch.attention_mask[0]) == 9 + assert torch.all(next_batch.all_input_ids[0][-2:] == 10264) + assert torch.all(next_batch.all_input_ids[0][:-2] == 3) + + assert torch.all(next_batch.attention_mask[0][-2:] == 1) + assert torch.all(next_batch.attention_mask[0][:-2] == 0) + + assert next_batch.input_ids.shape == (next_batch.size, 1) + assert next_batch.input_ids[0, 0] == 10264 + + assert next_batch.input_lengths == [2] + assert next_batch.max_sequence_length == next_batch.input_lengths[0] + + assert next_batch.past_key_values is not None + assert all([p[0].shape == (16, 64, 8) for p in next_batch.past_key_values]) + assert all([p[1].shape == (16, 8, 64) for p in next_batch.past_key_values]) + + +def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch): + next_batch = default_bloom_batch + for _ in range(default_bloom_batch.stopping_criterias[0].max_new_tokens - 1): + generated_texts, next_batch = default_bloom.generate_token(next_batch) + assert generated_texts == [] + + generated_texts, next_batch = default_bloom.generate_token(next_batch) + assert next_batch is None + + assert len(generated_texts) == 1 + assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest" + assert generated_texts[0].request == default_bloom_batch.requests[0] + assert ( + generated_texts[0].tokens + == default_bloom_batch.stopping_criterias[0].max_new_tokens + ) + + +def test_causal_lm_generate_token_completion_multi( + default_bloom, default_multi_requests_bloom_batch +): + next_batch = default_multi_requests_bloom_batch + + for i in range( + default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 1 + ): + generated_texts, next_batch = default_bloom.generate_token(next_batch) + assert generated_texts == [] + + generated_texts, next_batch = default_bloom.generate_token(next_batch) + assert next_batch is not None + + assert len(generated_texts) == 1 + assert generated_texts[0].output == "TestTestTestTestTestTest" + assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[1] + assert ( + generated_texts[0].tokens + == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens + ) + + for _ in range( + default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens + - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens + - 1 + ): + generated_texts, next_batch = default_bloom.generate_token(next_batch) + assert generated_texts == [] + + generated_texts, next_batch = default_bloom.generate_token(next_batch) + assert next_batch is None + + assert len(generated_texts) == 1 + assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest" + assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0] + assert ( + generated_texts[0].tokens + == default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens + ) + + +def test_batch_concatenate( + default_bloom, default_bloom_batch, default_multi_requests_bloom_batch +): + next_batch_0 = default_bloom_batch + _, next_batch_0 = default_bloom.generate_token(next_batch_0) + _, next_batch_0 = default_bloom.generate_token(next_batch_0) + + next_batch_1 = default_multi_requests_bloom_batch + _, next_batch_1 = default_bloom.generate_token(next_batch_1) + + next_batch = BloomCausalLMBatch.concatenate([next_batch_0, next_batch_1]) + + assert torch.equal(next_batch.all_input_ids[0], next_batch_0.all_input_ids[0]) + assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0]) + assert torch.equal(next_batch.all_input_ids[2], next_batch_1.all_input_ids[1]) + + assert torch.all(next_batch.attention_mask[0] == 1) + assert torch.all(next_batch.attention_mask[1:, -2:] == 1) + assert torch.all(next_batch.attention_mask[1:, :-2] == 0) + + assert next_batch.batch_id == 0 + assert torch.all(next_batch.input_ids == 10264) + + assert next_batch.input_lengths == [3, 2, 2] + assert next_batch.max_sequence_length == 3 + + assert next_batch.requests[0] == next_batch_0.requests[0] + assert next_batch.requests[1:] == next_batch_1.requests + + assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0] + assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers + + assert next_batch.stopping_criterias[0] == next_batch_0.stopping_criterias[0] + assert next_batch.stopping_criterias[1:] == next_batch_1.stopping_criterias + + assert next_batch.past_key_values is not None + assert all([p[0].shape == (3, 16, 64, 2) for p in next_batch.past_key_values]) + assert all([p[1].shape == (3, 16, 2, 64) for p in next_batch.past_key_values]) + + for i, past in enumerate(next_batch.past_key_values): + assert torch.equal(next_batch_0.past_key_values[i][0][:, :, -2:], past[0][0]) + assert torch.equal( + next_batch_1.past_key_values[i][0][:, :, -1:], + past[0][1:, :, :, -1].reshape(-1, 64, 1), + ) + + assert torch.equal(next_batch_0.past_key_values[i][1][:, -2:, :], past[1][0]) + assert torch.equal( + next_batch_1.past_key_values[i][1][:, -1:, :], + past[1][1:, :, -1, :].reshape(-1, 1, 64), + ) + + for _ in range( + default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 2 + ): + generated_texts, next_batch = default_bloom.generate_token(next_batch) + assert generated_texts == [] + + generated_texts, next_batch = default_bloom.generate_token(next_batch) + assert next_batch is not None + + assert len(generated_texts) == 1 + assert generated_texts[0].output == "TestTestTestTestTestTest" + assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[1] + assert ( + generated_texts[0].tokens + == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens + ) + + for _ in range( + default_bloom_batch.stopping_criterias[0].max_new_tokens + - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens + - 2 + ): + generated_texts, next_batch = default_bloom.generate_token(next_batch) + assert generated_texts == [] + + generated_texts, next_batch = default_bloom.generate_token(next_batch) + assert next_batch is not None + + assert len(generated_texts) == 1 + assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest" + assert generated_texts[0].request == default_bloom_batch.requests[0] + assert ( + generated_texts[0].tokens + == default_bloom_batch.stopping_criterias[0].max_new_tokens + ) + + for _ in range( + default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens + - default_bloom_batch.stopping_criterias[0].max_new_tokens + - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens + - 4 + ): + generated_texts, next_batch = default_bloom.generate_token(next_batch) + assert generated_texts == [] + + generated_texts, next_batch = default_bloom.generate_token(next_batch) + assert next_batch is None + + assert len(generated_texts) == 1 + assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest" + assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0] + assert ( + generated_texts[0].tokens + == default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens + ) diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py new file mode 100644 index 00000000..1bf3e5e6 --- /dev/null +++ b/server/tests/models/test_causal_lm.py @@ -0,0 +1,296 @@ +import pytest +import torch + +from copy import copy + +from text_generation.pb import generate_pb2 +from text_generation.models.causal_lm import CausalLM, CausalLMBatch + + +@pytest.fixture +def default_pb_request(default_pb_parameters): + return generate_pb2.Request( + id=0, + inputs="Test", + input_length=1, + parameters=default_pb_parameters, + max_new_tokens=10, + ) + + +@pytest.fixture +def default_pb_batch(default_pb_request): + return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1) + + +@pytest.fixture +def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer): + return CausalLMBatch.from_pb(default_pb_batch, gpt2_tokenizer, torch.device("cpu")) + + +@pytest.fixture +def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer): + req_0 = copy(default_pb_request) + req_1 = default_pb_request + req_1.id = 1 + req_1.max_new_tokens = 5 + + batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2) + return CausalLMBatch.from_pb(batch_pb, gpt2_tokenizer, torch.device("cpu")) + + +@pytest.fixture(scope="session") +def default_causal_lm(): + return CausalLM("gpt2") + + +def test_batch_from_pb(default_pb_batch, default_causal_lm_batch): + batch = default_causal_lm_batch + + assert batch.batch_id == default_pb_batch.id + assert batch.requests == default_pb_batch.requests + + assert len(batch.input_ids) == default_pb_batch.size + assert len(batch.input_ids[0]) == 8 + assert batch.input_ids[0][-1] == 14402 + assert torch.all(batch.input_ids[0][:-1] == 50256) + + assert batch.attention_mask[0][-1] == 1 + assert torch.all(batch.attention_mask[0][:-1] == 0) + + assert batch.past_key_values is None + + assert torch.equal(batch.input_ids, batch.all_input_ids[:, :, 0]) + + assert batch.input_lengths == [1] + + assert batch.size == default_pb_batch.size + assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size + + assert batch.max_sequence_length == batch.input_lengths[0] + + +def test_batch_concatenate_no_prefill(default_causal_lm_batch): + with pytest.raises(ValueError): + CausalLMBatch.concatenate([default_causal_lm_batch, default_causal_lm_batch]) + + +def test_causal_lm_batch_type(default_causal_lm): + assert default_causal_lm.batch_type == CausalLMBatch + + +def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch): + generated_texts, next_batch = default_causal_lm.generate_token( + default_causal_lm_batch + ) + + assert generated_texts == [] + assert isinstance(next_batch, CausalLMBatch) + + assert len(next_batch.all_input_ids) == next_batch.size + assert len(next_batch.all_input_ids[0]) == len(next_batch.attention_mask[0]) == 9 + assert next_batch.all_input_ids[0][-1] == 6208 + assert next_batch.all_input_ids[0][-2] == 14402 + assert torch.all(next_batch.all_input_ids[0][:-2] == 50256) + + assert torch.all(next_batch.attention_mask[0][-2:] == 1) + assert torch.all(next_batch.attention_mask[0][:-2] == 0) + + assert next_batch.input_ids.shape == (next_batch.size, 1) + assert next_batch.input_ids[0, 0] == 6208 + + assert next_batch.input_lengths == [2] + assert next_batch.max_sequence_length == next_batch.input_lengths[0] + + assert next_batch.past_key_values is not None + assert all([p[0].shape == (1, 12, 8, 64) for p in next_batch.past_key_values]) + assert all([p[1].shape == (1, 12, 8, 64) for p in next_batch.past_key_values]) + + +def test_causal_lm_generate_token_completion( + default_causal_lm, default_causal_lm_batch +): + next_batch = default_causal_lm_batch + for _ in range(default_causal_lm_batch.stopping_criterias[0].max_new_tokens - 1): + generated_texts, next_batch = default_causal_lm.generate_token(next_batch) + assert generated_texts == [] + + generated_texts, next_batch = default_causal_lm.generate_token(next_batch) + assert next_batch is None + + assert len(generated_texts) == 1 + assert ( + generated_texts[0].output + == "Test Test Test Test Test Test Test Test Test Test Test" + ) + assert generated_texts[0].request == default_causal_lm_batch.requests[0] + assert ( + generated_texts[0].tokens + == default_causal_lm_batch.stopping_criterias[0].max_new_tokens + ) + + +def test_causal_lm_generate_token_completion_multi( + default_causal_lm, default_multi_requests_causal_lm_batch +): + next_batch = default_multi_requests_causal_lm_batch + + for i in range( + default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1 + ): + generated_texts, next_batch = default_causal_lm.generate_token(next_batch) + assert generated_texts == [] + + generated_texts, next_batch = default_causal_lm.generate_token(next_batch) + assert next_batch is not None + + assert len(generated_texts) == 1 + assert generated_texts[0].output == "Test Test Test Test Test Test" + assert ( + generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1] + ) + assert ( + generated_texts[0].tokens + == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens + ) + + for _ in range( + default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens + - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens + - 1 + ): + generated_texts, next_batch = default_causal_lm.generate_token(next_batch) + assert generated_texts == [] + + generated_texts, next_batch = default_causal_lm.generate_token(next_batch) + assert next_batch is None + + assert len(generated_texts) == 1 + assert ( + generated_texts[0].output + == "Test Test Test Test Test Test Test Test Test Test Test" + ) + assert ( + generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0] + ) + assert ( + generated_texts[0].tokens + == default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens + ) + + +def test_batch_concatenate( + default_causal_lm, default_causal_lm_batch, default_multi_requests_causal_lm_batch +): + next_batch_0 = default_causal_lm_batch + _, next_batch_0 = default_causal_lm.generate_token(next_batch_0) + _, next_batch_0 = default_causal_lm.generate_token(next_batch_0) + + next_batch_1 = default_multi_requests_causal_lm_batch + _, next_batch_1 = default_causal_lm.generate_token(next_batch_1) + + next_batch = CausalLMBatch.concatenate([next_batch_0, next_batch_1]) + + assert torch.equal(next_batch.all_input_ids[0], next_batch_0.all_input_ids[0]) + assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0]) + assert torch.equal(next_batch.all_input_ids[2], next_batch_1.all_input_ids[1]) + + assert torch.all(next_batch.attention_mask[0] == 1) + assert torch.all(next_batch.attention_mask[1:, -2:] == 1) + assert torch.all(next_batch.attention_mask[1:, :-2] == 0) + + assert next_batch.batch_id == 0 + assert torch.all(next_batch.input_ids == 6208) + + assert next_batch.input_lengths == [3, 2, 2] + assert next_batch.max_sequence_length == 3 + + assert next_batch.requests[0] == next_batch_0.requests[0] + assert next_batch.requests[1:] == next_batch_1.requests + + assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0] + assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers + + assert next_batch.stopping_criterias[0] == next_batch_0.stopping_criterias[0] + assert next_batch.stopping_criterias[1:] == next_batch_1.stopping_criterias + + assert next_batch.past_key_values is not None + assert all([p[0].shape == (3, 12, 2, 64) for p in next_batch.past_key_values]) + assert all([p[1].shape == (3, 12, 2, 64) for p in next_batch.past_key_values]) + + for i, past in enumerate(next_batch.past_key_values): + assert torch.equal(next_batch_0.past_key_values[i][0][0, :, -2:], past[0][0]) + assert torch.equal( + next_batch_1.past_key_values[i][0][:, :, -1:], past[0][1:, :, -1:, :] + ) + + assert torch.equal(next_batch_0.past_key_values[i][1][0, :, -2:], past[1][0]) + assert torch.equal( + next_batch_1.past_key_values[i][1][:, :, -1:], past[1][1:, :, -1:, :] + ) + + for _ in range( + default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2 + ): + generated_texts, next_batch = default_causal_lm.generate_token(next_batch) + assert generated_texts == [] + + generated_texts, next_batch = default_causal_lm.generate_token(next_batch) + assert next_batch is not None + + assert len(generated_texts) == 1 + assert generated_texts[0].output == "Test Test Test Test Test Test" + assert ( + generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1] + ) + assert ( + generated_texts[0].tokens + == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens + ) + + for _ in range( + default_causal_lm_batch.stopping_criterias[0].max_new_tokens + - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens + - 2 + ): + generated_texts, next_batch = default_causal_lm.generate_token(next_batch) + assert generated_texts == [] + + generated_texts, next_batch = default_causal_lm.generate_token(next_batch) + assert next_batch is not None + + assert len(generated_texts) == 1 + assert ( + generated_texts[0].output + == "Test Test Test Test Test Test Test Test Test Test Test" + ) + assert generated_texts[0].request == default_causal_lm_batch.requests[0] + assert ( + generated_texts[0].tokens + == default_causal_lm_batch.stopping_criterias[0].max_new_tokens + ) + + for _ in range( + default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens + - default_causal_lm_batch.stopping_criterias[0].max_new_tokens + - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens + - 4 + ): + generated_texts, next_batch = default_causal_lm.generate_token(next_batch) + assert generated_texts == [] + + generated_texts, next_batch = default_causal_lm.generate_token(next_batch) + assert next_batch is None + + assert len(generated_texts) == 1 + assert ( + generated_texts[0].output + == "Test Test Test Test Test Test Test Test Test Test Test" + ) + assert ( + generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0] + ) + assert ( + generated_texts[0].tokens + == default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens + ) diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py new file mode 100644 index 00000000..7e4c7fdd --- /dev/null +++ b/server/tests/models/test_seq2seq_lm.py @@ -0,0 +1,306 @@ +import pytest +import torch + +from copy import copy + +from text_generation.pb import generate_pb2 +from text_generation.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch + + +@pytest.fixture +def default_pb_request(default_pb_parameters): + return generate_pb2.Request( + id=0, + inputs="Test", + input_length=2, + parameters=default_pb_parameters, + max_new_tokens=10, + ) + + +@pytest.fixture +def default_pb_batch(default_pb_request): + return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1) + + +@pytest.fixture +def default_seq2seq_lm_batch(default_pb_batch, mt0_small_tokenizer): + return Seq2SeqLMBatch.from_pb( + default_pb_batch, mt0_small_tokenizer, torch.device("cpu") + ) + + +@pytest.fixture +def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokenizer): + req_0 = copy(default_pb_request) + req_1 = default_pb_request + req_1.id = 1 + req_1.max_new_tokens = 5 + + batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2) + return Seq2SeqLMBatch.from_pb(batch_pb, mt0_small_tokenizer, torch.device("cpu")) + + +@pytest.fixture(scope="session") +def default_seq2seq_lm(): + return Seq2SeqLM("bigscience/mt0-small") + + +def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch): + batch = default_seq2seq_lm_batch + + assert batch.batch_id == default_pb_batch.id + assert batch.requests == default_pb_batch.requests + + assert batch.input_ids.shape == (default_pb_batch.size, 8) + assert batch.input_ids[0][-2] == 4268 + assert batch.input_ids[0][-1] == 1 + assert torch.all(batch.input_ids[0][:-2] == 0) + + assert torch.all(batch.attention_mask[0][-2:] == 1) + assert torch.all(batch.attention_mask[0][:-2] == 0) + + assert batch.decoder_input_ids.shape == (default_pb_batch.size, 1) + assert batch.decoder_attention_mask is None + assert batch.encoder_last_hidden_state is None + + assert batch.past_key_values is None + + assert batch.input_lengths == [2] + assert batch.decoder_input_lengths == [1] + + assert batch.size == default_pb_batch.size + assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == batch.size + + assert batch.max_input_length == batch.input_lengths[0] + assert batch.max_decoder_input_length == batch.decoder_input_lengths[0] + + +def test_batch_concatenate_no_prefill(default_seq2seq_lm_batch): + with pytest.raises(ValueError): + Seq2SeqLMBatch.concatenate([default_seq2seq_lm_batch, default_seq2seq_lm_batch]) + + +def test_seq2seq_lm_batch_type(default_seq2seq_lm): + assert default_seq2seq_lm.batch_type == Seq2SeqLMBatch + + +def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch): + generated_texts, next_batch = default_seq2seq_lm.generate_token( + default_seq2seq_lm_batch + ) + + assert generated_texts == [] + assert isinstance(next_batch, Seq2SeqLMBatch) + + assert torch.equal(next_batch.input_ids, default_seq2seq_lm_batch.input_ids) + assert torch.equal( + next_batch.attention_mask, default_seq2seq_lm_batch.attention_mask + ) + assert next_batch.input_lengths == default_seq2seq_lm_batch.input_lengths + assert next_batch.max_input_length == default_seq2seq_lm_batch.max_input_length + assert ( + next_batch.next_token_choosers == default_seq2seq_lm_batch.next_token_choosers + ) + assert next_batch.stopping_criterias == default_seq2seq_lm_batch.stopping_criterias + + assert next_batch.decoder_input_ids.shape == (next_batch.size, 2) + assert next_batch.decoder_input_ids[0, 0] == 0 + assert next_batch.decoder_input_ids[0, 1] == 259 + assert next_batch.decoder_attention_mask is None + assert next_batch.encoder_last_hidden_state.shape == (1, 8, 512) + + assert next_batch.decoder_input_lengths == [2] + assert next_batch.max_decoder_input_length == 2 + + assert next_batch.past_key_values is not None + assert all( + [p[0].shape == (next_batch.size, 6, 1, 64) for p in next_batch.past_key_values] + ) + assert all( + [p[1].shape == (next_batch.size, 6, 1, 64) for p in next_batch.past_key_values] + ) + assert all( + [p[2].shape == (next_batch.size, 6, 8, 64) for p in next_batch.past_key_values] + ) + assert all( + [p[3].shape == (next_batch.size, 6, 8, 64) for p in next_batch.past_key_values] + ) + + +def test_seq2seq_lm_generate_token_completion( + default_seq2seq_lm, default_seq2seq_lm_batch +): + next_batch = default_seq2seq_lm_batch + for _ in range(6): + generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) + assert generated_texts == [] + + generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) + assert next_batch is None + + assert len(generated_texts) == 1 + assert generated_texts[0].output == "a few weeks" + assert generated_texts[0].request == default_seq2seq_lm_batch.requests[0] + assert generated_texts[0].tokens == 7 + + +def test_seq2seq_lm_generate_token_completion_multi( + default_seq2seq_lm, default_multi_requests_seq2seq_lm_batch +): + next_batch = default_multi_requests_seq2seq_lm_batch + + for i in range(4): + generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) + assert generated_texts == [] + + generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) + assert next_batch is not None + + assert len(generated_texts) == 1 + assert generated_texts[0].output == "a few " + assert ( + generated_texts[0].request + == default_multi_requests_seq2seq_lm_batch.requests[1] + ) + assert generated_texts[0].tokens == 5 + + generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) + assert generated_texts == [] + + generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) + assert next_batch is None + + assert len(generated_texts) == 1 + assert generated_texts[0].output == "a few weeks" + assert ( + generated_texts[0].request + == default_multi_requests_seq2seq_lm_batch.requests[0] + ) + assert generated_texts[0].tokens == 7 + + +def test_batch_concatenate( + default_seq2seq_lm, + default_seq2seq_lm_batch, + default_multi_requests_seq2seq_lm_batch, +): + next_batch_0 = default_seq2seq_lm_batch + _, next_batch_0 = default_seq2seq_lm.generate_token(next_batch_0) + _, next_batch_0 = default_seq2seq_lm.generate_token(next_batch_0) + + next_batch_1 = default_multi_requests_seq2seq_lm_batch + _, next_batch_1 = default_seq2seq_lm.generate_token(next_batch_1) + + next_batch = Seq2SeqLMBatch.concatenate([next_batch_0, next_batch_1]) + + assert next_batch.batch_id == 0 + + assert torch.all(next_batch.input_ids[:, 0] == 4268) + assert torch.all(next_batch.input_ids[:, 1] == 1) + + assert torch.all(next_batch.attention_mask == 1) + + assert torch.equal( + next_batch.decoder_input_ids[0], next_batch_0.decoder_input_ids[0] + ) + assert torch.all(next_batch.decoder_input_ids[1:, 0] == 0) + assert torch.equal( + next_batch.decoder_input_ids[1:, -2:], next_batch_1.decoder_input_ids + ) + + assert torch.all(next_batch.decoder_attention_mask[0] == 1) + assert torch.all(next_batch.decoder_attention_mask[1:, 0] == 0) + assert torch.all(next_batch.decoder_attention_mask[1:, -2:] == 1) + + assert torch.equal( + next_batch.encoder_last_hidden_state[0], + next_batch_0.encoder_last_hidden_state[0, -2:], + ) + assert torch.equal( + next_batch.encoder_last_hidden_state[1:], + next_batch_1.encoder_last_hidden_state[:, -2:], + ) + + assert next_batch.input_lengths == [2, 2, 2] + assert next_batch.decoder_input_lengths == [3, 2, 2] + assert next_batch.max_input_length == 2 + assert next_batch.max_decoder_input_length == 3 + + assert next_batch.requests[0] == next_batch_0.requests[0] + assert next_batch.requests[1:] == next_batch_1.requests + + assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0] + assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers + + assert next_batch.stopping_criterias[0] == next_batch_0.stopping_criterias[0] + assert next_batch.stopping_criterias[1:] == next_batch_1.stopping_criterias + + assert next_batch.past_key_values is not None + assert all( + [p[0].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values] + ) + assert all( + [p[1].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values] + ) + assert all( + [p[2].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values] + ) + assert all( + [p[3].shape == (next_batch.size, 6, 2, 64) for p in next_batch.past_key_values] + ) + + for i, past in enumerate(next_batch.past_key_values): + assert torch.equal(next_batch_0.past_key_values[i][0][0, :, -2:, :], past[0][0]) + assert torch.equal( + next_batch_1.past_key_values[i][0][:, :, -1:, :], past[0][1:, :, -1:, :] + ) + + assert torch.equal(next_batch_0.past_key_values[i][1][0, :, -2:, :], past[1][0]) + assert torch.equal( + next_batch_1.past_key_values[i][1][:, :, -1:, :], past[1][1:, :, -1:, :] + ) + + assert torch.equal(next_batch_0.past_key_values[i][2][0, :, -2:, :], past[2][0]) + assert torch.equal( + next_batch_1.past_key_values[i][2][:, :, -2:, :], past[2][1:] + ) + + assert torch.equal(next_batch_0.past_key_values[i][3][0, :, -2:, :], past[3][0]) + assert torch.equal( + next_batch_1.past_key_values[i][3][:, :, -2:, :], past[3][1:] + ) + + for _ in range(3): + generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) + assert generated_texts == [] + + generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) + assert next_batch is not None + + assert len(generated_texts) == 1 + assert generated_texts[0].output == "a few " + assert ( + generated_texts[0].request + == default_multi_requests_seq2seq_lm_batch.requests[1] + ) + assert generated_texts[0].tokens == 5 + + generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) + assert next_batch is not None + + assert len(generated_texts) == 1 + assert generated_texts[0].output == "a few weeks" + assert generated_texts[0].request == default_seq2seq_lm_batch.requests[0] + assert generated_texts[0].tokens == 7 + + generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) + assert next_batch is None + + assert len(generated_texts) == 1 + assert generated_texts[0].output == "a few weeks" + assert ( + generated_texts[0].request + == default_multi_requests_seq2seq_lm_batch.requests[0] + ) + assert generated_texts[0].tokens == 7 diff --git a/server/tests/test_utils.py b/server/tests/test_utils.py new file mode 100644 index 00000000..e630ebda --- /dev/null +++ b/server/tests/test_utils.py @@ -0,0 +1,34 @@ +import pytest + +from text_generation.utils import ( + weight_hub_files, + download_weights, + weight_files, + LocalEntryNotFoundError, +) + + +def test_weight_hub_files(): + filenames = weight_hub_files("bigscience/bloom-560m") + assert filenames == ["model.safetensors"] + + +def test_weight_hub_files_llm(): + filenames = weight_hub_files("bigscience/bloom") + assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)] + + +def test_weight_hub_files_empty(): + filenames = weight_hub_files("bigscience/bloom", ".errors") + assert filenames == [] + + +def test_download_weights(): + files = download_weights("bigscience/bloom-560m") + local_files = weight_files("bigscience/bloom-560m") + assert files == local_files + + +def test_weight_files_error(): + with pytest.raises(LocalEntryNotFoundError): + weight_files("bert-base-uncased") diff --git a/server/text_generation/models/__init__.py b/server/text_generation/models/__init__.py index b364309a..935d74ef 100644 --- a/server/text_generation/models/__init__.py +++ b/server/text_generation/models/__init__.py @@ -1,10 +1,10 @@ from text_generation.models.model import Model from text_generation.models.causal_lm import CausalLM -from text_generation.models.bloom import BLOOMSharded +from text_generation.models.bloom import BLOOM, BLOOMSharded from text_generation.models.seq2seq_lm import Seq2SeqLM from text_generation.models.galactica import Galactica, GalacticaSharded -__all__ = ["Model", "BLOOMSharded", "CausalLM", "Seq2SeqLM"] +__all__ = ["Model", "BLOOM", "BLOOMSharded", "CausalLM", "Seq2SeqLM"] def get_model(model_name: str, sharded: bool, quantize: bool) -> Model: @@ -12,7 +12,7 @@ def get_model(model_name: str, sharded: bool, quantize: bool) -> Model: if sharded: return BLOOMSharded(model_name, quantize=quantize) else: - return CausalLM(model_name, quantize=quantize) + return BLOOM(model_name, quantize=quantize) elif model_name.startswith("facebook/galactica"): if sharded: return GalacticaSharded(model_name, quantize=quantize) diff --git a/server/text_generation/models/bloom.py b/server/text_generation/models/bloom.py index 2a7405d3..20e26419 100644 --- a/server/text_generation/models/bloom.py +++ b/server/text_generation/models/bloom.py @@ -1,7 +1,7 @@ import torch import torch.distributed -from typing import List, Optional +from typing import List, Optional, Type from accelerate import init_empty_weights from safetensors import safe_open @@ -13,6 +13,8 @@ from transformers.models.bloom.parallel_layers import ( ) from text_generation.models import CausalLM +from text_generation.models.causal_lm import CausalLMBatch +from text_generation.pb import generate_pb2 from text_generation.utils import ( initialize_torch_distributed, weight_files, @@ -29,7 +31,25 @@ except Exception as e: torch.manual_seed(0) -class BLOOMSharded(CausalLM): +class BloomCausalLMBatch(CausalLMBatch): + @classmethod + def from_pb( + cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device + ) -> "CausalLMBatch": + batch = super(BloomCausalLMBatch, cls).from_pb( + pb=pb, tokenizer=tokenizer, device=device + ) + batch.keys_head_dim_last = False + return batch + + +class BLOOM(CausalLM): + @property + def batch_type(self) -> Type[CausalLMBatch]: + return BloomCausalLMBatch + + +class BLOOMSharded(BLOOM): def __init__(self, model_name: str, quantize: bool = False): if not model_name.startswith("bigscience/bloom"): raise ValueError(f"Model {model_name} is not supported") diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index 4e66ae3a..2a88c781 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -34,6 +34,9 @@ class CausalLMBatch: size: int max_sequence_length: int + # Past metadata + keys_head_dim_last: bool = True + def to_pb(self): return generate_pb2.Batch( id=self.batch_id, @@ -165,20 +168,16 @@ class CausalLMBatch: head_dim, ) + if batch.keys_head_dim_last: + padded_past_keys_shape = padded_past_values_shape # seq_length is last for BLOOM - if past_keys.shape[-2] == head_dim: - past_keys_head_dim_last = False + else: padded_past_keys_shape = ( total_batch_size, num_heads, head_dim, max_sequence_length - 1, ) - elif past_keys.shape[-1] == head_dim: - past_keys_head_dim_last = True - padded_past_keys_shape = padded_past_values_shape - else: - raise ValueError(f"past_keys shape {past_keys.shape} is not valid") # This will run only once per layer if j == len(past_key_values): @@ -195,7 +194,7 @@ class CausalLMBatch: past_key_values.append((padded_past_keys, padded_past_values)) # We slice the past keys and values to remove the padding from previous batches - if past_keys_head_dim_last: + if batch.keys_head_dim_last: past_key_values[j][0][ start_index:end_index, :, @@ -228,6 +227,7 @@ class CausalLMBatch: stopping_criterias=stopping_criterias, size=total_batch_size, max_sequence_length=max_sequence_length, + keys_head_dim_last=batches[0].keys_head_dim_last, ) @@ -237,6 +237,9 @@ class CausalLM(Model): device = torch.device("cuda") dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 else: + if quantize: + raise ValueError("quantization is not available on CPU") + device = torch.device("cpu") dtype = torch.float32 @@ -247,7 +250,11 @@ class CausalLM(Model): device_map="auto" if torch.cuda.is_available() else None, load_in_8bit=quantize, ).eval() - tokenizer.pad_token_id = self.model.config.pad_token_id + tokenizer.pad_token_id = ( + self.model.config.pad_token_id + if self.model.config.pad_token_id is not None + else self.model.config.eos_token_id + ) super(CausalLM, self).__init__( tokenizer=tokenizer, @@ -397,5 +404,6 @@ class CausalLM(Model): stopping_criterias=next_batch_stopping_criterias, size=next_batch_size, max_sequence_length=next_batch_max_sequence_length, + keys_head_dim_last=batch.keys_head_dim_last, ) return generated_texts, next_batch diff --git a/server/text_generation/models/galactica.py b/server/text_generation/models/galactica.py index 5de75ab4..8aec1bc7 100644 --- a/server/text_generation/models/galactica.py +++ b/server/text_generation/models/galactica.py @@ -83,7 +83,7 @@ class GalacticaCausalLMBatch(CausalLMBatch): @classmethod def from_pb( cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device - ) -> "CausalLMBatch": + ) -> "GalacticaCausalLMBatch": inputs = [] next_token_choosers = [] stopping_criterias = [] diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index e9c65596..3302138f 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -221,8 +221,8 @@ class Seq2SeqLMBatch: # Copy to correct indices encoder_last_hidden_state[ - start_index:end_index, -batch.max_decoder_input_length :, : - ] = batch.encoder_last_hidden_state[:, -batch.max_decoder_input_length :, :] + start_index:end_index, -batch.max_input_length :, : + ] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :] # Iterate over attention layers for j, past in enumerate(batch.past_key_values): @@ -305,6 +305,9 @@ class Seq2SeqLM(Model): device = torch.device("cuda") dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 else: + if quantize: + raise ValueError("quantization is not available on CPU") + device = torch.device("cpu") dtype = torch.float32 diff --git a/server/text_generation/utils.py b/server/text_generation/utils.py index e6bfc391..e55eeb64 100644 --- a/server/text_generation/utils.py +++ b/server/text_generation/utils.py @@ -137,8 +137,8 @@ def download_weights(model_name, extension=".safetensors"): executor.submit(download_function, filename=filename) for filename in filenames ] files = [ - file - for file in tqdm(concurrent.futures.as_completed(futures), total=len(futures)) + future.result() + for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)) ] return files