From 5a58226130fcd35774cd14d8c5ba638a4adb2bf4 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Tue, 16 May 2023 23:23:27 +0200 Subject: [PATCH] fix(server): fix decode token (#334) Fixes #333 --------- Co-authored-by: Nicolas Patry --- .github/workflows/build.yaml | 3 +- .github/workflows/tests.yaml | 3 +- Makefile | 2 +- integration-tests/conftest.py | 16 +- .../test_flash_llama_load.json | 176 +++++++++--------- .../test_mt0_base/test_mt0_base.json | 2 +- .../test_mt0_base_all_params.json | 2 +- .../test_mt0_base/test_mt0_base_load.json | 114 ++++++------ router/src/main.rs | 2 +- server/Makefile | 2 +- server/tests/models/test_model.py | 78 ++++++++ server/tests/models/test_seq2seq_lm.py | 2 +- server/text_generation_server/models/bloom.py | 5 +- .../models/causal_lm.py | 71 +++---- .../models/flash_causal_lm.py | 77 ++++---- .../models/flash_llama.py | 4 +- .../models/flash_neox.py | 2 +- .../models/flash_santacoder.py | 6 +- .../models/galactica.py | 16 +- .../text_generation_server/models/gpt_neox.py | 2 +- server/text_generation_server/models/model.py | 68 +++---- server/text_generation_server/models/opt.py | 2 +- .../models/santacoder.py | 20 +- .../models/seq2seq_lm.py | 66 +++---- server/text_generation_server/models/t5.py | 5 +- 25 files changed, 396 insertions(+), 350 deletions(-) create mode 100644 server/tests/models/test_model.py diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 79b3c777..9305b8e7 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -213,12 +213,13 @@ jobs: sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }} - name: Install run: | + pip install pytest-xdist make install-integration-tests - name: Run tests run: | export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }} export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} - pytest -s -vv integration-tests + pytest -s -vv -n 2 --dist loadfile integration-tests stop-runner: name: Stop self-hosted EC2 runner diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index d9858a3b..7e5ba52c 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -66,7 +66,8 @@ jobs: - name: Run server tests run: | pip install pytest - make python-server-tests + export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} + pytest -s -vv server/tests - name: Run Rust fmt run: | cargo fmt --check diff --git a/Makefile b/Makefile index 29c318fa..7309aaee 100644 --- a/Makefile +++ b/Makefile @@ -31,7 +31,7 @@ update-integration-tests: install-integration-tests pytest -s -vv --snapshot-update integration-tests python-server-tests: - HF_HUB_ENABLE_HF_TRANSFER=1 pytest server/tests + HF_HUB_ENABLE_HF_TRANSFER=1 pytest -s -vv -m "not private" server/tests python-client-tests: pytest clients/python/tests diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index ba1abca9..3086ecda 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -1,3 +1,4 @@ +import sys import subprocess import contextlib import pytest @@ -7,6 +8,7 @@ import docker import json import math import time +import random from docker.errors import NotFound from typing import Optional, List, Dict @@ -205,10 +207,12 @@ def launcher(event_loop): def local_launcher( model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None ): - port = 9999 - master_port = 19999 + port = random.randint(8000, 10_000) + master_port = random.randint(10_000, 20_000) - shard_uds_path = f"/tmp/{model_id.replace('/', '--')}-server" + shard_uds_path = ( + f"/tmp/tgi-tests-{model_id.split('/')[-1]}-{num_shard}-{quantize}-server" + ) args = [ "text-generation-launcher", @@ -236,7 +240,7 @@ def launcher(event_loop): process.wait(60) launcher_output = process.stdout.read().decode("utf-8") - print(launcher_output) + print(launcher_output, file=sys.stderr) process.stdout.close() process.stderr.close() @@ -245,7 +249,7 @@ def launcher(event_loop): def docker_launcher( model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None ): - port = 9999 + port = random.randint(8000, 10_000) args = ["--model-id", model_id, "--env"] @@ -298,7 +302,7 @@ def launcher(event_loop): pass container_output = container.logs().decode("utf-8") - print(container_output) + print(container_output, file=sys.stderr) container.remove() diff --git a/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_load.json b/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_load.json index 5a8ba217..9bbb5322 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_load.json +++ b/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_load.json @@ -1,92 +1,4 @@ [ - { - "details": { - "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, - "prefill": [ - { - "id": 1, - "logprob": null, - "text": "" - }, - { - "id": 4321, - "logprob": -8.6875, - "text": "Test" - }, - { - "id": 2009, - "logprob": -11.5546875, - "text": "request" - } - ], - "seed": null, - "tokens": [ - { - "id": 363, - "logprob": -1.5322266, - "special": false, - "text": " for" - }, - { - "id": 847, - "logprob": -2.5585938, - "special": false, - "text": " /" - }, - { - "id": 2754, - "logprob": -2.265625, - "special": false, - "text": "api" - }, - { - "id": 29914, - "logprob": -0.034088135, - "special": false, - "text": "/" - }, - { - "id": 29894, - "logprob": -0.96240234, - "special": false, - "text": "v" - }, - { - "id": 29896, - "logprob": -0.36816406, - "special": false, - "text": "1" - }, - { - "id": 29914, - "logprob": -0.013191223, - "special": false, - "text": "/" - }, - { - "id": 16418, - "logprob": -3.15625, - "special": false, - "text": "projects" - }, - { - "id": 29914, - "logprob": -0.43774414, - "special": false, - "text": "/" - }, - { - "id": 29896, - "logprob": -1.9443359, - "special": false, - "text": "1" - } - ] - }, - "generated_text": "for /api/v1/projects/1" - }, { "details": { "best_of_sequences": null, @@ -263,6 +175,94 @@ }, "generated_text": "for /api/v1/projects/1" }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -8.6875, + "text": "Test" + }, + { + "id": 2009, + "logprob": -11.5546875, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 363, + "logprob": -1.5322266, + "special": false, + "text": " for" + }, + { + "id": 847, + "logprob": -2.5585938, + "special": false, + "text": " /" + }, + { + "id": 2754, + "logprob": -2.265625, + "special": false, + "text": "api" + }, + { + "id": 29914, + "logprob": -0.034088135, + "special": false, + "text": "/" + }, + { + "id": 29894, + "logprob": -0.96240234, + "special": false, + "text": "v" + }, + { + "id": 29896, + "logprob": -0.36816406, + "special": false, + "text": "1" + }, + { + "id": 29914, + "logprob": -0.013191223, + "special": false, + "text": "/" + }, + { + "id": 16418, + "logprob": -3.15625, + "special": false, + "text": "projects" + }, + { + "id": 29914, + "logprob": -0.43774414, + "special": false, + "text": "/" + }, + { + "id": 29896, + "logprob": -1.9443359, + "special": false, + "text": "1" + } + ] + }, + "generated_text": "for /api/v1/projects/1" + }, { "details": { "best_of_sequences": null, diff --git a/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base.json b/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base.json index 2a26e3db..c1cd24cd 100644 --- a/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base.json +++ b/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base.json @@ -16,7 +16,7 @@ "id": 926, "logprob": -4.3554688, "special": false, - "text": "To" + "text": " To" }, { "id": 18295, diff --git a/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_all_params.json b/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_all_params.json index fd77252d..3e9f3d73 100644 --- a/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_all_params.json +++ b/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_all_params.json @@ -16,7 +16,7 @@ "id": 16017, "logprob": -1.3505859, "special": false, - "text": "blue" + "text": " blue" }, { "id": 20495, diff --git a/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_load.json b/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_load.json index c9e552b6..c0834ae1 100644 --- a/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_load.json +++ b/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_load.json @@ -1,58 +1,4 @@ [ - { - "details": { - "best_of_sequences": null, - "finish_reason": "eos_token", - "generated_tokens": 6, - "prefill": [ - { - "id": 0, - "logprob": null, - "text": "" - } - ], - "seed": null, - "tokens": [ - { - "id": 259, - "logprob": -1.3789062, - "special": false, - "text": "" - }, - { - "id": 39261, - "logprob": -0.36279297, - "special": false, - "text": "Because" - }, - { - "id": 609, - "logprob": -1.0966797, - "special": false, - "text": " it" - }, - { - "id": 339, - "logprob": -0.8276367, - "special": false, - "text": " is" - }, - { - "id": 16017, - "logprob": -1.6845703, - "special": false, - "text": " blue" - }, - { - "id": 1, - "logprob": -0.72753906, - "special": true, - "text": "" - } - ] - }, - "generated_text": "Because it is blue" - }, { "details": { "best_of_sequences": null, @@ -71,7 +17,7 @@ "id": 259, "logprob": -1.3798828, "special": false, - "text": "" + "text": " " }, { "id": 39261, @@ -125,7 +71,7 @@ "id": 259, "logprob": -1.3789062, "special": false, - "text": "" + "text": " " }, { "id": 39261, @@ -179,7 +125,61 @@ "id": 259, "logprob": -1.3789062, "special": false, - "text": "" + "text": " " + }, + { + "id": 39261, + "logprob": -0.36279297, + "special": false, + "text": "Because" + }, + { + "id": 609, + "logprob": -1.0966797, + "special": false, + "text": " it" + }, + { + "id": 339, + "logprob": -0.8276367, + "special": false, + "text": " is" + }, + { + "id": 16017, + "logprob": -1.6845703, + "special": false, + "text": " blue" + }, + { + "id": 1, + "logprob": -0.72753906, + "special": true, + "text": "" + } + ] + }, + "generated_text": "Because it is blue" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "eos_token", + "generated_tokens": 6, + "prefill": [ + { + "id": 0, + "logprob": null, + "text": "" + } + ], + "seed": null, + "tokens": [ + { + "id": 259, + "logprob": -1.3789062, + "special": false, + "text": " " }, { "id": 39261, diff --git a/router/src/main.rs b/router/src/main.rs index 5ad49003..82bf6ba8 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -146,7 +146,7 @@ fn main() -> Result<(), std::io::Error> { sha: None, pipeline_tag: None, }, - false => get_model_info(&tokenizer_name, &revision, authorization_token).await.unwrap_or({ + false => get_model_info(&tokenizer_name, &revision, authorization_token).await.unwrap_or_else(|| { tracing::warn!("Could not retrieve model info from the Hugging Face hub."); HubModelInfo { model_id: tokenizer_name.to_string(), sha: None, pipeline_tag: None } }), diff --git a/server/Makefile b/server/Makefile index 150d7e4a..6eb56c75 100644 --- a/server/Makefile +++ b/server/Makefile @@ -2,7 +2,7 @@ include Makefile-transformers include Makefile-flash-att unit-tests: - python -m pytest tests + pytest -s -vv -m "not private" tests gen-server: # Compile protos diff --git a/server/tests/models/test_model.py b/server/tests/models/test_model.py new file mode 100644 index 00000000..32bcd45f --- /dev/null +++ b/server/tests/models/test_model.py @@ -0,0 +1,78 @@ +import pytest +import torch + +from transformers import AutoTokenizer + +from text_generation_server.models import Model + + +def get_test_model(): + class TestModel(Model): + def batch_type(self): + raise NotImplementedError + + def generate_token(self, batch): + raise NotImplementedError + + tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b") + + model = TestModel( + torch.nn.Linear(1, 1), tokenizer, False, torch.float32, torch.device("cpu") + ) + return model + + +@pytest.mark.private +def test_decode_streaming_english_spaces(): + model = get_test_model() + truth = "Hello here, this is a simple test" + all_input_ids = [15043, 1244, 29892, 445, 338, 263, 2560, 1243] + assert ( + all_input_ids == model.tokenizer(truth, add_special_tokens=False)["input_ids"] + ) + + decoded_text = "" + offset = 0 + token_offset = 0 + for i in range(len(all_input_ids)): + text, offset, token_offset = model.decode_token( + all_input_ids[: i + 1], offset, token_offset + ) + decoded_text += text + + assert decoded_text == truth + + +@pytest.mark.private +def test_decode_streaming_chinese_utf8(): + model = get_test_model() + truth = "我很感谢你的热情" + all_input_ids = [ + 30672, + 232, + 193, + 139, + 233, + 135, + 162, + 235, + 179, + 165, + 30919, + 30210, + 234, + 134, + 176, + 30993, + ] + + decoded_text = "" + offset = 0 + token_offset = 0 + for i in range(len(all_input_ids)): + text, offset, token_offset = model.decode_token( + all_input_ids[: i + 1], offset, token_offset + ) + decoded_text += text + + assert decoded_text == truth diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index 65dafa50..ba769e75 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -149,7 +149,7 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch) assert all([generation.generated_text is None for generation in generations]) assert all([len(generation.prefill_tokens) == 1 for generation in generations]) assert all([generation.token_id.item() == 259 for generation in generations]) - assert all([generation.token_text == "" for generation in generations]) + assert all([generation.token_text == " " for generation in generations]) assert generations[0].request_id == 0 diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 9029e954..1f324f77 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -56,7 +56,7 @@ class BLOOM(CausalLM): quantize: Optional[str] = None, ): super(BLOOM, self).__init__( - model_id=model_id, revision=revision, quantize=quantize, decode_buffer=1 + model_id=model_id, revision=revision, quantize=quantize ) @property @@ -104,14 +104,13 @@ class BLOOMSharded(BLOOM): rank=rank, world_size=world_size, ) - self.model = model.eval() torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( + model=model, tokenizer=tokenizer, requires_padding=True, dtype=dtype, device=device, - decode_buffer=1, rank=rank, world_size=world_size, ) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 0d521ac4..9d8ae254 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -35,8 +35,8 @@ class CausalLMBatch(Batch): # Lengths of all generations present in the batch input_lengths: List[int] - offsets: List[Optional[int]] - token_offsets: List[Optional[int]] + prefix_offsets: List[int] + read_offsets: List[int] # Generation helpers next_token_choosers: List[NextTokenChooser] @@ -70,8 +70,8 @@ class CausalLMBatch(Batch): inputs = [] next_token_choosers = [] stopping_criterias = [] - offsets = [] - token_offsets = [] + prefix_offsets = [] + read_offsets = [] requests_idx_mapping = {} # Parse batch @@ -81,8 +81,6 @@ class CausalLMBatch(Batch): for i, r in enumerate(pb.requests): requests_idx_mapping[r.id] = i inputs.append(r.inputs) - offsets.append(None) - token_offsets.append(None) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer @@ -102,6 +100,10 @@ class CausalLMBatch(Batch): truncation=True, max_length=max_truncation, ).to(device) + for _ in pb.requests: + input_len = tokenized_inputs["input_ids"].shape[1] + prefix_offsets.append(0) + read_offsets.append(input_len) input_lengths = tokenized_inputs["attention_mask"].sum(1) max_input_length = input_lengths.max() @@ -130,8 +132,8 @@ class CausalLMBatch(Batch): past_key_values=None, all_input_ids=list(all_input_ids), input_lengths=input_lengths.tolist(), - offsets=offsets, - token_offsets=token_offsets, + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, max_input_length=max_input_length.item(), @@ -151,8 +153,8 @@ class CausalLMBatch(Batch): # New values after filtering requests_idx_mapping = {} input_lengths = [] - offsets = [] - token_offsets = [] + prefix_offsets = [] + read_offsets = [] all_input_ids = [] max_input_length = 0 @@ -167,8 +169,8 @@ class CausalLMBatch(Batch): requests_idx_mapping[r.id] = i keep_indices.append(idx) - offsets.append(self.offsets[idx]) - token_offsets.append(self.token_offsets[idx]) + prefix_offsets.append(self.prefix_offsets[idx]) + read_offsets.append(self.read_offsets[idx]) all_input_ids.append(self.all_input_ids[idx]) request_input_length = self.input_lengths[idx] @@ -225,8 +227,8 @@ class CausalLMBatch(Batch): self.position_ids = position_ids self.all_input_ids = all_input_ids self.input_lengths = input_lengths - self.offsets = offsets - self.token_offsets = token_offsets + self.prefix_offsets = prefix_offsets + self.read_offsets = read_offsets self.next_token_choosers = next_token_choosers self.stopping_criterias = stopping_criterias self.max_input_length = max_input_length @@ -251,8 +253,8 @@ class CausalLMBatch(Batch): requests = [] requests_idx_mapping = {} input_lengths = [] - offsets = [] - token_offsets = [] + prefix_offsets = [] + read_offsets = [] all_input_ids = [] next_token_choosers = [] stopping_criterias = [] @@ -270,8 +272,8 @@ class CausalLMBatch(Batch): for i, batch in enumerate(batches): requests.extend(batch.requests) input_lengths.extend(batch.input_lengths) - offsets.extend(batch.offsets) - token_offsets.extend(batch.token_offsets) + prefix_offsets.extend(batch.prefix_offsets) + read_offsets.extend(batch.read_offsets) all_input_ids.extend(batch.all_input_ids) next_token_choosers.extend(batch.next_token_choosers) stopping_criterias.extend(batch.stopping_criterias) @@ -428,8 +430,8 @@ class CausalLMBatch(Batch): past_key_values=past_key_values, all_input_ids=all_input_ids, input_lengths=input_lengths, - offsets=offsets, - token_offsets=token_offsets, + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, max_input_length=max_input_length, @@ -448,7 +450,6 @@ class CausalLM(Model): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - decode_buffer: int = 3, ): if torch.cuda.is_available(): device = torch.device("cuda") @@ -463,25 +464,25 @@ class CausalLM(Model): tokenizer = AutoTokenizer.from_pretrained( model_id, revision=revision, padding_side="left", truncation_side="left" ) - self.model = AutoModelForCausalLM.from_pretrained( + model = AutoModelForCausalLM.from_pretrained( model_id, revision=revision, torch_dtype=dtype, device_map="auto" if torch.cuda.is_available() else None, load_in_8bit=quantize == "bitsandbytes", - ).eval() + ) tokenizer.pad_token_id = ( - self.model.config.pad_token_id - if self.model.config.pad_token_id is not None - else self.model.config.eos_token_id + model.config.pad_token_id + if model.config.pad_token_id is not None + else model.config.eos_token_id ) super(CausalLM, self).__init__( + model=model, tokenizer=tokenizer, requires_padding=True, dtype=dtype, device=device, - decode_buffer=decode_buffer, ) @property @@ -528,8 +529,8 @@ class CausalLM(Model): iterator = zip( batch.requests, batch.input_lengths, - batch.offsets, - batch.token_offsets, + batch.prefix_offsets, + batch.read_offsets, logits, batch.next_token_choosers, batch.stopping_criterias, @@ -540,8 +541,8 @@ class CausalLM(Model): for i, ( request, input_length, - offset, - token_offset, + prefix_offset, + read_offset, logits, next_token_chooser, stopping_criteria, @@ -559,8 +560,8 @@ class CausalLM(Model): # Generated token next_token_logprob = logprobs[-1, next_token_id] next_token_id_squeezed = next_token_id.squeeze() - next_token_text, offset, token_offset = self.decode_token( - all_input_ids[:, 0], offset, token_offset + next_token_text, prefix_offset, read_offset = self.decode_token( + all_input_ids[:, 0], prefix_offset, read_offset ) # Evaluate stopping criteria @@ -628,8 +629,8 @@ class CausalLM(Model): batch.input_ids[i, 0] = next_token_id batch.all_input_ids[i] = all_input_ids batch.input_lengths[i] = new_input_length - batch.offsets[i] = offset - batch.token_offsets[i] = token_offset + batch.prefix_offsets[i] = prefix_offset + batch.read_offsets[i] = read_offset batch.max_input_length = max(batch.max_input_length, new_input_length) # We finished all generations in the batch; there is no next batch diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 0a9fccca..aee0480d 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -52,8 +52,8 @@ class FlashCausalLMBatch(Batch): # Lengths of all generations present in the batch input_lengths: List[int] - offsets: List[Optional[int]] - token_offsets: List[Optional[int]] + prefix_offsets: List[Optional[int]] + read_offsets: List[Optional[int]] # Generation helpers next_token_choosers: List[NextTokenChooser] @@ -82,8 +82,8 @@ class FlashCausalLMBatch(Batch): max_seqlen = 0 input_lengths = [] - offsets = [] - token_offsets = [] + prefix_offsets = [] + read_offsets = [] all_input_ids = [] requests_idx_mapping = {} @@ -108,8 +108,8 @@ class FlashCausalLMBatch(Batch): max_seqlen = max(max_seqlen, input_length) input_lengths.append(input_length) - offsets.append(None) - token_offsets.append(None) + prefix_offsets.append(0) + read_offsets.append(input_length) all_input_ids.append(tokenized_input) @@ -151,8 +151,8 @@ class FlashCausalLMBatch(Batch): max_seqlen=max_seqlen, past_key_values=None, input_lengths=input_lengths, - offsets=offsets, - token_offsets=token_offsets, + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=[], next_token_choosers=next_token_choosers, @@ -190,8 +190,8 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor = [] input_lengths = [] - offsets = [] - token_offsets = [] + prefix_offsets = [] + read_offsets = [] next_token_choosers = [] stopping_criterias = [] @@ -222,8 +222,8 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor.append(self.all_input_ids_tensor[idx]) input_lengths.append(request_input_length) - offsets.append(self.offsets[idx]) - token_offsets.append(self.token_offsets[idx]) + prefix_offsets.append(self.prefix_offsets[idx]) + read_offsets.append(self.read_offsets[idx]) next_token_choosers.append(self.next_token_choosers[idx]) @@ -269,8 +269,8 @@ class FlashCausalLMBatch(Batch): max_seqlen=max_seqlen, past_key_values=past_key_values, input_lengths=input_lengths, - offsets=offsets, - token_offsets=token_offsets, + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, next_token_choosers=next_token_choosers, @@ -302,8 +302,8 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor = [] input_lengths = [] - offsets = [] - token_offsets = [] + prefix_offsets = [] + read_offsets = [] next_token_choosers = [] stopping_criterias = [] @@ -347,8 +347,8 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor.extend(batch.all_input_ids_tensor) input_lengths.extend(batch.input_lengths) - offsets.extend(batch.offsets) - token_offsets.extend(batch.token_offsets) + prefix_offsets.extend(batch.prefix_offsets) + read_offsets.extend(batch.read_offsets) next_token_choosers.extend(batch.next_token_choosers) stopping_criterias.extend(batch.stopping_criterias) @@ -374,8 +374,8 @@ class FlashCausalLMBatch(Batch): max_seqlen=max_seqlen, past_key_values=past_key_values, input_lengths=input_lengths, - offsets=offsets, - token_offsets=token_offsets, + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, next_token_choosers=next_token_choosers, @@ -394,7 +394,6 @@ class FlashCausalLM(Model): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - decode_buffer: int = 3, ): if torch.cuda.is_available(): device = torch.device("cuda") @@ -405,23 +404,19 @@ class FlashCausalLM(Model): tokenizer = AutoTokenizer.from_pretrained( model_id, revision=revision, padding_side="left", truncation_side="left" ) - self.model = ( - model_cls.from_pretrained( - model_id, - revision=revision, - torch_dtype=dtype, - load_in_8bit=quantize == "bitsandbytes", - ) - .eval() - .to(device) - ) + model = model_cls.from_pretrained( + model_id, + revision=revision, + torch_dtype=dtype, + load_in_8bit=quantize == "bitsandbytes", + ).to(device) super(FlashCausalLM, self).__init__( + model=model, tokenizer=tokenizer, requires_padding=False, dtype=dtype, device=device, - decode_buffer=decode_buffer, ) @property @@ -645,8 +640,8 @@ class FlashCausalLM(Model): iterator = zip( batch.requests, batch.input_lengths, - batch.offsets, - batch.token_offsets, + batch.prefix_offsets, + batch.read_offsets, batch.next_token_choosers, batch.stopping_criterias, batch.all_input_ids, @@ -659,8 +654,8 @@ class FlashCausalLM(Model): for i, ( request, input_length, - offset, - token_offset, + prefix_offset, + read_offset, next_token_chooser, stopping_criteria, all_input_ids, @@ -675,10 +670,10 @@ class FlashCausalLM(Model): all_input_ids.append(next_token_id) # Generated token - next_token_text, offset, token_offset = self.decode_token( + next_token_text, prefix_offset, read_offset = self.decode_token( all_input_ids, - offset, - token_offset, + prefix_offset, + read_offset, ) # Evaluate stopping criteria @@ -744,8 +739,8 @@ class FlashCausalLM(Model): # Update values batch.input_lengths[i] = new_input_length - batch.offsets[i] = offset - batch.token_offsets[i] = token_offset + batch.prefix_offsets[i] = prefix_offset + batch.read_offsets[i] = read_offset batch.all_input_ids[i] = all_input_ids batch.max_seqlen = batch.max_seqlen + 1 cumulative_length += input_length diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index b775bd79..ebdbe206 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -64,9 +64,9 @@ class FlashLlama(FlashCausalLM): model = FlashLlamaForCausalLM(config) self.load_weights(model, filenames, quantize, device, dtype) - self.model = model.eval().to(device) super(FlashCausalLM, self).__init__( + model=model.to(device), tokenizer=tokenizer, requires_padding=False, dtype=dtype, @@ -189,9 +189,9 @@ class FlashLlamaSharded(FlashLlama): rank=rank, world_size=world_size, ) - self.model = model.eval().to(device) torch.distributed.barrier(group=self.process_group) super(FlashCausalLM, self).__init__( + model=model.to(device), tokenizer=tokenizer, requires_padding=False, dtype=dtype, diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 0924f107..cac40bab 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -73,9 +73,9 @@ class FlashNeoXSharded(FlashNeoX): rank=rank, world_size=world_size, ) - self.model = model.eval().to(device) torch.distributed.barrier(group=self.process_group) super(FlashCausalLM, self).__init__( + model=model.to(device), tokenizer=tokenizer, requires_padding=False, dtype=dtype, diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 031a67eb..5dc31309 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -67,14 +67,13 @@ class FlashSantacoder(FlashCausalLM): dtype, config.architectures[0].startswith("GPT2"), ) - self.model = model.eval().to(device) super(FlashCausalLM, self).__init__( + model=model.to(device), tokenizer=tokenizer, requires_padding=False, dtype=dtype, device=device, - decode_buffer=1, ) @staticmethod @@ -213,16 +212,15 @@ class FlashSantacoderSharded(FlashSantacoder): world_size=world_size, transpose=config.architectures[0].startswith("GPT2"), ) - self.model = model.eval().to(device) torch.distributed.barrier(group=self.process_group) super(FlashCausalLM, self).__init__( + model=model.to(device), tokenizer=tokenizer, requires_padding=False, dtype=dtype, device=device, rank=rank, world_size=world_size, - decode_buffer=1, ) @staticmethod diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index d1e5e841..24c37c19 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -94,8 +94,8 @@ class GalacticaCausalLMBatch(CausalLMBatch): inputs = [] next_token_choosers = [] stopping_criterias = [] - offsets = [] - token_offsets = [] + prefix_offsets = [] + read_offsets = [] requests_idx_mapping = {} # Parse batch @@ -106,8 +106,6 @@ class GalacticaCausalLMBatch(CausalLMBatch): requests_idx_mapping[r.id] = i # Add escape_custom_split_sequence to the CausalLMBatch logic inputs.append(escape_custom_split_sequence(r.inputs)) - offsets.append(None) - token_offsets.append(None) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer @@ -127,6 +125,10 @@ class GalacticaCausalLMBatch(CausalLMBatch): truncation=True, max_length=max_truncation, ).to(device) + for _ in pb.requests: + input_len = tokenized_inputs["input_ids"].shape[1] + prefix_offsets.append(0) + read_offsets.append(input_len) input_lengths = tokenized_inputs["attention_mask"].sum(1) max_input_length = input_lengths.max() @@ -155,8 +157,8 @@ class GalacticaCausalLMBatch(CausalLMBatch): past_key_values=None, all_input_ids=list(all_input_ids), input_lengths=input_lengths.tolist(), - offsets=offsets, - token_offsets=token_offsets, + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, max_input_length=max_input_length.item(), @@ -231,9 +233,9 @@ class GalacticaSharded(Galactica): rank=rank, world_size=world_size, ) - self.model = model.eval() torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( + model=model, tokenizer=tokenizer, requires_padding=True, dtype=dtype, diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index f95e5be2..a10dfcb8 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -70,9 +70,9 @@ class GPTNeoxSharded(CausalLM): rank=rank, world_size=world_size, ) - self.model = model.eval() torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( + model=model, tokenizer=tokenizer, requires_padding=True, dtype=dtype, diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 03f14013..29bad321 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -13,23 +13,20 @@ B = TypeVar("B", bound=Batch) class Model(ABC): def __init__( self, + model: torch.nn.Module, tokenizer: PreTrainedTokenizerBase, requires_padding: bool, dtype: torch.dtype, device: torch.device, - decode_buffer: int = 3, rank: int = 0, world_size: int = 1, ): - if decode_buffer < 1: - raise ValueError("decode_buffer must be >= 1") - + self.model = model.eval() self.tokenizer = tokenizer self.all_special_ids = set(tokenizer.all_special_ids) self.requires_padding = requires_padding self.dtype = dtype self.device = device - self.decode_buffer = decode_buffer self.rank = rank self.world_size = world_size self.check_initialized() @@ -54,52 +51,29 @@ class Model(ABC): def decode_token( self, all_input_ids: List[int], - offset: Optional[int] = None, - token_offset: Optional[int] = None, - ) -> Tuple[str, Optional[int], Optional[int]]: + prefix_offset: int = 0, + read_offset: int = 0, + ) -> Tuple[str, int, int]: """Hack to hopefully support generate_stream for the maximum number of tokenizers""" - if all_input_ids[-1] in self.all_special_ids: - return ( - self.tokenizer.decode(all_input_ids[-1], skip_special_tokens=False), - None, - None, - ) - if token_offset is None: - token_offset = len(all_input_ids) - self.decode_buffer - # left token buffer - if self.decode_buffer > 1: - # Decode token_offset token minus last one and token_offset tokens - raw_texts = self.tokenizer.batch_decode( - [all_input_ids[token_offset:-1], all_input_ids[token_offset:]], - skip_special_tokens=False, - ) + # The prefix text is necessary only to defeat cleanup algorithms in the decode + # which decide to add a space or not depending on the surrounding ids. + prefix_text = self.tokenizer.decode( + all_input_ids[prefix_offset:read_offset], skip_special_tokens=False + ) + new_text = self.tokenizer.decode( + all_input_ids[prefix_offset:], skip_special_tokens=False + ) - # default offset is only the last token - offset = len(raw_texts[0]) - sequence_text = raw_texts[1] - else: - # Only decode the last token without using a token buffer - sequence_text = self.tokenizer.decode( - all_input_ids[-1], skip_special_tokens=False - ) - # no offset in this case - offset = 0 + if len(new_text) > len(prefix_text) and not new_text.endswith("�"): + # utf-8 char at the end means it's a potential unfinished byte sequence + # from byte fallback tokenization. + # If it's in the middle, it's probably a real invalid id generated + # by the model + new_text = new_text[len(prefix_text) :] + return new_text, read_offset, len(all_input_ids) else: - assert offset is not None - sequence_text = self.tokenizer.decode( - all_input_ids[token_offset:], - skip_special_tokens=False, - ) - - # get text - token_text = sequence_text[offset:] - - # if text is utf-8 - if token_text and token_text[-1] != "�": - return token_text, None, None - else: - return "", offset, token_offset + return "", prefix_offset, read_offset def check_initialized(self): uninitialized_parameters = [] diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index 093cf70a..fdae795b 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -86,9 +86,9 @@ class OPTSharded(OPT): rank=rank, world_size=world_size, ) - self.model = model.eval() torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( + model=model, tokenizer=tokenizer, requires_padding=True, dtype=dtype, diff --git a/server/text_generation_server/models/santacoder.py b/server/text_generation_server/models/santacoder.py index 4bd56de1..23f89f48 100644 --- a/server/text_generation_server/models/santacoder.py +++ b/server/text_generation_server/models/santacoder.py @@ -46,24 +46,20 @@ class SantaCoder(CausalLM): } ) - self.model = ( - AutoModelForCausalLM.from_pretrained( - model_id, - revision=revision, - torch_dtype=dtype, - load_in_8bit=quantize == "bitsandbytes", - trust_remote_code=True, # required - ) - .to(device) - .eval() - ) + model = AutoModelForCausalLM.from_pretrained( + model_id, + revision=revision, + torch_dtype=dtype, + load_in_8bit=quantize == "bitsandbytes", + trust_remote_code=True, # required + ).to(device) super(CausalLM, self).__init__( + model=model, tokenizer=tokenizer, requires_padding=True, dtype=dtype, device=device, - decode_buffer=1, ) def decode(self, generated_ids: List[int]) -> str: diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 84854f5d..4f55b22f 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -42,8 +42,8 @@ class Seq2SeqLMBatch(Batch): # Lengths of all generations present in the batch input_lengths: List[int] decoder_input_lengths: List[int] - offsets: List[Optional[int]] - token_offsets: List[Optional[int]] + prefix_offsets: List[int] + read_offsets: List[int] # Generation helpers next_token_choosers: List[NextTokenChooser] @@ -79,8 +79,8 @@ class Seq2SeqLMBatch(Batch): stopping_criterias = [] decoder_input_lengths = [] - offsets = [] - token_offsets = [] + prefix_offsets = [] + read_offsets = [] requests_idx_mapping = {} # Parse batch @@ -91,8 +91,6 @@ class Seq2SeqLMBatch(Batch): inputs.append(r.inputs) requests_idx_mapping[r.id] = i decoder_input_lengths.append(1) - offsets.append(None) - token_offsets.append(None) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer @@ -123,6 +121,9 @@ class Seq2SeqLMBatch(Batch): .repeat(len(pb.requests)) .view(-1, 1) ) + for _ in pb.requests: + prefix_offsets.append(0) + read_offsets.append(1) all_decoder_input_ids = decoder_input_ids.view(-1).split(1) max_tokens = len(inputs) * max_input_length + max_decode_tokens @@ -140,8 +141,8 @@ class Seq2SeqLMBatch(Batch): past_key_values=None, input_lengths=input_lengths.tolist(), decoder_input_lengths=decoder_input_lengths, - offsets=offsets, - token_offsets=token_offsets, + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, max_input_length=max_input_length.item(), @@ -165,8 +166,8 @@ class Seq2SeqLMBatch(Batch): requests_idx_mapping = {} input_lengths = [] decoder_input_lengths = [] - offsets = [] - token_offsets = [] + prefix_offsets = [] + read_offsets = [] all_decoder_input_ids = [] @@ -184,8 +185,8 @@ class Seq2SeqLMBatch(Batch): requests_idx_mapping[r.id] = i keep_indices.append(idx) - offsets.append(self.offsets[idx]) - token_offsets.append(self.token_offsets[idx]) + prefix_offsets.append(self.prefix_offsets[idx]) + read_offsets.append(self.read_offsets[idx]) all_decoder_input_ids.append(self.all_decoder_input_ids[idx]) @@ -248,8 +249,8 @@ class Seq2SeqLMBatch(Batch): self.all_decoder_input_ids = all_decoder_input_ids self.input_lengths = input_lengths self.decoder_input_lengths = decoder_input_lengths - self.offsets = offsets - self.token_offsets = token_offsets + self.prefix_offsets = prefix_offsets + self.read_offsets = read_offsets self.next_token_choosers = next_token_choosers self.stopping_criterias = stopping_criterias self.max_input_length = max_input_length @@ -283,8 +284,8 @@ class Seq2SeqLMBatch(Batch): all_decoder_input_ids = [] input_lengths = [] decoder_input_lengths = [] - offsets = [] - token_offsets = [] + prefix_offsets = [] + read_offsets = [] next_token_choosers = [] stopping_criterias = [] max_tokens = 0 @@ -306,8 +307,8 @@ class Seq2SeqLMBatch(Batch): all_decoder_input_ids.extend(batch.all_decoder_input_ids) input_lengths.extend(batch.input_lengths) decoder_input_lengths.extend(batch.decoder_input_lengths) - offsets.extend(batch.offsets) - token_offsets.extend(batch.token_offsets) + prefix_offsets.extend(batch.prefix_offsets) + read_offsets.extend(batch.read_offsets) next_token_choosers.extend(batch.next_token_choosers) stopping_criterias.extend(batch.stopping_criterias) @@ -482,8 +483,8 @@ class Seq2SeqLMBatch(Batch): past_key_values=past_key_values, input_lengths=input_lengths, decoder_input_lengths=decoder_input_lengths, - offsets=offsets, - token_offsets=token_offsets, + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, max_input_length=max_input_length, @@ -502,7 +503,6 @@ class Seq2SeqLM(Model): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - decode_buffer: int = 3, ): if torch.cuda.is_available(): device = torch.device("cuda") @@ -514,24 +514,24 @@ class Seq2SeqLM(Model): device = torch.device("cpu") dtype = torch.float32 - self.model = AutoModelForSeq2SeqLM.from_pretrained( + model = AutoModelForSeq2SeqLM.from_pretrained( model_id, revision=revision, torch_dtype=dtype, device_map="auto" if torch.cuda.is_available() else None, load_in_8bit=quantize == "bitsandbytes", - ).eval() + ) tokenizer = AutoTokenizer.from_pretrained( model_id, revision=revision, padding_side="left", truncation_side="left" ) - tokenizer.bos_token_id = self.model.config.decoder_start_token_id + tokenizer.bos_token_id = model.config.decoder_start_token_id super(Seq2SeqLM, self).__init__( + model=model, tokenizer=tokenizer, requires_padding=True, dtype=dtype, device=device, - decode_buffer=decode_buffer, ) @property @@ -608,8 +608,8 @@ class Seq2SeqLM(Model): iterator = zip( batch.requests, batch.input_lengths, - batch.offsets, - batch.token_offsets, + batch.prefix_offsets, + batch.read_offsets, batch.decoder_input_lengths, logits, batch.next_token_choosers, @@ -621,8 +621,8 @@ class Seq2SeqLM(Model): for i, ( request, input_length, - offset, - token_offset, + prefix_offset, + read_offset, decoder_input_length, logits, next_token_chooser, @@ -643,8 +643,8 @@ class Seq2SeqLM(Model): # Generated token next_token_logprob = logprobs[-1, next_token_id] next_token_id_squeezed = next_token_id.squeeze() - next_token_text, offset, token_offset = self.decode_token( - all_decoder_input_ids, offset, token_offset + next_token_text, prefix_offset, read_offset = self.decode_token( + all_decoder_input_ids, prefix_offset, read_offset ) # Evaluate stopping criteria @@ -702,8 +702,8 @@ class Seq2SeqLM(Model): batch.all_decoder_input_ids[i] = all_decoder_input_ids batch.input_lengths[i] = input_length batch.decoder_input_lengths[i] = new_decoder_input_length - batch.offsets[i] = offset - batch.token_offsets[i] = token_offset + batch.prefix_offsets[i] = prefix_offset + batch.read_offsets[i] = read_offset batch.max_input_length = max(batch.max_input_length, input_length) batch.max_decoder_input_length = max( batch.max_decoder_input_length, new_decoder_input_length diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 8e3826a4..b1ba2432 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -16,9 +16,6 @@ from text_generation_server.utils import ( initialize_torch_distributed, weight_files, ) -from text_generation_server.utils.layers import ( - FastLinear, -) from transformers.models.t5.parallel_layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -73,9 +70,9 @@ class T5Sharded(Seq2SeqLM): rank=rank, world_size=world_size, ) - self.model = model.eval() torch.distributed.barrier(group=self.process_group) super(Seq2SeqLM, self).__init__( + model=model, tokenizer=tokenizer, requires_padding=True, dtype=dtype,