diff --git a/integration-tests/models/__snapshots__/test_t5_sharded/test_t5_sharded.json b/integration-tests/models/__snapshots__/test_t5_sharded/test_t5_sharded.json new file mode 100644 index 00000000..6090e2c9 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_t5_sharded/test_t5_sharded.json @@ -0,0 +1,60 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "eos_token", + "generated_tokens": 7, + "prefill": [ + { + "id": 0, + "logprob": null, + "text": "" + } + ], + "seed": null, + "tokens": [ + { + "id": 3, + "logprob": -0.7001953, + "special": false, + "text": " " + }, + { + "id": 18, + "logprob": -1.1943359, + "special": false, + "text": "-" + }, + { + "id": 26937, + "logprob": -1.2099609, + "special": false, + "text": "196" + }, + { + "id": 3, + "logprob": -1.2451172, + "special": false, + "text": " " + }, + { + "id": 1956, + "logprob": -0.3322754, + "special": false, + "text": "°" + }, + { + "id": 254, + "logprob": -0.19213867, + "special": false, + "text": "C" + }, + { + "id": 1, + "logprob": -0.030151367, + "special": true, + "text": "" + } + ] + }, + "generated_text": "-196 °C" +} diff --git a/integration-tests/models/__snapshots__/test_t5_sharded/test_t5_sharded_load.json b/integration-tests/models/__snapshots__/test_t5_sharded/test_t5_sharded_load.json new file mode 100644 index 00000000..3e9af12e --- /dev/null +++ b/integration-tests/models/__snapshots__/test_t5_sharded/test_t5_sharded_load.json @@ -0,0 +1,242 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "eos_token", + "generated_tokens": 7, + "prefill": [ + { + "id": 0, + "logprob": null, + "text": "" + } + ], + "seed": null, + "tokens": [ + { + "id": 3, + "logprob": -0.7001953, + "special": false, + "text": " " + }, + { + "id": 18, + "logprob": -1.1943359, + "special": false, + "text": "-" + }, + { + "id": 26937, + "logprob": -1.2119141, + "special": false, + "text": "196" + }, + { + "id": 3, + "logprob": -1.2480469, + "special": false, + "text": " " + }, + { + "id": 1956, + "logprob": -0.33203125, + "special": false, + "text": "°" + }, + { + "id": 254, + "logprob": -0.19250488, + "special": false, + "text": "C" + }, + { + "id": 1, + "logprob": -0.030166626, + "special": true, + "text": "" + } + ] + }, + "generated_text": "-196 °C" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "eos_token", + "generated_tokens": 7, + "prefill": [ + { + "id": 0, + "logprob": null, + "text": "" + } + ], + "seed": null, + "tokens": [ + { + "id": 3, + "logprob": -0.7001953, + "special": false, + "text": " " + }, + { + "id": 18, + "logprob": -1.1943359, + "special": false, + "text": "-" + }, + { + "id": 26937, + "logprob": -1.2119141, + "special": false, + "text": "196" + }, + { + "id": 3, + "logprob": -1.2480469, + "special": false, + "text": " " + }, + { + "id": 1956, + "logprob": -0.33203125, + "special": false, + "text": "°" + }, + { + "id": 254, + "logprob": -0.19250488, + "special": false, + "text": "C" + }, + { + "id": 1, + "logprob": -0.030166626, + "special": true, + "text": "" + } + ] + }, + "generated_text": "-196 °C" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "eos_token", + "generated_tokens": 7, + "prefill": [ + { + "id": 0, + "logprob": null, + "text": "" + } + ], + "seed": null, + "tokens": [ + { + "id": 3, + "logprob": -0.7001953, + "special": false, + "text": " " + }, + { + "id": 18, + "logprob": -1.1943359, + "special": false, + "text": "-" + }, + { + "id": 26937, + "logprob": -1.2119141, + "special": false, + "text": "196" + }, + { + "id": 3, + "logprob": -1.2480469, + "special": false, + "text": " " + }, + { + "id": 1956, + "logprob": -0.33203125, + "special": false, + "text": "°" + }, + { + "id": 254, + "logprob": -0.19250488, + "special": false, + "text": "C" + }, + { + "id": 1, + "logprob": -0.030166626, + "special": true, + "text": "" + } + ] + }, + "generated_text": "-196 °C" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "eos_token", + "generated_tokens": 7, + "prefill": [ + { + "id": 0, + "logprob": null, + "text": "" + } + ], + "seed": null, + "tokens": [ + { + "id": 3, + "logprob": -0.7001953, + "special": false, + "text": " " + }, + { + "id": 18, + "logprob": -1.1943359, + "special": false, + "text": "-" + }, + { + "id": 26937, + "logprob": -1.2099609, + "special": false, + "text": "196" + }, + { + "id": 3, + "logprob": -1.2451172, + "special": false, + "text": " " + }, + { + "id": 1956, + "logprob": -0.3322754, + "special": false, + "text": "°" + }, + { + "id": 254, + "logprob": -0.19213867, + "special": false, + "text": "C" + }, + { + "id": 1, + "logprob": -0.030151367, + "special": true, + "text": "" + } + ] + }, + "generated_text": "-196 °C" + } +] diff --git a/integration-tests/models/test_flash_neox.py b/integration-tests/models/test_flash_neox.py index ca5b33c1..ff9915ed 100644 --- a/integration-tests/models/test_flash_neox.py +++ b/integration-tests/models/test_flash_neox.py @@ -36,6 +36,8 @@ async def test_flash_neox_load(flash_neox, generate_load, response_snapshot): generated_texts = [r.generated_text for r in responses] assert len(generated_texts) == 4 - assert generated_texts, all([text == generated_texts[0] for text in generated_texts]) + assert generated_texts, all( + [text == generated_texts[0] for text in generated_texts] + ) assert responses == response_snapshot diff --git a/integration-tests/models/test_t5_sharded.py b/integration-tests/models/test_t5_sharded.py new file mode 100644 index 00000000..074660c7 --- /dev/null +++ b/integration-tests/models/test_t5_sharded.py @@ -0,0 +1,38 @@ +import pytest + + +@pytest.fixture(scope="module") +def t5_sharded_handle(launcher): + with launcher("google/flan-t5-xxl", num_shard=2) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def t5_sharded(t5_sharded_handle): + await t5_sharded_handle.health(240) + return t5_sharded_handle.client + + +@pytest.mark.asyncio +async def test_t5_sharded(t5_sharded, response_snapshot): + response = await t5_sharded.generate( + "Please answer the following question. What is the boiling point of Nitrogen?", + max_new_tokens=10, + ) + + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_t5_sharded_load(t5_sharded, generate_load, response_snapshot): + responses = await generate_load( + t5_sharded, + "Please answer the following question. What is the boiling point of Nitrogen?", + max_new_tokens=10, + n=4, + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 390f0a0a..9d609185 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -160,7 +160,10 @@ class BLOOMSharded(BLOOM): # XXX: Hack for Rowlinear to add the bias only once. if rank != 0: tensor = torch.zeros_like(tensor) - elif isinstance(module, TensorParallelEmbedding) or name == "lm_head.weight": + elif ( + isinstance(module, TensorParallelEmbedding) + or name == "lm_head.weight" + ): size = slice_.get_shape()[0] block_size = size // world_size start = rank * block_size diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 2fd67574..7ecf948b 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -40,7 +40,7 @@ class T5Sharded(Seq2SeqLM): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 + dtype = torch.float16 else: device = torch.device("cpu") dtype = torch.float32 @@ -154,9 +154,15 @@ class T5Sharded(Seq2SeqLM): f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" ) - tensor = tensor.contiguous().to(dtype) + tensor = tensor.contiguous() - if quantize == "bitsandbytes": + # See: https://github.com/huggingface/transformers/blob/1fe1e3caa44617047f149bcc0c0b566343b714a7/src/transformers/models/t5/modeling_t5.py#LL316C15-L316C71 + if module_name.endswith("wo"): + tensor = tensor.to(torch.float32) + else: + tensor = tensor.to(dtype) + + if quantize == "bitsandbytes" and not module_name.endswith("wo"): if not HAS_BITS_AND_BYTES: raise ImportError( "bitsandbytes is not available on your machine either because it is not installed " @@ -207,7 +213,7 @@ class T5Sharded(Seq2SeqLM): module.linear = replace_linear(state) - elif quantize == "gptq": + elif quantize == "gptq" and not module_name.endswith("wo"): raise NotImplementedError( "`gptq` is not implemented for now" )