parent
94377efa78
commit
cfaa858070
|
@ -0,0 +1,60 @@
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "eos_token",
|
||||||
|
"generated_tokens": 7,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<pad>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 3,
|
||||||
|
"logprob": -0.7001953,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18,
|
||||||
|
"logprob": -1.1943359,
|
||||||
|
"special": false,
|
||||||
|
"text": "-"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26937,
|
||||||
|
"logprob": -1.2099609,
|
||||||
|
"special": false,
|
||||||
|
"text": "196"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3,
|
||||||
|
"logprob": -1.2451172,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1956,
|
||||||
|
"logprob": -0.3322754,
|
||||||
|
"special": false,
|
||||||
|
"text": "°"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 254,
|
||||||
|
"logprob": -0.19213867,
|
||||||
|
"special": false,
|
||||||
|
"text": "C"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": -0.030151367,
|
||||||
|
"special": true,
|
||||||
|
"text": "</s>"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"generated_text": "-196 °C"
|
||||||
|
}
|
|
@ -0,0 +1,242 @@
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "eos_token",
|
||||||
|
"generated_tokens": 7,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<pad>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 3,
|
||||||
|
"logprob": -0.7001953,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18,
|
||||||
|
"logprob": -1.1943359,
|
||||||
|
"special": false,
|
||||||
|
"text": "-"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26937,
|
||||||
|
"logprob": -1.2119141,
|
||||||
|
"special": false,
|
||||||
|
"text": "196"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3,
|
||||||
|
"logprob": -1.2480469,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1956,
|
||||||
|
"logprob": -0.33203125,
|
||||||
|
"special": false,
|
||||||
|
"text": "°"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 254,
|
||||||
|
"logprob": -0.19250488,
|
||||||
|
"special": false,
|
||||||
|
"text": "C"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": -0.030166626,
|
||||||
|
"special": true,
|
||||||
|
"text": "</s>"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"generated_text": "-196 °C"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "eos_token",
|
||||||
|
"generated_tokens": 7,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<pad>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 3,
|
||||||
|
"logprob": -0.7001953,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18,
|
||||||
|
"logprob": -1.1943359,
|
||||||
|
"special": false,
|
||||||
|
"text": "-"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26937,
|
||||||
|
"logprob": -1.2119141,
|
||||||
|
"special": false,
|
||||||
|
"text": "196"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3,
|
||||||
|
"logprob": -1.2480469,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1956,
|
||||||
|
"logprob": -0.33203125,
|
||||||
|
"special": false,
|
||||||
|
"text": "°"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 254,
|
||||||
|
"logprob": -0.19250488,
|
||||||
|
"special": false,
|
||||||
|
"text": "C"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": -0.030166626,
|
||||||
|
"special": true,
|
||||||
|
"text": "</s>"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"generated_text": "-196 °C"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "eos_token",
|
||||||
|
"generated_tokens": 7,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<pad>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 3,
|
||||||
|
"logprob": -0.7001953,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18,
|
||||||
|
"logprob": -1.1943359,
|
||||||
|
"special": false,
|
||||||
|
"text": "-"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26937,
|
||||||
|
"logprob": -1.2119141,
|
||||||
|
"special": false,
|
||||||
|
"text": "196"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3,
|
||||||
|
"logprob": -1.2480469,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1956,
|
||||||
|
"logprob": -0.33203125,
|
||||||
|
"special": false,
|
||||||
|
"text": "°"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 254,
|
||||||
|
"logprob": -0.19250488,
|
||||||
|
"special": false,
|
||||||
|
"text": "C"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": -0.030166626,
|
||||||
|
"special": true,
|
||||||
|
"text": "</s>"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"generated_text": "-196 °C"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "eos_token",
|
||||||
|
"generated_tokens": 7,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<pad>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 3,
|
||||||
|
"logprob": -0.7001953,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18,
|
||||||
|
"logprob": -1.1943359,
|
||||||
|
"special": false,
|
||||||
|
"text": "-"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26937,
|
||||||
|
"logprob": -1.2099609,
|
||||||
|
"special": false,
|
||||||
|
"text": "196"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3,
|
||||||
|
"logprob": -1.2451172,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1956,
|
||||||
|
"logprob": -0.3322754,
|
||||||
|
"special": false,
|
||||||
|
"text": "°"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 254,
|
||||||
|
"logprob": -0.19213867,
|
||||||
|
"special": false,
|
||||||
|
"text": "C"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": -0.030151367,
|
||||||
|
"special": true,
|
||||||
|
"text": "</s>"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"generated_text": "-196 °C"
|
||||||
|
}
|
||||||
|
]
|
|
@ -36,6 +36,8 @@ async def test_flash_neox_load(flash_neox, generate_load, response_snapshot):
|
||||||
generated_texts = [r.generated_text for r in responses]
|
generated_texts = [r.generated_text for r in responses]
|
||||||
|
|
||||||
assert len(generated_texts) == 4
|
assert len(generated_texts) == 4
|
||||||
assert generated_texts, all([text == generated_texts[0] for text in generated_texts])
|
assert generated_texts, all(
|
||||||
|
[text == generated_texts[0] for text in generated_texts]
|
||||||
|
)
|
||||||
|
|
||||||
assert responses == response_snapshot
|
assert responses == response_snapshot
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def t5_sharded_handle(launcher):
|
||||||
|
with launcher("google/flan-t5-xxl", num_shard=2) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def t5_sharded(t5_sharded_handle):
|
||||||
|
await t5_sharded_handle.health(240)
|
||||||
|
return t5_sharded_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_t5_sharded(t5_sharded, response_snapshot):
|
||||||
|
response = await t5_sharded.generate(
|
||||||
|
"Please answer the following question. What is the boiling point of Nitrogen?",
|
||||||
|
max_new_tokens=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_t5_sharded_load(t5_sharded, generate_load, response_snapshot):
|
||||||
|
responses = await generate_load(
|
||||||
|
t5_sharded,
|
||||||
|
"Please answer the following question. What is the boiling point of Nitrogen?",
|
||||||
|
max_new_tokens=10,
|
||||||
|
n=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
|
@ -160,7 +160,10 @@ class BLOOMSharded(BLOOM):
|
||||||
# XXX: Hack for Rowlinear to add the bias only once.
|
# XXX: Hack for Rowlinear to add the bias only once.
|
||||||
if rank != 0:
|
if rank != 0:
|
||||||
tensor = torch.zeros_like(tensor)
|
tensor = torch.zeros_like(tensor)
|
||||||
elif isinstance(module, TensorParallelEmbedding) or name == "lm_head.weight":
|
elif (
|
||||||
|
isinstance(module, TensorParallelEmbedding)
|
||||||
|
or name == "lm_head.weight"
|
||||||
|
):
|
||||||
size = slice_.get_shape()[0]
|
size = slice_.get_shape()[0]
|
||||||
block_size = size // world_size
|
block_size = size // world_size
|
||||||
start = rank * block_size
|
start = rank * block_size
|
||||||
|
|
|
@ -40,7 +40,7 @@ class T5Sharded(Seq2SeqLM):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
dtype = torch.float16
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
@ -154,9 +154,15 @@ class T5Sharded(Seq2SeqLM):
|
||||||
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
|
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
|
||||||
)
|
)
|
||||||
|
|
||||||
tensor = tensor.contiguous().to(dtype)
|
tensor = tensor.contiguous()
|
||||||
|
|
||||||
if quantize == "bitsandbytes":
|
# See: https://github.com/huggingface/transformers/blob/1fe1e3caa44617047f149bcc0c0b566343b714a7/src/transformers/models/t5/modeling_t5.py#LL316C15-L316C71
|
||||||
|
if module_name.endswith("wo"):
|
||||||
|
tensor = tensor.to(torch.float32)
|
||||||
|
else:
|
||||||
|
tensor = tensor.to(dtype)
|
||||||
|
|
||||||
|
if quantize == "bitsandbytes" and not module_name.endswith("wo"):
|
||||||
if not HAS_BITS_AND_BYTES:
|
if not HAS_BITS_AND_BYTES:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"bitsandbytes is not available on your machine either because it is not installed "
|
"bitsandbytes is not available on your machine either because it is not installed "
|
||||||
|
@ -207,7 +213,7 @@ class T5Sharded(Seq2SeqLM):
|
||||||
|
|
||||||
module.linear = replace_linear(state)
|
module.linear = replace_linear(state)
|
||||||
|
|
||||||
elif quantize == "gptq":
|
elif quantize == "gptq" and not module_name.endswith("wo"):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"`gptq` is not implemented for now"
|
"`gptq` is not implemented for now"
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue