From 62f91f78ac4ab747505c9fbf51fa53ba2d3f3d61 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Fri, 26 May 2023 12:30:27 +0200 Subject: [PATCH] feat(server): support vectorized warpers in flash causal lm (#317) Co-authored-by: Joel Lamy-Poirier --- .../test_bloom_560m_all_params.json | 96 ++--- .../test_flash_llama_all_params.json | 58 +-- .../test_flash_starcoder_default_params.json | 364 ++++++++++++++-- .../test_mt0_base_all_params.json | 48 +-- integration-tests/models/test_flash_llama.py | 2 +- .../models/test_flash_starcoder.py | 2 +- integration-tests/models/test_mt0_base.py | 2 +- server/tests/models/test_bloom.py | 8 +- server/tests/models/test_causal_lm.py | 12 +- server/tests/models/test_santacoder.py | 10 +- server/tests/models/test_seq2seq_lm.py | 10 +- server/text_generation_server/models/bloom.py | 3 +- .../models/causal_lm.py | 1 + .../models/flash_causal_lm.py | 174 ++++---- .../models/galactica.py | 1 + .../models/seq2seq_lm.py | 1 + server/text_generation_server/models/types.py | 1 + server/text_generation_server/server.py | 2 +- .../text_generation_server/utils/__init__.py | 6 +- .../text_generation_server/utils/convert.py | 6 +- .../utils/logits_process.py | 405 ++++++++++++++++++ server/text_generation_server/utils/tokens.py | 282 ++++++++---- 22 files changed, 1151 insertions(+), 343 deletions(-) create mode 100644 server/text_generation_server/utils/logits_process.py diff --git a/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m_all_params.json b/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m_all_params.json index 93a95804..ace73416 100644 --- a/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m_all_params.json +++ b/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m_all_params.json @@ -34,65 +34,65 @@ "tokens": [ { "id": 408, - "logprob": -1.9267578, + "logprob": -0.07891846, "special": false, "text": " que" }, - { - "id": 20288, - "logprob": -2.9257812, - "special": false, - "text": " l'on" - }, - { - "id": 22255, - "logprob": -2.8964844, - "special": false, - "text": " trouve" - }, - { - "id": 1622, - "logprob": -1.1083984, - "special": false, - "text": " une" - }, - { - "id": 187079, - "logprob": -7.796875, - "special": false, - "text": " posture" - }, - { - "id": 501, - "logprob": -5.390625, - "special": false, - "text": " par" - }, - { - "id": 8741, - "logprob": -0.34936523, - "special": false, - "text": " rapport" - }, - { - "id": 693, - "logprob": 0.0, - "special": false, - "text": " à" - }, { "id": 366, - "logprob": -2.3378906, + "logprob": -1.2939453, "special": false, "text": " la" }, { - "id": 36503, - "logprob": -3.6640625, + "id": 8769, + "logprob": -0.3708496, "special": false, - "text": " pratique" + "text": " personne" + }, + { + "id": 1479, + "logprob": -2.2871094, + "special": false, + "text": " qui" + }, + { + "id": 2997, + "logprob": -0.8671875, + "special": false, + "text": " vous" + }, + { + "id": 35977, + "logprob": -1.5097656, + "special": false, + "text": " suit" + }, + { + "id": 21558, + "logprob": -0.07891846, + "special": false, + "text": " ait" + }, + { + "id": 447, + "logprob": -0.12695312, + "special": false, + "text": " un" + }, + { + "id": 78606, + "logprob": -2.21875, + "special": false, + "text": " profil" + }, + { + "id": 3899, + "logprob": -1.3535156, + "special": false, + "text": " bien" } ] }, - "generated_text": "Pour déguster un ortolan, il faut tout d'abord que l'on trouve une posture par rapport à la pratique" + "generated_text": "Pour déguster un ortolan, il faut tout d'abord que la personne qui vous suit ait un profil bien" } diff --git a/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_all_params.json index 1b6b51a3..5be2870d 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_all_params.json +++ b/integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_all_params.json @@ -1,8 +1,8 @@ { "details": { "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, + "finish_reason": "stop_sequence", + "generated_tokens": 5, "prefill": [ { "id": 1, @@ -24,65 +24,35 @@ "tokens": [ { "id": 5229, - "logprob": -3.3085938, + "logprob": -2.5683594, "special": false, "text": " failed" }, - { - "id": 363, - "logprob": -3.984375, - "special": false, - "text": " for" - }, - { - "id": 5641, - "logprob": -6.53125, - "special": false, - "text": " IP" - }, - { - "id": 16428, - "logprob": -3.1835938, - "special": false, - "text": " Address" - }, { "id": 29901, - "logprob": -1.2324219, + "logprob": -0.45336914, "special": false, "text": ":" }, { - "id": 525, - "logprob": -2.6855469, + "id": 4829, + "logprob": -1.8408203, "special": false, - "text": " '" + "text": " Error" }, { - "id": 8516, - "logprob": -7.1601562, + "id": 297, + "logprob": -1.0556641, "special": false, - "text": "None" + "text": " in" }, { - "id": 4286, - "logprob": -2.4433594, + "id": 1243, + "logprob": 0.0, "special": false, - "text": "'." - }, - { - "id": 13, - "logprob": -0.06530762, - "special": false, - "text": "\n" - }, - { - "id": 294, - "logprob": -7.953125, - "special": false, - "text": "as" + "text": " test" } ] }, - "generated_text": "Test requestfailed for IP Address: 'None'.\nas" + "generated_text": "Test requestfailed: Error in test" } diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder_default_params.json b/integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder_default_params.json index 21bb509b..afd0b662 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder_default_params.json +++ b/integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder_default_params.json @@ -1,8 +1,8 @@ { "details": { "best_of_sequences": null, - "finish_reason": "eos_token", - "generated_tokens": 12, + "finish_reason": "length", + "generated_tokens": 60, "prefill": [ { "id": 589, @@ -29,77 +29,365 @@ "tokens": [ { "id": 2262, - "logprob": -0.7451172, + "logprob": -0.042999268, "special": false, "text": "():" }, { "id": 284, - "logprob": -0.21325684, + "logprob": 0.0, "special": false, "text": "\n " }, { - "id": 5741, - "logprob": -5.734375, - "special": false, - "text": " logging" - }, - { - "id": 32, + "id": 1459, "logprob": 0.0, "special": false, - "text": "." + "text": " print" }, { - "id": 1338, - "logprob": -0.3232422, + "id": 440, + "logprob": 0.0, "special": false, - "text": "info" - }, - { - "id": 463, - "logprob": -1.0380859, - "special": false, - "text": "('" + "text": "(\"" }, { "id": 8279, - "logprob": -0.8378906, + "logprob": 0.0, "special": false, "text": "Hello" }, - { - "id": 30, - "logprob": -1.9501953, - "special": false, - "text": "," - }, { "id": 10896, - "logprob": -1.3476562, + "logprob": -0.3659668, "special": false, "text": " World" }, { - "id": 683, - "logprob": -1.796875, + "id": 657, + "logprob": -0.49804688, "special": false, - "text": "')" + "text": "\")" }, { "id": 203, - "logprob": -0.9873047, + "logprob": -0.11279297, "special": false, "text": "\n" }, { - "id": 0, - "logprob": -0.7495117, - "special": true, - "text": "<|endoftext|>" + "id": 203, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 589, + "logprob": -0.20141602, + "special": false, + "text": "def" + }, + { + "id": 1459, + "logprob": 0.0, + "special": false, + "text": " print" + }, + { + "id": 81, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 7656, + "logprob": 0.0, + "special": false, + "text": "hello" + }, + { + "id": 81, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 426, + "logprob": -0.051635742, + "special": false, + "text": "name" + }, + { + "id": 26, + "logprob": 0.0, + "special": false, + "text": "(" + }, + { + "id": 426, + "logprob": 0.0, + "special": false, + "text": "name" + }, + { + "id": 711, + "logprob": 0.0, + "special": false, + "text": "):" + }, + { + "id": 284, + "logprob": 0.0, + "special": false, + "text": "\n " + }, + { + "id": 1459, + "logprob": 0.0, + "special": false, + "text": " print" + }, + { + "id": 440, + "logprob": -0.16027832, + "special": false, + "text": "(\"" + }, + { + "id": 8279, + "logprob": 0.0, + "special": false, + "text": "Hello" + }, + { + "id": 313, + "logprob": 0.0, + "special": false, + "text": " \"" + }, + { + "id": 474, + "logprob": 0.0, + "special": false, + "text": " +" + }, + { + "id": 636, + "logprob": 0.0, + "special": false, + "text": " name" + }, + { + "id": 27, + "logprob": 0.0, + "special": false, + "text": ")" + }, + { + "id": 203, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 203, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 589, + "logprob": 0.0, + "special": false, + "text": "def" + }, + { + "id": 1459, + "logprob": 0.0, + "special": false, + "text": " print" + }, + { + "id": 81, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 7656, + "logprob": 0.0, + "special": false, + "text": "hello" + }, + { + "id": 81, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 426, + "logprob": 0.0, + "special": false, + "text": "name" + }, + { + "id": 81, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 381, + "logprob": 0.0, + "special": false, + "text": "age" + }, + { + "id": 26, + "logprob": 0.0, + "special": false, + "text": "(" + }, + { + "id": 426, + "logprob": 0.0, + "special": false, + "text": "name" + }, + { + "id": 30, + "logprob": 0.0, + "special": false, + "text": "," + }, + { + "id": 11442, + "logprob": 0.0, + "special": false, + "text": " age" + }, + { + "id": 711, + "logprob": 0.0, + "special": false, + "text": "):" + }, + { + "id": 284, + "logprob": 0.0, + "special": false, + "text": "\n " + }, + { + "id": 1459, + "logprob": 0.0, + "special": false, + "text": " print" + }, + { + "id": 440, + "logprob": 0.0, + "special": false, + "text": "(\"" + }, + { + "id": 8279, + "logprob": 0.0, + "special": false, + "text": "Hello" + }, + { + "id": 313, + "logprob": 0.0, + "special": false, + "text": " \"" + }, + { + "id": 474, + "logprob": 0.0, + "special": false, + "text": " +" + }, + { + "id": 636, + "logprob": 0.0, + "special": false, + "text": " name" + }, + { + "id": 474, + "logprob": 0.0, + "special": false, + "text": " +" + }, + { + "id": 313, + "logprob": -0.6328125, + "special": false, + "text": " \"" + }, + { + "id": 313, + "logprob": -1.7011719, + "special": false, + "text": " \"" + }, + { + "id": 474, + "logprob": 0.0, + "special": false, + "text": " +" + }, + { + "id": 596, + "logprob": 0.0, + "special": false, + "text": " str" + }, + { + "id": 26, + "logprob": 0.0, + "special": false, + "text": "(" + }, + { + "id": 381, + "logprob": 0.0, + "special": false, + "text": "age" + }, + { + "id": 490, + "logprob": 0.0, + "special": false, + "text": "))" + }, + { + "id": 203, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 203, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 589, + "logprob": 0.0, + "special": false, + "text": "def" + }, + { + "id": 1459, + "logprob": 0.0, + "special": false, + "text": " print" } ] }, - "generated_text": "():\n logging.info('Hello, World')\n<|endoftext|>" + "generated_text": "():\n print(\"Hello World\")\n\ndef print_hello_name(name):\n print(\"Hello \" + name)\n\ndef print_hello_name_age(name, age):\n print(\"Hello \" + name + \" \" + str(age))\n\ndef print" } diff --git a/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_all_params.json b/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_all_params.json index 3e9f3d73..024823d0 100644 --- a/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_all_params.json +++ b/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_all_params.json @@ -1,8 +1,8 @@ { "details": { "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, + "finish_reason": "eos_token", + "generated_tokens": 9, "prefill": [ { "id": 0, @@ -14,65 +14,59 @@ "tokens": [ { "id": 16017, - "logprob": -1.3505859, + "logprob": -0.30908203, "special": false, "text": " blue" }, { "id": 20495, - "logprob": -0.50439453, + "logprob": 0.0, "special": false, "text": " sky" }, { "id": 259, - "logprob": -1.2011719, + "logprob": -0.28271484, "special": false, "text": " " }, { "id": 15484, - "logprob": -2.8378906, + "logprob": -1.7929688, "special": false, "text": "appear" }, { "id": 345, - "logprob": -0.87597656, + "logprob": -0.8935547, "special": false, "text": "ed" }, { - "id": 288, - "logprob": -1.8447266, + "id": 281, + "logprob": 0.0, "special": false, - "text": " to" + "text": " in" }, { - "id": 35622, - "logprob": -7.1445312, + "id": 287, + "logprob": 0.0, "special": false, - "text": " cloud" + "text": " the" }, { - "id": 263, - "logprob": -1.2929688, + "id": 20495, + "logprob": -0.32299805, "special": false, - "text": "s" + "text": " sky" }, { - "id": 14701, - "logprob": -3.0761719, - "special": false, - "text": " above" - }, - { - "id": 751, - "logprob": -4.4375, - "special": false, - "text": " all" + "id": 1, + "logprob": 0.0, + "special": true, + "text": "" } ] }, - "generated_text": "Why is the sky blue?blue sky appeared to clouds above all" + "generated_text": "Why is the sky blue?blue sky appeared in the sky" } diff --git a/integration-tests/models/test_flash_llama.py b/integration-tests/models/test_flash_llama.py index 37468455..bf5b64ba 100644 --- a/integration-tests/models/test_flash_llama.py +++ b/integration-tests/models/test_flash_llama.py @@ -40,7 +40,7 @@ async def test_flash_llama_all_params(flash_llama, response_snapshot): seed=0, ) - assert response.details.generated_tokens == 10 + assert response.details.generated_tokens == 5 assert response == response_snapshot diff --git a/integration-tests/models/test_flash_starcoder.py b/integration-tests/models/test_flash_starcoder.py index 4c7393a7..c1a68d89 100644 --- a/integration-tests/models/test_flash_starcoder.py +++ b/integration-tests/models/test_flash_starcoder.py @@ -29,7 +29,7 @@ async def test_flash_starcoder_default_params(flash_starcoder, response_snapshot "def print_hello", max_new_tokens=60, temperature=0.2, top_p=0.95, seed=0 ) - assert response.details.generated_tokens == 12 + assert response.details.generated_tokens == 60 assert response == response_snapshot diff --git a/integration-tests/models/test_mt0_base.py b/integration-tests/models/test_mt0_base.py index 15410f73..e347d22a 100644 --- a/integration-tests/models/test_mt0_base.py +++ b/integration-tests/models/test_mt0_base.py @@ -43,7 +43,7 @@ async def test_mt0_base_all_params(mt0_base, response_snapshot): seed=0, ) - assert response.details.generated_tokens == 10 + assert response.details.generated_tokens == 9 assert response == response_snapshot diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 105b3573..590ba557 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -38,7 +38,7 @@ def default_pb_batch(default_pb_request): @pytest.fixture def default_bloom_batch(default_pb_batch, bloom_560m_tokenizer): return BloomCausalLMBatch.from_pb( - default_pb_batch, bloom_560m_tokenizer, torch.device("cpu") + default_pb_batch, bloom_560m_tokenizer, torch.float32, torch.device("cpu") ) @@ -52,7 +52,7 @@ def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer) batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2) return BloomCausalLMBatch.from_pb( - batch_pb, bloom_560m_tokenizer, torch.device("cpu") + batch_pb, bloom_560m_tokenizer, torch.float32, torch.device("cpu") ) @@ -286,7 +286,9 @@ def test_batch_concatenate( == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens ) - next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id]) + next_batch = next_batch.filter( + [next_batch.requests[0].id, next_batch.requests[1].id] + ) for _ in range( default_bloom_batch.stopping_criterias[0].max_new_tokens diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index d8d1bd16..3f28f5b3 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -38,7 +38,9 @@ def default_pb_batch(default_pb_request): @pytest.fixture def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer): - return CausalLMBatch.from_pb(default_pb_batch, gpt2_tokenizer, torch.device("cpu")) + return CausalLMBatch.from_pb( + default_pb_batch, gpt2_tokenizer, torch.float32, torch.device("cpu") + ) @pytest.fixture @@ -50,7 +52,9 @@ def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer): req_1.stopping_parameters.max_new_tokens = 5 batch_pb = generate_pb2.Batch(id=1, requests=[req_0, req_1], size=2) - return CausalLMBatch.from_pb(batch_pb, gpt2_tokenizer, torch.device("cpu")) + return CausalLMBatch.from_pb( + batch_pb, gpt2_tokenizer, torch.float32, torch.device("cpu") + ) def test_batch_from_pb(default_pb_batch, default_causal_lm_batch): @@ -285,7 +289,9 @@ def test_batch_concatenate( == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens ) - next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id]) + next_batch = next_batch.filter( + [next_batch.requests[0].id, next_batch.requests[1].id] + ) for _ in range( default_causal_lm_batch.stopping_criterias[0].max_new_tokens diff --git a/server/tests/models/test_santacoder.py b/server/tests/models/test_santacoder.py index 8cf66d47..bef8db38 100644 --- a/server/tests/models/test_santacoder.py +++ b/server/tests/models/test_santacoder.py @@ -45,7 +45,10 @@ def default_fim_pb_batch(default_fim_pb_request): @pytest.mark.skip def test_santacoder_generate_token_completion(default_santacoder, default_pb_batch): batch = CausalLMBatch.from_pb( - default_pb_batch, default_santacoder.tokenizer, default_santacoder.device + default_pb_batch, + default_santacoder.tokenizer, + default_santacoder.dtype, + default_santacoder.device, ) next_batch = batch @@ -70,7 +73,10 @@ def test_fim_santacoder_generate_token_completion( default_santacoder, default_fim_pb_batch ): batch = CausalLMBatch.from_pb( - default_fim_pb_batch, default_santacoder.tokenizer, default_santacoder.device + default_fim_pb_batch, + default_santacoder.tokenizer, + default_santacoder.dtype, + default_santacoder.device, ) next_batch = batch diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index 8fdeee60..a3199d02 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -42,7 +42,7 @@ def default_pb_batch(default_pb_request): @pytest.fixture def default_seq2seq_lm_batch(default_pb_batch, mt0_small_tokenizer): return Seq2SeqLMBatch.from_pb( - default_pb_batch, mt0_small_tokenizer, torch.device("cpu") + default_pb_batch, mt0_small_tokenizer, torch.float32, torch.device("cpu") ) @@ -55,7 +55,9 @@ def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokeni req_1.stopping_parameters.max_new_tokens = 5 batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2) - return Seq2SeqLMBatch.from_pb(batch_pb, mt0_small_tokenizer, torch.device("cpu")) + return Seq2SeqLMBatch.from_pb( + batch_pb, mt0_small_tokenizer, torch.float32, torch.device("cpu") + ) def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch): @@ -323,7 +325,9 @@ def test_batch_concatenate( ) assert generations[2].generated_text.generated_tokens == 5 - next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id]) + next_batch = next_batch.filter( + [next_batch.requests[0].id, next_batch.requests[1].id] + ) generations, next_batch = default_seq2seq_lm.generate_token(next_batch) assert next_batch is not None diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 5eddc8cf..088a1457 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -39,10 +39,11 @@ class BloomCausalLMBatch(CausalLMBatch): cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, + dtype: torch.dtype, device: torch.device, ) -> "CausalLMBatch": batch = super(BloomCausalLMBatch, cls).from_pb( - pb=pb, tokenizer=tokenizer, device=device + pb=pb, tokenizer=tokenizer, dtype=dtype, device=device ) batch.keys_head_dim_last = False return batch diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 81a5e75e..a20a6143 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -66,6 +66,7 @@ class CausalLMBatch(Batch): cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, + dtype: torch.dtype, device: torch.device, ) -> "CausalLMBatch": inputs = [] diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index baa6cd7f..35cbe174 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -18,11 +18,7 @@ from text_generation_server.models.types import ( GeneratedText, ) from text_generation_server.pb import generate_pb2 -from text_generation_server.utils import ( - NextTokenChooser, - StoppingCriteria, - Sampling, -) +from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser tracer = trace.get_tracer(__name__) @@ -48,7 +44,7 @@ class FlashCausalLMBatch(Batch): # All tokens all_input_ids: List[List[int]] - all_input_ids_tensor: List[torch.Tensor] + all_input_ids_tensor: torch.Tensor # Lengths of all generations present in the batch input_lengths: List[int] @@ -56,7 +52,7 @@ class FlashCausalLMBatch(Batch): read_offsets: List[Optional[int]] # Generation helpers - next_token_choosers: List[NextTokenChooser] + next_token_chooser: HeterogeneousNextTokenChooser stopping_criterias: List[StoppingCriteria] # Maximum number of tokens this batch will grow to @@ -75,6 +71,7 @@ class FlashCausalLMBatch(Batch): cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, + dtype: torch.dtype, device: torch.device, ) -> "FlashCausalLMBatch": position_ids = [] @@ -87,13 +84,14 @@ class FlashCausalLMBatch(Batch): all_input_ids = [] requests_idx_mapping = {} - next_token_choosers = [] + next_token_chooser_parameters = [] stopping_criterias = [] # Cumulative length cumulative_length = 0 max_tokens = 0 + max_length = 0 # Parse batch for i, r in enumerate(pb.requests): @@ -119,7 +117,7 @@ class FlashCausalLMBatch(Batch): # Add cumulative lengths of all previous inputs cu_seqlens.append(cumulative_length + input_length) - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + next_token_chooser_parameters.append(r.parameters) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer @@ -130,11 +128,26 @@ class FlashCausalLMBatch(Batch): # Update cumulative_length += input_length max_tokens += input_length + max_new_tokens + max_length = max(max_length, input_length + max_new_tokens) + + next_token_chooser = HeterogeneousNextTokenChooser.from_pb( + next_token_chooser_parameters, dtype, device + ) + + # Padded all_input_ids_tensor + all_input_ids_tensor = np.zeros( + (len(all_input_ids), max_length), dtype=np.int64 + ) + for i, input_ids in enumerate(all_input_ids): + all_input_ids_tensor[i, : len(input_ids)] = input_ids # Create tensors on device input_ids = torch.tensor( np.concatenate(all_input_ids), dtype=torch.int64, device=device ) + all_input_ids_tensor = torch.tensor( + all_input_ids_tensor, dtype=torch.int64, device=device + ) position_ids = torch.tensor( np.concatenate(position_ids), dtype=torch.int32, device=device ) @@ -154,8 +167,8 @@ class FlashCausalLMBatch(Batch): prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, - all_input_ids_tensor=[], - next_token_choosers=next_token_choosers, + all_input_ids_tensor=all_input_ids_tensor, + next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, max_tokens=max_tokens, ) @@ -176,31 +189,29 @@ class FlashCausalLMBatch(Batch): # New values after filtering requests_idx_mapping = {} - input_ids = self.input_ids.new_empty(len(request_ids)) - position_ids = self.position_ids.new_empty(len(request_ids)) + # Used to index into tensors + indices = [] + # Create on CPU to only move to GPU once instead of at every copy cu_seqlens = torch.zeros(len(request_ids) + 1, dtype=torch.int32) - cu_seqlens_q = torch.arange( - 0, len(request_ids) + 1, device=self.cu_seqlens_q.device, dtype=torch.int32 - ) + cu_seqlens_q = self.cu_seqlens_q[: len(request_ids) + 1] max_seqlen = 0 past_key_values = [] requests = [] all_input_ids = [] - all_input_ids_tensor = [] input_lengths = [] prefix_offsets = [] read_offsets = [] - next_token_choosers = [] stopping_criterias = [] max_tokens = 0 for i, request_id in enumerate(request_ids): idx = self.requests_idx_mapping[request_id] + indices.append(idx) requests_idx_mapping[request_id] = i requests.append(self.requests[idx]) @@ -208,10 +219,6 @@ class FlashCausalLMBatch(Batch): # Get length request_input_length = self.input_lengths[idx] - # Copy tensors (GPU) - input_ids[i] = self.input_ids[idx] - position_ids[i] = self.position_ids[idx] - # Copy to tensor (CPU) cu_seqlens[i + 1] = cumulative_length + request_input_length max_seqlen = max(max_seqlen, request_input_length) @@ -222,14 +229,11 @@ class FlashCausalLMBatch(Batch): ) all_input_ids.append(self.all_input_ids[idx]) - all_input_ids_tensor.append(self.all_input_ids_tensor[idx]) input_lengths.append(request_input_length) prefix_offsets.append(self.prefix_offsets[idx]) read_offsets.append(self.read_offsets[idx]) - next_token_choosers.append(self.next_token_choosers[idx]) - stopping_criteria = self.stopping_criterias[idx] stopping_criterias.append(stopping_criteria) @@ -258,6 +262,12 @@ class FlashCausalLMBatch(Batch): # Cat all past past_key_values = torch.cat(past_key_values, dim=1) + # Index into tensors + input_ids = self.input_ids[indices] + position_ids = self.position_ids[indices] + all_input_ids_tensor = self.all_input_ids_tensor[indices] + next_token_chooser = self.next_token_chooser.filter(indices) + # Move to GPU now that we have the whole tensor cu_seqlens = cu_seqlens.to(self.cu_seqlens.device) @@ -276,7 +286,7 @@ class FlashCausalLMBatch(Batch): read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, - next_token_choosers=next_token_choosers, + next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, max_tokens=max_tokens, ) @@ -290,6 +300,7 @@ class FlashCausalLMBatch(Batch): total_batch_size = sum([len(b) for b in batches]) + dtype = batches[0].past_key_values.dtype device = batches[0].input_ids.device input_ids = batches[0].input_ids.new_empty(total_batch_size) @@ -302,19 +313,19 @@ class FlashCausalLMBatch(Batch): past_key_values = [] all_input_ids = [] - all_input_ids_tensor = [] input_lengths = [] prefix_offsets = [] read_offsets = [] - next_token_choosers = [] + next_token_chooser_parameters = [] stopping_criterias = [] # Cumulative length cumulative_batch_size = 0 cumulative_length = 0 max_tokens = 0 + max_length = 0 for i, batch in enumerate(batches): requests.extend(batch.requests) @@ -347,25 +358,54 @@ class FlashCausalLMBatch(Batch): ) all_input_ids.extend(batch.all_input_ids) - all_input_ids_tensor.extend(batch.all_input_ids_tensor) input_lengths.extend(batch.input_lengths) prefix_offsets.extend(batch.prefix_offsets) read_offsets.extend(batch.read_offsets) - next_token_choosers.extend(batch.next_token_choosers) + next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) stopping_criterias.extend(batch.stopping_criterias) # Update cumulative_length += batch.cu_seqlens[-1] cumulative_batch_size += len(batch) max_tokens += batch.max_tokens + max_length = max( + max_length, + max( + input_length + + stopping_criteria.max_new_tokens + - stopping_criteria.current_tokens + for input_length, stopping_criteria in zip( + batch.input_lengths, batch.stopping_criterias + ) + ), + ) + + all_input_ids_tensor = torch.zeros( + (total_batch_size, max_length), dtype=torch.int64, device=device + ) + + cumulative_batch_size = 0 + for i, batch in enumerate(batches): + start_index = cumulative_batch_size + end_index = cumulative_batch_size + len(batch) + + all_input_ids_tensor[ + start_index:end_index, : batch.all_input_ids_tensor.shape[1] + ] = batch.all_input_ids_tensor[:, :max_length] + + cumulative_batch_size += len(batch) # Cat past past_key_values = torch.cat(past_key_values, dim=1) # Create final tensor on GPU cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + next_token_chooser = HeterogeneousNextTokenChooser.from_pb( + next_token_chooser_parameters, dtype=dtype, device=device + ) + return FlashCausalLMBatch( batch_id=batches[0].batch_id, requests=requests, @@ -381,7 +421,7 @@ class FlashCausalLMBatch(Batch): read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, - next_token_choosers=next_token_choosers, + next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, max_tokens=max_tokens, ) @@ -463,6 +503,7 @@ class FlashCausalLM(Model): self, batch: FlashCausalLMBatch ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: prefill = batch.past_key_values is None + single_request = len(batch) == 1 if prefill and len(batch) == 1: # Ask to pre-allocate kv to its max size @@ -483,6 +524,17 @@ class FlashCausalLM(Model): pre_allocate_past_size, ) + if prefill: + next_token_logits = ( + out[-1:] if single_request else out[batch.cu_seqlens[1:] - 1] + ) + else: + next_token_logits = out + + next_input_ids, next_token_logprobs = batch.next_token_chooser( + batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits + ) + if prefill: if len(batch) > 1: # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs @@ -493,15 +545,11 @@ class FlashCausalLM(Model): batch.cu_seqlens_q = torch.arange( 0, len(batch) + 1, device=self.device, dtype=torch.int32 ) - next_input_ids = batch.input_ids.new_empty(len(batch)) next_position_ids = batch.position_ids.new_empty(len(batch)) else: prefill_logprobs = None - next_input_ids = batch.input_ids next_position_ids = batch.position_ids - next_token_logprobs = out.new_empty(len(batch)) - # Prepare past for next decode if len(batch) > 1: # Used to slice next batch past @@ -552,7 +600,6 @@ class FlashCausalLM(Model): # Zipped iterator iterator = zip( batch.input_lengths, - batch.next_token_choosers, batch.stopping_criterias, batch.all_input_ids, ) @@ -564,7 +611,6 @@ class FlashCausalLM(Model): # For each member of the batch for i, ( input_length, - next_token_chooser, stopping_criteria, all_input_ids, ) in enumerate(iterator): @@ -573,21 +619,6 @@ class FlashCausalLM(Model): end_index = cumulative_length + input_length if prefill: - # Prefill mode - # out is of shape [cumulative_sequence_lengths, vocab_size] - # only take last token logit - logits = out[end_index - 1 : end_index] - - # Create all_input_ids_tensor that will be used by token warpers (for example, RepetitionPenalty) - all_input_ids_tensor = batch.input_ids.new_empty( - input_length + stopping_criteria.max_new_tokens - ) - # Copy from batch.input_ids to all_input_ids_tensor - all_input_ids_tensor[:input_length] = batch.input_ids[ - start_index:end_index - ] - batch.all_input_ids_tensor.append(all_input_ids_tensor) - # Initialize position_ids # In decode, we do not need this as we can just increment position ids next_position_ids[i] = batch.position_ids[end_index - 1] @@ -603,25 +634,8 @@ class FlashCausalLM(Model): prefill_tokens_indices = batch.input_ids[ start_index + 1 : end_index ] - else: - # Decode mode - # out is of shape [batch_size, vocab_size] - logits = out[i].view(1, -1) - all_input_ids_tensor = batch.all_input_ids_tensor[i] - - # Select next token - next_token_id, logprob = next_token_chooser( - all_input_ids_tensor[None, :input_length], logits - ) - - # Add to all_input_ids_tensor - next_token_id_squeezed = next_token_id.view(1) - all_input_ids_tensor[input_length] = next_token_id_squeezed - - # Set values - next_input_ids[i] = next_token_id_squeezed - next_token_logprobs[i] = logprob[-1, next_token_id].view(1) + batch.all_input_ids_tensor[i, input_length] = next_input_ids[i] cumulative_length += input_length @@ -651,10 +665,11 @@ class FlashCausalLM(Model): batch.input_lengths, batch.prefix_offsets, batch.read_offsets, - batch.next_token_choosers, batch.stopping_criterias, batch.all_input_ids, batch.all_input_ids_tensor, + batch.next_token_chooser.do_sample, + batch.next_token_chooser.seeds, next_token_ids, next_token_logprobs, ) @@ -665,10 +680,11 @@ class FlashCausalLM(Model): input_length, prefix_offset, read_offset, - next_token_chooser, stopping_criteria, all_input_ids, all_input_ids_tensor, + do_sample, + seed, next_token_id, next_token_logprob, ) in enumerate(iterator): @@ -702,14 +718,11 @@ class FlashCausalLM(Model): output_text = self.decode( all_input_ids[-stopping_criteria.current_tokens :] ) - # Get seed - if isinstance(next_token_chooser.choice, Sampling): - seed = next_token_chooser.choice.seed - else: - seed = None - generated_text = GeneratedText( - output_text, stopping_criteria.current_tokens, reason, seed + output_text, + stopping_criteria.current_tokens, + reason, + seed if do_sample else None, ) else: generated_text = None @@ -751,8 +764,9 @@ class FlashCausalLM(Model): batch.prefix_offsets[i] = prefix_offset batch.read_offsets[i] = read_offset batch.all_input_ids[i] = all_input_ids - batch.max_seqlen = batch.max_seqlen + 1 cumulative_length += input_length + batch.max_seqlen = batch.max_seqlen + 1 + # No need to return a batch if we know that all requests stopped return generations, batch if not stopped else None diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index bc3096c6..0a3f341b 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -89,6 +89,7 @@ class GalacticaCausalLMBatch(CausalLMBatch): cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, + dtype: torch.dtype, device: torch.device, ) -> "GalacticaCausalLMBatch": inputs = [] diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 2abb87ae..68e59dc3 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -71,6 +71,7 @@ class Seq2SeqLMBatch(Batch): cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, + dtype: torch.dtype, device: torch.device, ) -> "Seq2SeqLMBatch": """Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch""" diff --git a/server/text_generation_server/models/types.py b/server/text_generation_server/models/types.py index 66a8c212..28ca8147 100644 --- a/server/text_generation_server/models/types.py +++ b/server/text_generation_server/models/types.py @@ -21,6 +21,7 @@ class Batch(ABC): cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, + dtype: torch.dtype, device: torch.device, ) -> "Batch": raise NotImplementedError diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index e47fd049..e1bd8412 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -55,7 +55,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): async def Prefill(self, request, context): batch = self.model.batch_type.from_pb( - request.batch, self.model.tokenizer, self.model.device + request.batch, self.model.tokenizer, self.model.dtype, self.model.device ) generations, next_batch = self.model.generate_token(batch) diff --git a/server/text_generation_server/utils/__init__.py b/server/text_generation_server/utils/__init__.py index 50d64518..6a351d66 100644 --- a/server/text_generation_server/utils/__init__.py +++ b/server/text_generation_server/utils/__init__.py @@ -9,12 +9,13 @@ from text_generation_server.utils.hub import ( RevisionNotFoundError, ) from text_generation_server.utils.tokens import ( - Greedy, NextTokenChooser, - Sampling, + HeterogeneousNextTokenChooser, StoppingCriteria, StopSequenceCriteria, FinishReason, + Sampling, + Greedy, ) __all__ = [ @@ -25,6 +26,7 @@ __all__ = [ "weight_hub_files", "download_weights", "EntryNotFoundError", + "HeterogeneousNextTokenChooser", "LocalEntryNotFoundError", "RevisionNotFoundError", "Greedy", diff --git a/server/text_generation_server/utils/convert.py b/server/text_generation_server/utils/convert.py index caf1a764..c43a4464 100644 --- a/server/text_generation_server/utils/convert.py +++ b/server/text_generation_server/utils/convert.py @@ -1,14 +1,10 @@ -import concurrent -import time import datetime import torch -from concurrent.futures import ThreadPoolExecutor from collections import defaultdict -from datetime import timedelta from loguru import logger from pathlib import Path -from safetensors.torch import load_file, save_file +from safetensors.torch import save_file from safetensors import safe_open from typing import Dict, List diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py new file mode 100644 index 00000000..faa94516 --- /dev/null +++ b/server/text_generation_server/utils/logits_process.py @@ -0,0 +1,405 @@ +import math +import torch + +from functools import lru_cache +from typing import Optional, List, Dict, Union + +from transformers import ( + LogitsWarper, + LogitsProcessor, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, + TypicalLogitsWarper, +) + +mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None + + +class StaticWarper: + def __init__( + self, + temperature=1.0, + top_k=None, + top_p=None, + typical_p=None, + ): + self.warpers = [] + + if temperature is not None and temperature != 1.0: + temperature = float(temperature) + self.warpers.append(TemperatureLogitsWarper(temperature)) + if top_k is not None and top_k != 0: + self.warpers.append(TopKLogitsWarper(top_k=top_k)) + if top_p is not None and top_p < 1.0: + self.warpers.append(TopPLogitsWarper(top_p=top_p)) + if typical_p is not None and typical_p < 1.0: + self.warpers.append(TypicalLogitsWarper(mass=typical_p)) + + self.cuda_graph = None + self.static_scores = None + self.static_warped_scores = None + self.static_next_logprob = None + + def __call__(self, scores): + if self.cuda_graph is None: + self.static_scores = scores + self.cuda_graph = torch.cuda.CUDAGraph() + + with torch.cuda.graph(self.cuda_graph, pool=mempool): + local_scores = self.static_scores + for warper in self.warpers: + local_scores = warper(None, local_scores) + + self.static_warped_scores = local_scores + # Compute logprobs + self.static_next_logprob = torch.log_softmax( + self.static_warped_scores, -1 + ) + + self.static_scores.copy_(scores) + self.cuda_graph.replay() + + return self.static_warped_scores, self.static_next_logprob + + +@lru_cache(10) +def static_warper( + temperature: Optional[float], + top_k: Optional[int], + top_p: Optional[float], + typical_p: Optional[float], +) -> StaticWarper: + return StaticWarper( + temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p + ) + + +class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor): + r""" + [`LogitsProcessor`] enforcing an exponential penalty on repeated sequences. + This version allows for a separate value for each sample and runs inplace when possible. + It doesn't validate inputs. + + Args: + repetition_penalty (`List[float]`): + The parameter for repetition penalty. 1.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + """ + + def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device): + self.penalty = penalty + self.penalty_tensor = torch.tensor( + penalty, dtype=dtype, device=device + ).unsqueeze(1) + + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: + score = torch.gather(scores, 1, input_ids) + + # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability + score = torch.where( + score < 0, score * self.penalty_tensor, score / self.penalty_tensor + ) + + scores.scatter_(1, input_ids, score) + return scores + + def filter(self, indices): + self.penalty = [self.penalty[i] for i in indices] + if any([x != 1.0 for x in self.penalty]): + self.penalty_tensor = self.penalty_tensor[indices] + return self + return None + + +class HeterogeneousTemperatureLogitsWarper: + r""" + [`LogitsWarper`] for temperature (exponential scaling output probability distribution). + This version allows for a separate value for each sample and runs inplace when possible. + It doesn't validate inputs. + + Args: + temperature (`float`): + The value used to module the logits distribution. + """ + + def __init__( + self, temperature: List[float], dtype: torch.dtype, device: torch.device + ): + self.temperature = temperature + self.temperature_tensor = torch.tensor( + temperature, dtype=dtype, device=device + ).unsqueeze(1) + + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: + scores.div_(self.temperature_tensor) + return scores + + def filter(self, indices): + self.temperature = [self.temperature[i] for i in indices] + if any([x != 1.0 for x in self.temperature]): + self.temperature_tensor = self.temperature_tensor[indices] + return self + return None + + +class HeterogeneousTopPLogitsWarper(LogitsWarper): + """ + [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. + This version allows for a separate value for each sample and runs inplace when possible. + It doesn't validate inputs. + + Args: + top_p (`float`): + If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + higher are kept for generation. + filter_value (`float`, *optional*, defaults to `-float("Inf")`): + All filtered values will be set to this float value. + min_tokens_to_keep (`int`, *optional*, defaults to 1): + Minimum number of tokens that cannot be filtered. + """ + + def __init__( + self, + top_p: List[float], + dtype: torch.dtype, + device: torch.device, + filter_value: float = -math.inf, + min_tokens_to_keep: int = 1, + ): + self.top_p = top_p + self.top_p_opposite = 1 - torch.tensor( + top_p, dtype=dtype, device=device + ).unsqueeze(1) + self.filter_value = filter_value + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: + sorted_logits, sorted_indices = torch.sort(scores, descending=False) + probs = sorted_logits.softmax(dim=-1) + # This is way faster for some reason + for i in range(probs.shape[0]): + probs[i] = probs[i].cumsum(dim=-1) + + # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) + sorted_indices_to_remove = probs <= self.top_p_opposite + if self.min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep + sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove + ) + warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value) + + return warped_scores + + def filter(self, indices): + self.top_p = [self.top_p[i] for i in indices] + if any([x < 1.0 for x in self.top_p]): + self.top_p_opposite = self.top_p_opposite[indices] + return self + return None + + +class HeterogeneousTopKLogitsWarper(LogitsWarper): + r""" + [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. + This version allows for a separate value for each sample and runs inplace when possible. + It doesn't validate inputs. + + Args: + top_k (`int`): + The number of highest probability vocabulary tokens to keep for top-k-filtering. + filter_value (`float`, *optional*, defaults to `-float("Inf")`): + All filtered values will be set to this float value. + min_tokens_to_keep (`int`, *optional*, defaults to 1): + Minimum number of tokens that cannot be filtered. + """ + + def __init__( + self, + top_k: List[int], + device: torch.device, + filter_value: float = -math.inf, + min_tokens_to_keep: int = 1, + ): + self.top_k = top_k + self.max_top_k = max(top_k) + # value - 1 as we will use top_k to index and python uses 0 based numbering + self.top_k_tensor = torch.tensor( + [max(x - 1, min_tokens_to_keep - 1) for x in top_k], + dtype=torch.int64, + device=device, + ).unsqueeze(1) + + # 0 is a special value that disables top_k warping for this member of the batch + disabled = [x == 0 for x in top_k] + + if any(disabled): + self.top_k_disabled_mask = torch.tensor( + disabled, dtype=torch.bool, device=device + ).view(-1, 1) + else: + self.top_k_disabled_mask = None + + self.filter_value = filter_value + + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: + # If max_top_k is superior to the vocab, we need to clamp or the warper will fail + if scores.size(-1) < self.max_top_k: + max_top_k = scores.size(-1) + top_k = torch.clamp_max(self.top_k_tensor, max_top_k) + else: + max_top_k = self.max_top_k + top_k = self.top_k_tensor + + # Get the kth score for each member of the batch + kth_scores = torch.gather(torch.topk(scores, max_top_k)[0], 1, top_k) + + # Mask member of kth_scores that do not want to use top_k warping + if self.top_k_disabled_mask is not None: + kth_scores.masked_fill_(self.top_k_disabled_mask, self.filter_value) + + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = scores < kth_scores + scores.masked_fill_(indices_to_remove, self.filter_value) + return scores + + def filter(self, indices): + self.top_k = [self.top_k[i] for i in indices] + disabled = [x == 0 for x in self.top_k] + + if not all(disabled): + self.top_k_tensor = self.top_k_tensor[indices] + self.max_top_k = max(self.top_k) + + if self.top_k_disabled_mask is not None: + self.top_k_disabled_mask = ( + self.top_k_disabled_mask[indices] if any(disabled) else None + ) + + return self + return None + + +class HeterogeneousTypicalLogitsWarper(LogitsWarper): + r""" + [`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language + Generation](https://arxiv.org/abs/2202.00666) for more information. + This version allows for a separate value for each sample and runs inplace when possible. + It doesn't validate inputs. + + Args: + mass (`float`): + Value of typical_p between 0 and 1 inclusive, defaults to 0.9. + filter_value (`float`, *optional*, defaults to `-float("Inf")`): + All filtered values will be set to this float value. + min_tokens_to_keep (`int`, *optional*, defaults to 1): + Minimum number of tokens that cannot be filtered. + """ + + def __init__( + self, + mass: List[float], + dtype: torch.dtype, + device: torch.device, + filter_value: float = -math.inf, + min_tokens_to_keep: int = 1, + ): + self.mass = mass + self.mass_tensor = torch.tensor(mass, dtype=dtype, device=device).unsqueeze(1) + + # 1 is a special value that disables typical_p warping for this member of the batch + disabled = [x == 1.0 for x in mass] + + if any(disabled): + self.disabled_mask = torch.tensor(disabled, dtype=torch.bool, device=device) + else: + self.disabled_mask = None + + self.filter_value = filter_value + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: + # calculate entropy + normalized = torch.nn.functional.log_softmax(scores, dim=-1) + p = torch.exp(normalized) + ent = -(normalized * p).nansum(-1, keepdim=True) + + # shift and sort + shifted_scores = torch.abs((-normalized) - ent) + sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False) + sorted_logits = scores.gather(-1, sorted_indices) + probs = sorted_logits.softmax(dim=-1) + # This is way faster for some reason + for i in range(probs.shape[0]): + probs[i] = probs[i].cumsum(dim=-1) + + # Remove tokens with cumulative mass above the threshold + last_ind = (probs < self.mass_tensor).sum(dim=1) + last_ind[last_ind < 0] = 0 + + if self.disabled_mask is not None: + last_ind.masked_fill_(self.disabled_mask, scores.shape[-1] - 1) + + sorted_indices_to_remove = sorted_scores > sorted_scores.gather( + 1, last_ind.view(-1, 1) + ) + if self.min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove + ) + + warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value) + + return warped_scores + + def filter(self, indices): + self.mass = [self.mass[i] for i in indices] + disabled = [x == 1.0 for x in self.mass] + + if not all(disabled): + self.mass_tensor = self.mass_tensor[indices] + + if self.disabled_mask is not None: + self.disabled_mask = ( + self.disabled_mask[indices] if any(disabled) else None + ) + + return self + return None + + +class HeterogeneousProcessorWrapper(LogitsProcessor): + r""" + A wrapper for logit warpers or processors without heterogeneous parameter support. + Args: + processors (`Dict[int, Union[LogitsProcessor, LogitsWarper]]`): + A mapping of sample indices to logit warpers or processors, to be run sequentially. + """ + + def __init__( + self, + processors: Dict[int, Union[LogitsProcessor, LogitsWarper]], + ): + self.processors = processors + + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: + for i, processor in self.processors.items(): + scores[i : i + 1] = processor(input_ids[i : i + 1], scores[i : i + 1]) + return scores + + def filter(self, indices): + new_processors = {} + for i, idx in enumerate(indices): + if idx in self.processors: + new_processors[i] = self.processors[idx] + + if new_processors: + self.processors = new_processors + return self + return None diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index e9fb96b0..e6e512bc 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -1,12 +1,7 @@ import re import torch -from functools import lru_cache from transformers import ( - TemperatureLogitsWarper, - TopKLogitsWarper, - TopPLogitsWarper, - TypicalLogitsWarper, RepetitionPenaltyLogitsProcessor, PreTrainedTokenizerBase, ) @@ -15,82 +10,15 @@ from typing import List, Tuple, Optional from text_generation_server.pb import generate_pb2 from text_generation_server.pb.generate_pb2 import FinishReason from text_generation_server.utils.watermark import WatermarkLogitsProcessor - - -class Sampling: - def __init__(self, seed: int, device: str = "cpu"): - self.generator = torch.Generator(device) - self.generator.manual_seed(seed) - self.seed = seed - - def __call__(self, logits): - probs = torch.nn.functional.softmax(logits, -1) - # Avoid GPU<->CPU sync done by torch multinomial - # See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637 - q = torch.empty_like(probs).exponential_(1, generator=self.generator) - return probs.div_(q).argmax() - - -class Greedy: - def __call__(self, logits): - return logits.argmax() - - -class StaticWarper: - def __init__( - self, - temperature=1.0, - top_k=None, - top_p=None, - typical_p=None, - ): - self.warpers = [] - - if temperature is not None and temperature != 1.0: - temperature = float(temperature) - self.warpers.append(TemperatureLogitsWarper(temperature)) - if top_k is not None and top_k != 0: - self.warpers.append(TopKLogitsWarper(top_k=top_k)) - if top_p is not None and top_p < 1.0: - self.warpers.append(TopPLogitsWarper(top_p=top_p)) - if typical_p is not None and typical_p < 1.0: - self.warpers.append(TypicalLogitsWarper(mass=typical_p)) - - self.cuda_graph = None - self.static_scores = None - self.static_warped_scores = None - self.static_next_logprob = None - - def __call__(self, scores): - if self.cuda_graph is None: - self.static_scores = scores - self.cuda_graph = torch.cuda.CUDAGraph() - - with torch.cuda.graph(self.cuda_graph): - for warper in self.warpers: - self.static_warped_scores = warper(None, self.static_scores) - - # Compute logprobs - self.static_next_logprob = torch.log_softmax( - self.static_warped_scores, -1 - ) - - self.static_scores.copy_(scores) - self.cuda_graph.replay() - - return self.static_warped_scores, self.static_next_logprob - - -@lru_cache(10) -def static_warper( - temperature: Optional[float], - top_k: Optional[int], - top_p: Optional[float], - typical_p: Optional[float], -) -> StaticWarper: - return StaticWarper( - temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p - ) +from text_generation_server.utils.logits_process import ( + static_warper, + HeterogeneousRepetitionPenaltyLogitsProcessor, + HeterogeneousTemperatureLogitsWarper, + HeterogeneousTopKLogitsWarper, + HeterogeneousTopPLogitsWarper, + HeterogeneousTypicalLogitsWarper, + HeterogeneousProcessorWrapper, +) class NextTokenChooser: @@ -132,9 +60,9 @@ class NextTokenChooser: self.choice = Sampling(seed, device) if sampling else Greedy() def __call__(self, input_ids, scores): - if self.watermark_processor: + if self.watermark_processor is not None: scores = self.watermark_processor(input_ids, scores) - if self.repetition_processor: + if self.repetition_processor is not None: scores = self.repetition_processor(input_ids, scores) if self.static_warper is None: @@ -221,3 +149,191 @@ class StoppingCriteria: pb.max_new_tokens, pb.ignore_eos_token, ) + + +class HeterogeneousNextTokenChooser: + def __init__( + self, + dtype: torch.dtype, + device: torch.device, + watermark: List[bool], + temperature: List[float], + repetition_penalty: List[float], + top_k: List[int], + top_p: List[float], + typical_p: List[float], + do_sample: List[bool], + seeds: List[int], + ): + warpers = [] + + self.watermark_processor = ( + HeterogeneousProcessorWrapper( + { + i: WatermarkLogitsProcessor(device=device) + for i, do_watermark in enumerate(watermark) + if do_watermark + } + ) + if any(watermark) + else None + ) + + self.repetition_processor = ( + HeterogeneousRepetitionPenaltyLogitsProcessor( + repetition_penalty, dtype, device + ) + if any([x != 1.0 for x in repetition_penalty]) + else None + ) + + if any([x != 1.0 for x in temperature]): + do_sample = [ + sample or x != 1.0 for x, sample in zip(temperature, do_sample) + ] + warpers.append( + HeterogeneousTemperatureLogitsWarper(temperature, dtype, device) + ) + + if any([x != 0 for x in top_k]): + do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)] + warpers.append(HeterogeneousTopKLogitsWarper(top_k, device)) + + if any([x < 1.0 for x in top_p]): + do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)] + warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device)) + + if any([x < 1.0 for x in typical_p]): + do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)] + warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device)) + + self.warpers = warpers + + if any(do_sample): + self.choice = HeterogeneousSampling(do_sample, seeds, device) + else: + self.choice = Greedy() + + self.seeds = seeds + self.do_sample = do_sample + + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor): + if self.watermark_processor is not None: + scores = self.watermark_processor(input_ids, scores) + if self.repetition_processor is not None: + scores = self.repetition_processor(input_ids, scores) + + for warper in self.warpers: + scores = warper(input_ids, scores) + + next_ids = self.choice(scores) + next_logprobs = torch.gather( + torch.log_softmax(scores, -1), 1, next_ids.view(-1, 1) + ).view(-1) + + return next_ids, next_logprobs + + def filter(self, indices): + if self.watermark_processor is not None: + self.watermark_processor = self.watermark_processor.filter(indices) + + if self.repetition_processor is not None: + self.repetition_processor = self.repetition_processor.filter(indices) + + filtered_warpers = [] + for warper in self.warpers: + filtered_warper = warper.filter(indices) + if filtered_warper is not None: + filtered_warpers.append(filtered_warper) + self.warpers = filtered_warpers + + self.seeds = [self.seeds[i] for i in indices] + self.do_sample = [self.do_sample[i] for i in indices] + + if any(self.do_sample): + self.choice.filter(indices) + else: + self.choice = Greedy() + + return self + + @classmethod + def from_pb( + cls, + pb: List[generate_pb2.NextTokenChooserParameters], + dtype: torch.dtype, + device: torch.device, + ) -> "HeterogeneousNextTokenChooser": + return HeterogeneousNextTokenChooser( + watermark=[pb_.watermark for pb_ in pb], + temperature=[pb_.temperature for pb_ in pb], + repetition_penalty=[pb_.repetition_penalty for pb_ in pb], + top_k=[pb_.top_k for pb_ in pb], + top_p=[pb_.top_p for pb_ in pb], + typical_p=[pb_.typical_p for pb_ in pb], + do_sample=[pb_.do_sample for pb_ in pb], + seeds=[pb_.seed for pb_ in pb], + device=device, + dtype=dtype, + ) + + +class Sampling: + def __init__(self, seed: int, device: str = "cpu"): + self.generator = torch.Generator(device) + self.generator.manual_seed(seed) + self.seed = seed + + def __call__(self, logits): + probs = torch.nn.functional.softmax(logits, -1) + # Avoid GPU<->CPU sync done by torch multinomial + # See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637 + q = torch.empty_like(probs).exponential_(1, generator=self.generator) + return probs.div_(q).argmax() + + +class Greedy: + def __call__(self, logits): + return logits.argmax(dim=-1) + + +class HeterogeneousSampling: + r""" + Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample. + """ + + def __init__(self, do_sample: List[bool], seeds: List[int], device: torch.device): + self.seeds = seeds + + self.greedy_indices = [] + self.sampling_mapping = {} + for i, (sample, seed) in enumerate(zip(do_sample, seeds)): + if sample: + self.sampling_mapping[i] = Sampling(seed, device) + else: + self.greedy_indices.append(i) + + self.greedy = Greedy() + + def __call__(self, logits): + out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device) + if self.greedy_indices: + # Computing for all indices is faster than slicing + torch.argmax(logits, -1, out=out) + + for i, sampling in self.sampling_mapping.items(): + out[i] = sampling(logits[i]) + return out + + def filter(self, indices): + new_greedy_indices = [] + new_sampling_mapping = {} + for i, idx in enumerate(indices): + if idx in self.sampling_mapping: + new_sampling_mapping[i] = self.sampling_mapping[idx] + else: + new_greedy_indices.append(i) + + self.greedy_indices = new_greedy_indices + self.sampling_mapping = new_sampling_mapping + return self