parent
94377efa78
commit
cfaa858070
|
@ -0,0 +1,60 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "eos_token",
|
||||
"generated_tokens": 7,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 0,
|
||||
"logprob": null,
|
||||
"text": "<pad>"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 3,
|
||||
"logprob": -0.7001953,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 18,
|
||||
"logprob": -1.1943359,
|
||||
"special": false,
|
||||
"text": "-"
|
||||
},
|
||||
{
|
||||
"id": 26937,
|
||||
"logprob": -1.2099609,
|
||||
"special": false,
|
||||
"text": "196"
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"logprob": -1.2451172,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 1956,
|
||||
"logprob": -0.3322754,
|
||||
"special": false,
|
||||
"text": "°"
|
||||
},
|
||||
{
|
||||
"id": 254,
|
||||
"logprob": -0.19213867,
|
||||
"special": false,
|
||||
"text": "C"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": -0.030151367,
|
||||
"special": true,
|
||||
"text": "</s>"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "-196 °C"
|
||||
}
|
|
@ -0,0 +1,242 @@
|
|||
[
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "eos_token",
|
||||
"generated_tokens": 7,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 0,
|
||||
"logprob": null,
|
||||
"text": "<pad>"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 3,
|
||||
"logprob": -0.7001953,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 18,
|
||||
"logprob": -1.1943359,
|
||||
"special": false,
|
||||
"text": "-"
|
||||
},
|
||||
{
|
||||
"id": 26937,
|
||||
"logprob": -1.2119141,
|
||||
"special": false,
|
||||
"text": "196"
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"logprob": -1.2480469,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 1956,
|
||||
"logprob": -0.33203125,
|
||||
"special": false,
|
||||
"text": "°"
|
||||
},
|
||||
{
|
||||
"id": 254,
|
||||
"logprob": -0.19250488,
|
||||
"special": false,
|
||||
"text": "C"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": -0.030166626,
|
||||
"special": true,
|
||||
"text": "</s>"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "-196 °C"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "eos_token",
|
||||
"generated_tokens": 7,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 0,
|
||||
"logprob": null,
|
||||
"text": "<pad>"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 3,
|
||||
"logprob": -0.7001953,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 18,
|
||||
"logprob": -1.1943359,
|
||||
"special": false,
|
||||
"text": "-"
|
||||
},
|
||||
{
|
||||
"id": 26937,
|
||||
"logprob": -1.2119141,
|
||||
"special": false,
|
||||
"text": "196"
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"logprob": -1.2480469,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 1956,
|
||||
"logprob": -0.33203125,
|
||||
"special": false,
|
||||
"text": "°"
|
||||
},
|
||||
{
|
||||
"id": 254,
|
||||
"logprob": -0.19250488,
|
||||
"special": false,
|
||||
"text": "C"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": -0.030166626,
|
||||
"special": true,
|
||||
"text": "</s>"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "-196 °C"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "eos_token",
|
||||
"generated_tokens": 7,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 0,
|
||||
"logprob": null,
|
||||
"text": "<pad>"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 3,
|
||||
"logprob": -0.7001953,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 18,
|
||||
"logprob": -1.1943359,
|
||||
"special": false,
|
||||
"text": "-"
|
||||
},
|
||||
{
|
||||
"id": 26937,
|
||||
"logprob": -1.2119141,
|
||||
"special": false,
|
||||
"text": "196"
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"logprob": -1.2480469,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 1956,
|
||||
"logprob": -0.33203125,
|
||||
"special": false,
|
||||
"text": "°"
|
||||
},
|
||||
{
|
||||
"id": 254,
|
||||
"logprob": -0.19250488,
|
||||
"special": false,
|
||||
"text": "C"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": -0.030166626,
|
||||
"special": true,
|
||||
"text": "</s>"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "-196 °C"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "eos_token",
|
||||
"generated_tokens": 7,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 0,
|
||||
"logprob": null,
|
||||
"text": "<pad>"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 3,
|
||||
"logprob": -0.7001953,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 18,
|
||||
"logprob": -1.1943359,
|
||||
"special": false,
|
||||
"text": "-"
|
||||
},
|
||||
{
|
||||
"id": 26937,
|
||||
"logprob": -1.2099609,
|
||||
"special": false,
|
||||
"text": "196"
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"logprob": -1.2451172,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 1956,
|
||||
"logprob": -0.3322754,
|
||||
"special": false,
|
||||
"text": "°"
|
||||
},
|
||||
{
|
||||
"id": 254,
|
||||
"logprob": -0.19213867,
|
||||
"special": false,
|
||||
"text": "C"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": -0.030151367,
|
||||
"special": true,
|
||||
"text": "</s>"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "-196 °C"
|
||||
}
|
||||
]
|
|
@ -36,6 +36,8 @@ async def test_flash_neox_load(flash_neox, generate_load, response_snapshot):
|
|||
generated_texts = [r.generated_text for r in responses]
|
||||
|
||||
assert len(generated_texts) == 4
|
||||
assert generated_texts, all([text == generated_texts[0] for text in generated_texts])
|
||||
assert generated_texts, all(
|
||||
[text == generated_texts[0] for text in generated_texts]
|
||||
)
|
||||
|
||||
assert responses == response_snapshot
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def t5_sharded_handle(launcher):
|
||||
with launcher("google/flan-t5-xxl", num_shard=2) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def t5_sharded(t5_sharded_handle):
|
||||
await t5_sharded_handle.health(240)
|
||||
return t5_sharded_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_t5_sharded(t5_sharded, response_snapshot):
|
||||
response = await t5_sharded.generate(
|
||||
"Please answer the following question. What is the boiling point of Nitrogen?",
|
||||
max_new_tokens=10,
|
||||
)
|
||||
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_t5_sharded_load(t5_sharded, generate_load, response_snapshot):
|
||||
responses = await generate_load(
|
||||
t5_sharded,
|
||||
"Please answer the following question. What is the boiling point of Nitrogen?",
|
||||
max_new_tokens=10,
|
||||
n=4,
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
|
||||
assert responses == response_snapshot
|
|
@ -160,7 +160,10 @@ class BLOOMSharded(BLOOM):
|
|||
# XXX: Hack for Rowlinear to add the bias only once.
|
||||
if rank != 0:
|
||||
tensor = torch.zeros_like(tensor)
|
||||
elif isinstance(module, TensorParallelEmbedding) or name == "lm_head.weight":
|
||||
elif (
|
||||
isinstance(module, TensorParallelEmbedding)
|
||||
or name == "lm_head.weight"
|
||||
):
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
|
|
|
@ -40,7 +40,7 @@ class T5Sharded(Seq2SeqLM):
|
|||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||
dtype = torch.float16
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
@ -154,9 +154,15 @@ class T5Sharded(Seq2SeqLM):
|
|||
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
|
||||
)
|
||||
|
||||
tensor = tensor.contiguous().to(dtype)
|
||||
tensor = tensor.contiguous()
|
||||
|
||||
if quantize == "bitsandbytes":
|
||||
# See: https://github.com/huggingface/transformers/blob/1fe1e3caa44617047f149bcc0c0b566343b714a7/src/transformers/models/t5/modeling_t5.py#LL316C15-L316C71
|
||||
if module_name.endswith("wo"):
|
||||
tensor = tensor.to(torch.float32)
|
||||
else:
|
||||
tensor = tensor.to(dtype)
|
||||
|
||||
if quantize == "bitsandbytes" and not module_name.endswith("wo"):
|
||||
if not HAS_BITS_AND_BYTES:
|
||||
raise ImportError(
|
||||
"bitsandbytes is not available on your machine either because it is not installed "
|
||||
|
@ -207,7 +213,7 @@ class T5Sharded(Seq2SeqLM):
|
|||
|
||||
module.linear = replace_linear(state)
|
||||
|
||||
elif quantize == "gptq":
|
||||
elif quantize == "gptq" and not module_name.endswith("wo"):
|
||||
raise NotImplementedError(
|
||||
"`gptq` is not implemented for now"
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue