feat(server): support fp16 for t5 (#360)

Fixes #349
This commit is contained in:
OlivierDehaene 2023-05-23 18:16:48 +02:00 committed by GitHub
parent 94377efa78
commit cfaa858070
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 357 additions and 6 deletions

View File

@ -0,0 +1,60 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 7,
"prefill": [
{
"id": 0,
"logprob": null,
"text": "<pad>"
}
],
"seed": null,
"tokens": [
{
"id": 3,
"logprob": -0.7001953,
"special": false,
"text": " "
},
{
"id": 18,
"logprob": -1.1943359,
"special": false,
"text": "-"
},
{
"id": 26937,
"logprob": -1.2099609,
"special": false,
"text": "196"
},
{
"id": 3,
"logprob": -1.2451172,
"special": false,
"text": " "
},
{
"id": 1956,
"logprob": -0.3322754,
"special": false,
"text": "°"
},
{
"id": 254,
"logprob": -0.19213867,
"special": false,
"text": "C"
},
{
"id": 1,
"logprob": -0.030151367,
"special": true,
"text": "</s>"
}
]
},
"generated_text": "-196 °C"
}

View File

@ -0,0 +1,242 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 7,
"prefill": [
{
"id": 0,
"logprob": null,
"text": "<pad>"
}
],
"seed": null,
"tokens": [
{
"id": 3,
"logprob": -0.7001953,
"special": false,
"text": " "
},
{
"id": 18,
"logprob": -1.1943359,
"special": false,
"text": "-"
},
{
"id": 26937,
"logprob": -1.2119141,
"special": false,
"text": "196"
},
{
"id": 3,
"logprob": -1.2480469,
"special": false,
"text": " "
},
{
"id": 1956,
"logprob": -0.33203125,
"special": false,
"text": "°"
},
{
"id": 254,
"logprob": -0.19250488,
"special": false,
"text": "C"
},
{
"id": 1,
"logprob": -0.030166626,
"special": true,
"text": "</s>"
}
]
},
"generated_text": "-196 °C"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 7,
"prefill": [
{
"id": 0,
"logprob": null,
"text": "<pad>"
}
],
"seed": null,
"tokens": [
{
"id": 3,
"logprob": -0.7001953,
"special": false,
"text": " "
},
{
"id": 18,
"logprob": -1.1943359,
"special": false,
"text": "-"
},
{
"id": 26937,
"logprob": -1.2119141,
"special": false,
"text": "196"
},
{
"id": 3,
"logprob": -1.2480469,
"special": false,
"text": " "
},
{
"id": 1956,
"logprob": -0.33203125,
"special": false,
"text": "°"
},
{
"id": 254,
"logprob": -0.19250488,
"special": false,
"text": "C"
},
{
"id": 1,
"logprob": -0.030166626,
"special": true,
"text": "</s>"
}
]
},
"generated_text": "-196 °C"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 7,
"prefill": [
{
"id": 0,
"logprob": null,
"text": "<pad>"
}
],
"seed": null,
"tokens": [
{
"id": 3,
"logprob": -0.7001953,
"special": false,
"text": " "
},
{
"id": 18,
"logprob": -1.1943359,
"special": false,
"text": "-"
},
{
"id": 26937,
"logprob": -1.2119141,
"special": false,
"text": "196"
},
{
"id": 3,
"logprob": -1.2480469,
"special": false,
"text": " "
},
{
"id": 1956,
"logprob": -0.33203125,
"special": false,
"text": "°"
},
{
"id": 254,
"logprob": -0.19250488,
"special": false,
"text": "C"
},
{
"id": 1,
"logprob": -0.030166626,
"special": true,
"text": "</s>"
}
]
},
"generated_text": "-196 °C"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 7,
"prefill": [
{
"id": 0,
"logprob": null,
"text": "<pad>"
}
],
"seed": null,
"tokens": [
{
"id": 3,
"logprob": -0.7001953,
"special": false,
"text": " "
},
{
"id": 18,
"logprob": -1.1943359,
"special": false,
"text": "-"
},
{
"id": 26937,
"logprob": -1.2099609,
"special": false,
"text": "196"
},
{
"id": 3,
"logprob": -1.2451172,
"special": false,
"text": " "
},
{
"id": 1956,
"logprob": -0.3322754,
"special": false,
"text": "°"
},
{
"id": 254,
"logprob": -0.19213867,
"special": false,
"text": "C"
},
{
"id": 1,
"logprob": -0.030151367,
"special": true,
"text": "</s>"
}
]
},
"generated_text": "-196 °C"
}
]

View File

@ -36,6 +36,8 @@ async def test_flash_neox_load(flash_neox, generate_load, response_snapshot):
generated_texts = [r.generated_text for r in responses]
assert len(generated_texts) == 4
assert generated_texts, all([text == generated_texts[0] for text in generated_texts])
assert generated_texts, all(
[text == generated_texts[0] for text in generated_texts]
)
assert responses == response_snapshot

View File

@ -0,0 +1,38 @@
import pytest
@pytest.fixture(scope="module")
def t5_sharded_handle(launcher):
with launcher("google/flan-t5-xxl", num_shard=2) as handle:
yield handle
@pytest.fixture(scope="module")
async def t5_sharded(t5_sharded_handle):
await t5_sharded_handle.health(240)
return t5_sharded_handle.client
@pytest.mark.asyncio
async def test_t5_sharded(t5_sharded, response_snapshot):
response = await t5_sharded.generate(
"Please answer the following question. What is the boiling point of Nitrogen?",
max_new_tokens=10,
)
assert response == response_snapshot
@pytest.mark.asyncio
async def test_t5_sharded_load(t5_sharded, generate_load, response_snapshot):
responses = await generate_load(
t5_sharded,
"Please answer the following question. What is the boiling point of Nitrogen?",
max_new_tokens=10,
n=4,
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot

View File

@ -160,7 +160,10 @@ class BLOOMSharded(BLOOM):
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding) or name == "lm_head.weight":
elif (
isinstance(module, TensorParallelEmbedding)
or name == "lm_head.weight"
):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size

View File

@ -40,7 +40,7 @@ class T5Sharded(Seq2SeqLM):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
dtype = torch.float16
else:
device = torch.device("cpu")
dtype = torch.float32
@ -154,9 +154,15 @@ class T5Sharded(Seq2SeqLM):
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous().to(dtype)
tensor = tensor.contiguous()
if quantize == "bitsandbytes":
# See: https://github.com/huggingface/transformers/blob/1fe1e3caa44617047f149bcc0c0b566343b714a7/src/transformers/models/t5/modeling_t5.py#LL316C15-L316C71
if module_name.endswith("wo"):
tensor = tensor.to(torch.float32)
else:
tensor = tensor.to(dtype)
if quantize == "bitsandbytes" and not module_name.endswith("wo"):
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
@ -207,7 +213,7 @@ class T5Sharded(Seq2SeqLM):
module.linear = replace_linear(state)
elif quantize == "gptq":
elif quantize == "gptq" and not module_name.endswith("wo"):
raise NotImplementedError(
"`gptq` is not implemented for now"
)