fix: marlin repeat scale for fp8 and bump snapshots

This commit is contained in:
drbh 2024-08-09 16:39:16 +00:00
parent df9eb38733
commit 3f12750a18
5 changed files with 188 additions and 219 deletions

View File

@ -11,79 +11,79 @@
},
{
"id": 2323,
"logprob": -9.421875,
"logprob": -9.6015625,
"text": "Test"
},
{
"id": 1715,
"logprob": -10.546875,
"logprob": -10.515625,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 369,
"logprob": -2.1816406,
"id": 25,
"logprob": -2.1914062,
"special": false,
"text": " for"
},
{
"id": 279,
"logprob": -2.6992188,
"special": false,
"text": " the"
"text": ":"
},
{
"id": 220,
"logprob": -3.6308594,
"logprob": -3.7324219,
"special": false,
"text": " "
},
{
"id": 679,
"logprob": -1.7900391,
"id": 16,
"logprob": -2.2753906,
"special": false,
"text": "201"
"text": "1"
},
{
"id": 24,
"logprob": -1.3554688,
"id": 13,
"logprob": -1.2070312,
"special": false,
"text": "9"
"text": "."
},
{
"id": 12,
"logprob": -2.0039062,
"id": 20,
"logprob": -2.765625,
"special": false,
"text": "-"
"text": "5"
},
{
"id": 2366,
"logprob": -0.4489746,
"id": 13,
"logprob": -1.1884766,
"special": false,
"text": "202"
"text": "."
},
{
"id": 15,
"logprob": -0.037109375,
"logprob": -1.5126953,
"special": false,
"text": "0"
},
{
"id": 2978,
"logprob": -0.8100586,
"id": 12,
"logprob": -2.078125,
"special": false,
"text": " school"
"text": "-"
},
{
"id": 1060,
"logprob": -0.013015747,
"id": 1310,
"logprob": -0.7158203,
"special": false,
"text": " year"
"text": "rc"
},
{
"id": 16,
"logprob": -1.0234375,
"special": false,
"text": "1"
}
],
"top_tokens": null
},
"generated_text": " for the 2019-2020 school year"
"generated_text": ": 1.5.0-rc1"
}

View File

@ -1,8 +1,8 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"finish_reason": "stop_sequence",
"generated_tokens": 5,
"prefill": [
{
"id": 128000,
@ -11,12 +11,12 @@
},
{
"id": 2323,
"logprob": -9.5625,
"logprob": -9.6015625,
"text": "Test"
},
{
"id": 1715,
"logprob": -10.375,
"logprob": -10.515625,
"text": " request"
}
],
@ -24,66 +24,36 @@
"tokens": [
{
"id": 25,
"logprob": -0.8984375,
"logprob": -0.81103516,
"special": false,
"text": ":"
},
{
"id": 2209,
"logprob": -2.78125,
"special": false,
"text": " Is"
},
{
"id": 279,
"logprob": -0.6328125,
"special": false,
"text": " the"
},
{
"id": 734,
"id": 923,
"logprob": -2.703125,
"special": false,
"text": " function"
"text": " add"
},
{
"id": 264,
"logprob": 0.0,
"special": false,
"text": " a"
},
{
"id": 330,
"logprob": -0.34179688,
"logprob": -0.1862793,
"special": false,
"text": " \""
},
{
"id": 4110,
"logprob": -2.359375,
"id": 1985,
"logprob": 0.0,
"special": false,
"text": "Create"
},
{
"id": 7575,
"logprob": -2.1875,
"special": false,
"text": "Process"
},
{
"id": 1,
"logprob": -0.07910156,
"special": false,
"text": "\""
},
{
"id": 304,
"logprob": -0.83203125,
"special": false,
"text": " in"
},
{
"id": 12468,
"logprob": -1.8203125,
"special": false,
"text": " Win"
"text": "test"
}
],
"top_tokens": null
},
"generated_text": "Test request: Is the function \"CreateProcess\" in Win"
"generated_text": "Test request: add a \"test"
}

View File

@ -12,81 +12,81 @@
},
{
"id": 2323,
"logprob": -9.421875,
"logprob": -9.6015625,
"text": "Test"
},
{
"id": 1715,
"logprob": -10.546875,
"logprob": -10.515625,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 369,
"logprob": -2.1816406,
"id": 25,
"logprob": -2.1914062,
"special": false,
"text": " for"
},
{
"id": 279,
"logprob": -2.6992188,
"special": false,
"text": " the"
"text": ":"
},
{
"id": 220,
"logprob": -3.6308594,
"logprob": -3.7421875,
"special": false,
"text": " "
},
{
"id": 679,
"logprob": -1.7988281,
"id": 16,
"logprob": -2.2753906,
"special": false,
"text": "201"
"text": "1"
},
{
"id": 24,
"logprob": -1.3535156,
"id": 13,
"logprob": -1.2041016,
"special": false,
"text": "9"
"text": "."
},
{
"id": 12,
"logprob": -2.0058594,
"id": 20,
"logprob": -2.7675781,
"special": false,
"text": "-"
"text": "5"
},
{
"id": 2366,
"logprob": -0.45410156,
"id": 13,
"logprob": -1.1884766,
"special": false,
"text": "202"
"text": "."
},
{
"id": 15,
"logprob": -0.037109375,
"logprob": -1.5244141,
"special": false,
"text": "0"
},
{
"id": 2978,
"logprob": -0.8095703,
"id": 12,
"logprob": -2.0761719,
"special": false,
"text": " school"
"text": "-"
},
{
"id": 1060,
"logprob": -0.013053894,
"id": 1310,
"logprob": -0.71484375,
"special": false,
"text": " year"
"text": "rc"
},
{
"id": 16,
"logprob": -1.0244141,
"special": false,
"text": "1"
}
],
"top_tokens": null
},
"generated_text": " for the 2019-2020 school year"
"generated_text": ": 1.5.0-rc1"
},
{
"details": {
@ -101,81 +101,81 @@
},
{
"id": 2323,
"logprob": -9.421875,
"logprob": -9.6015625,
"text": "Test"
},
{
"id": 1715,
"logprob": -10.546875,
"logprob": -10.515625,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 369,
"logprob": -2.1816406,
"id": 25,
"logprob": -2.1914062,
"special": false,
"text": " for"
},
{
"id": 279,
"logprob": -2.6992188,
"special": false,
"text": " the"
"text": ":"
},
{
"id": 220,
"logprob": -3.6308594,
"logprob": -3.7421875,
"special": false,
"text": " "
},
{
"id": 679,
"logprob": -1.7988281,
"id": 16,
"logprob": -2.2753906,
"special": false,
"text": "201"
"text": "1"
},
{
"id": 24,
"logprob": -1.3535156,
"id": 13,
"logprob": -1.2041016,
"special": false,
"text": "9"
"text": "."
},
{
"id": 12,
"logprob": -2.0058594,
"id": 20,
"logprob": -2.7675781,
"special": false,
"text": "-"
"text": "5"
},
{
"id": 2366,
"logprob": -0.45410156,
"id": 13,
"logprob": -1.1884766,
"special": false,
"text": "202"
"text": "."
},
{
"id": 15,
"logprob": -0.037109375,
"logprob": -1.5244141,
"special": false,
"text": "0"
},
{
"id": 2978,
"logprob": -0.8095703,
"id": 12,
"logprob": -2.0761719,
"special": false,
"text": " school"
"text": "-"
},
{
"id": 1060,
"logprob": -0.013053894,
"id": 1310,
"logprob": -0.71484375,
"special": false,
"text": " year"
"text": "rc"
},
{
"id": 16,
"logprob": -1.0244141,
"special": false,
"text": "1"
}
],
"top_tokens": null
},
"generated_text": " for the 2019-2020 school year"
"generated_text": ": 1.5.0-rc1"
},
{
"details": {
@ -190,81 +190,81 @@
},
{
"id": 2323,
"logprob": -9.421875,
"logprob": -9.6015625,
"text": "Test"
},
{
"id": 1715,
"logprob": -10.546875,
"logprob": -10.515625,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 369,
"logprob": -2.1816406,
"id": 25,
"logprob": -2.1914062,
"special": false,
"text": " for"
},
{
"id": 279,
"logprob": -2.6992188,
"special": false,
"text": " the"
"text": ":"
},
{
"id": 220,
"logprob": -3.6308594,
"logprob": -3.7421875,
"special": false,
"text": " "
},
{
"id": 679,
"logprob": -1.7988281,
"id": 16,
"logprob": -2.2753906,
"special": false,
"text": "201"
"text": "1"
},
{
"id": 24,
"logprob": -1.3535156,
"id": 13,
"logprob": -1.2041016,
"special": false,
"text": "9"
"text": "."
},
{
"id": 12,
"logprob": -2.0058594,
"id": 20,
"logprob": -2.7675781,
"special": false,
"text": "-"
"text": "5"
},
{
"id": 2366,
"logprob": -0.45410156,
"id": 13,
"logprob": -1.1884766,
"special": false,
"text": "202"
"text": "."
},
{
"id": 15,
"logprob": -0.037109375,
"logprob": -1.5244141,
"special": false,
"text": "0"
},
{
"id": 2978,
"logprob": -0.8095703,
"id": 12,
"logprob": -2.0761719,
"special": false,
"text": " school"
"text": "-"
},
{
"id": 1060,
"logprob": -0.013053894,
"id": 1310,
"logprob": -0.71484375,
"special": false,
"text": " year"
"text": "rc"
},
{
"id": 16,
"logprob": -1.0244141,
"special": false,
"text": "1"
}
],
"top_tokens": null
},
"generated_text": " for the 2019-2020 school year"
"generated_text": ": 1.5.0-rc1"
},
{
"details": {
@ -279,80 +279,80 @@
},
{
"id": 2323,
"logprob": -9.421875,
"logprob": -9.6015625,
"text": "Test"
},
{
"id": 1715,
"logprob": -10.546875,
"logprob": -10.515625,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 369,
"logprob": -2.1816406,
"id": 25,
"logprob": -2.1914062,
"special": false,
"text": " for"
},
{
"id": 279,
"logprob": -2.6992188,
"special": false,
"text": " the"
"text": ":"
},
{
"id": 220,
"logprob": -3.6308594,
"logprob": -3.7421875,
"special": false,
"text": " "
},
{
"id": 679,
"logprob": -1.7988281,
"id": 16,
"logprob": -2.2753906,
"special": false,
"text": "201"
"text": "1"
},
{
"id": 24,
"logprob": -1.3535156,
"id": 13,
"logprob": -1.2041016,
"special": false,
"text": "9"
"text": "."
},
{
"id": 12,
"logprob": -2.0058594,
"id": 20,
"logprob": -2.7675781,
"special": false,
"text": "-"
"text": "5"
},
{
"id": 2366,
"logprob": -0.45410156,
"id": 13,
"logprob": -1.1884766,
"special": false,
"text": "202"
"text": "."
},
{
"id": 15,
"logprob": -0.037109375,
"logprob": -1.5244141,
"special": false,
"text": "0"
},
{
"id": 2978,
"logprob": -0.8095703,
"id": 12,
"logprob": -2.0761719,
"special": false,
"text": " school"
"text": "-"
},
{
"id": 1060,
"logprob": -0.013053894,
"id": 1310,
"logprob": -0.71484375,
"special": false,
"text": " year"
"text": "rc"
},
{
"id": 16,
"logprob": -1.0244141,
"special": false,
"text": "1"
}
],
"top_tokens": null
},
"generated_text": " for the 2019-2020 school year"
"generated_text": ": 1.5.0-rc1"
}
]

View File

@ -48,16 +48,15 @@ async def test_flash_llama_fp8_all_params(flash_llama_fp8, response_snapshot):
assert response == response_snapshot
# TODO: fix and re-enable
# @pytest.mark.release
# @pytest.mark.asyncio
# @pytest.mark.private
# async def test_flash_llama_fp8_load(flash_llama_fp8, generate_load, response_snapshot):
# responses = await generate_load(
# flash_llama_fp8, "Test request", max_new_tokens=10, n=4
# )
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_fp8_load(flash_llama_fp8, generate_load, response_snapshot):
responses = await generate_load(
flash_llama_fp8, "Test request", max_new_tokens=10, n=4
)
# assert len(responses) == 4
# assert all([r.generated_text == responses[0].generated_text for r in responses])
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
# assert responses == response_snapshot
assert responses == response_snapshot

View File

@ -39,7 +39,7 @@ class GPTQMarlinFP8Linear(nn.Module):
log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")
scales = scales.unsqueeze(0)
if scales.shape[1] == 1:
if scales.size(0) == 1:
out_features, in_features = qweight.shape
scales = scales.repeat(1, out_features)
qweight, scales = repack_fp8_for_marlin(qweight, scales)