From 3f12750a186ecdc12e03b5599c21fa0914048597 Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 9 Aug 2024 16:39:16 +0000 Subject: [PATCH] fix: marlin repeat scale for fp8 and bump snapshots --- .../test_flash_llama_fp8.json | 64 ++--- .../test_flash_llama_fp8_all_params.json | 66 ++--- .../test_flash_llama_fp8_load.json | 256 +++++++++--------- .../models/test_flash_llama_fp8.py | 19 +- .../layers/marlin/fp8.py | 2 +- 5 files changed, 188 insertions(+), 219 deletions(-) diff --git a/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8.json b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8.json index 85cfb91f..e307dd93 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8.json +++ b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8.json @@ -11,79 +11,79 @@ }, { "id": 2323, - "logprob": -9.421875, + "logprob": -9.6015625, "text": "Test" }, { "id": 1715, - "logprob": -10.546875, + "logprob": -10.515625, "text": " request" } ], "seed": null, "tokens": [ { - "id": 369, - "logprob": -2.1816406, + "id": 25, + "logprob": -2.1914062, "special": false, - "text": " for" - }, - { - "id": 279, - "logprob": -2.6992188, - "special": false, - "text": " the" + "text": ":" }, { "id": 220, - "logprob": -3.6308594, + "logprob": -3.7324219, "special": false, "text": " " }, { - "id": 679, - "logprob": -1.7900391, + "id": 16, + "logprob": -2.2753906, "special": false, - "text": "201" + "text": "1" }, { - "id": 24, - "logprob": -1.3554688, + "id": 13, + "logprob": -1.2070312, "special": false, - "text": "9" + "text": "." }, { - "id": 12, - "logprob": -2.0039062, + "id": 20, + "logprob": -2.765625, "special": false, - "text": "-" + "text": "5" }, { - "id": 2366, - "logprob": -0.4489746, + "id": 13, + "logprob": -1.1884766, "special": false, - "text": "202" + "text": "." }, { "id": 15, - "logprob": -0.037109375, + "logprob": -1.5126953, "special": false, "text": "0" }, { - "id": 2978, - "logprob": -0.8100586, + "id": 12, + "logprob": -2.078125, "special": false, - "text": " school" + "text": "-" }, { - "id": 1060, - "logprob": -0.013015747, + "id": 1310, + "logprob": -0.7158203, "special": false, - "text": " year" + "text": "rc" + }, + { + "id": 16, + "logprob": -1.0234375, + "special": false, + "text": "1" } ], "top_tokens": null }, - "generated_text": " for the 2019-2020 school year" + "generated_text": ": 1.5.0-rc1" } diff --git a/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json index bf981e4f..9d3e25ed 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json +++ b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json @@ -1,8 +1,8 @@ { "details": { "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, + "finish_reason": "stop_sequence", + "generated_tokens": 5, "prefill": [ { "id": 128000, @@ -11,12 +11,12 @@ }, { "id": 2323, - "logprob": -9.5625, + "logprob": -9.6015625, "text": "Test" }, { "id": 1715, - "logprob": -10.375, + "logprob": -10.515625, "text": " request" } ], @@ -24,66 +24,36 @@ "tokens": [ { "id": 25, - "logprob": -0.8984375, + "logprob": -0.81103516, "special": false, "text": ":" }, { - "id": 2209, - "logprob": -2.78125, - "special": false, - "text": " Is" - }, - { - "id": 279, - "logprob": -0.6328125, - "special": false, - "text": " the" - }, - { - "id": 734, + "id": 923, "logprob": -2.703125, "special": false, - "text": " function" + "text": " add" + }, + { + "id": 264, + "logprob": 0.0, + "special": false, + "text": " a" }, { "id": 330, - "logprob": -0.34179688, + "logprob": -0.1862793, "special": false, "text": " \"" }, { - "id": 4110, - "logprob": -2.359375, + "id": 1985, + "logprob": 0.0, "special": false, - "text": "Create" - }, - { - "id": 7575, - "logprob": -2.1875, - "special": false, - "text": "Process" - }, - { - "id": 1, - "logprob": -0.07910156, - "special": false, - "text": "\"" - }, - { - "id": 304, - "logprob": -0.83203125, - "special": false, - "text": " in" - }, - { - "id": 12468, - "logprob": -1.8203125, - "special": false, - "text": " Win" + "text": "test" } ], "top_tokens": null }, - "generated_text": "Test request: Is the function \"CreateProcess\" in Win" + "generated_text": "Test request: add a \"test" } diff --git a/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_load.json b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_load.json index 36c87c09..5019e7cc 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_load.json +++ b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_load.json @@ -12,81 +12,81 @@ }, { "id": 2323, - "logprob": -9.421875, + "logprob": -9.6015625, "text": "Test" }, { "id": 1715, - "logprob": -10.546875, + "logprob": -10.515625, "text": " request" } ], "seed": null, "tokens": [ { - "id": 369, - "logprob": -2.1816406, + "id": 25, + "logprob": -2.1914062, "special": false, - "text": " for" - }, - { - "id": 279, - "logprob": -2.6992188, - "special": false, - "text": " the" + "text": ":" }, { "id": 220, - "logprob": -3.6308594, + "logprob": -3.7421875, "special": false, "text": " " }, { - "id": 679, - "logprob": -1.7988281, + "id": 16, + "logprob": -2.2753906, "special": false, - "text": "201" + "text": "1" }, { - "id": 24, - "logprob": -1.3535156, + "id": 13, + "logprob": -1.2041016, "special": false, - "text": "9" + "text": "." }, { - "id": 12, - "logprob": -2.0058594, + "id": 20, + "logprob": -2.7675781, "special": false, - "text": "-" + "text": "5" }, { - "id": 2366, - "logprob": -0.45410156, + "id": 13, + "logprob": -1.1884766, "special": false, - "text": "202" + "text": "." }, { "id": 15, - "logprob": -0.037109375, + "logprob": -1.5244141, "special": false, "text": "0" }, { - "id": 2978, - "logprob": -0.8095703, + "id": 12, + "logprob": -2.0761719, "special": false, - "text": " school" + "text": "-" }, { - "id": 1060, - "logprob": -0.013053894, + "id": 1310, + "logprob": -0.71484375, "special": false, - "text": " year" + "text": "rc" + }, + { + "id": 16, + "logprob": -1.0244141, + "special": false, + "text": "1" } ], "top_tokens": null }, - "generated_text": " for the 2019-2020 school year" + "generated_text": ": 1.5.0-rc1" }, { "details": { @@ -101,81 +101,81 @@ }, { "id": 2323, - "logprob": -9.421875, + "logprob": -9.6015625, "text": "Test" }, { "id": 1715, - "logprob": -10.546875, + "logprob": -10.515625, "text": " request" } ], "seed": null, "tokens": [ { - "id": 369, - "logprob": -2.1816406, + "id": 25, + "logprob": -2.1914062, "special": false, - "text": " for" - }, - { - "id": 279, - "logprob": -2.6992188, - "special": false, - "text": " the" + "text": ":" }, { "id": 220, - "logprob": -3.6308594, + "logprob": -3.7421875, "special": false, "text": " " }, { - "id": 679, - "logprob": -1.7988281, + "id": 16, + "logprob": -2.2753906, "special": false, - "text": "201" + "text": "1" }, { - "id": 24, - "logprob": -1.3535156, + "id": 13, + "logprob": -1.2041016, "special": false, - "text": "9" + "text": "." }, { - "id": 12, - "logprob": -2.0058594, + "id": 20, + "logprob": -2.7675781, "special": false, - "text": "-" + "text": "5" }, { - "id": 2366, - "logprob": -0.45410156, + "id": 13, + "logprob": -1.1884766, "special": false, - "text": "202" + "text": "." }, { "id": 15, - "logprob": -0.037109375, + "logprob": -1.5244141, "special": false, "text": "0" }, { - "id": 2978, - "logprob": -0.8095703, + "id": 12, + "logprob": -2.0761719, "special": false, - "text": " school" + "text": "-" }, { - "id": 1060, - "logprob": -0.013053894, + "id": 1310, + "logprob": -0.71484375, "special": false, - "text": " year" + "text": "rc" + }, + { + "id": 16, + "logprob": -1.0244141, + "special": false, + "text": "1" } ], "top_tokens": null }, - "generated_text": " for the 2019-2020 school year" + "generated_text": ": 1.5.0-rc1" }, { "details": { @@ -190,81 +190,81 @@ }, { "id": 2323, - "logprob": -9.421875, + "logprob": -9.6015625, "text": "Test" }, { "id": 1715, - "logprob": -10.546875, + "logprob": -10.515625, "text": " request" } ], "seed": null, "tokens": [ { - "id": 369, - "logprob": -2.1816406, + "id": 25, + "logprob": -2.1914062, "special": false, - "text": " for" - }, - { - "id": 279, - "logprob": -2.6992188, - "special": false, - "text": " the" + "text": ":" }, { "id": 220, - "logprob": -3.6308594, + "logprob": -3.7421875, "special": false, "text": " " }, { - "id": 679, - "logprob": -1.7988281, + "id": 16, + "logprob": -2.2753906, "special": false, - "text": "201" + "text": "1" }, { - "id": 24, - "logprob": -1.3535156, + "id": 13, + "logprob": -1.2041016, "special": false, - "text": "9" + "text": "." }, { - "id": 12, - "logprob": -2.0058594, + "id": 20, + "logprob": -2.7675781, "special": false, - "text": "-" + "text": "5" }, { - "id": 2366, - "logprob": -0.45410156, + "id": 13, + "logprob": -1.1884766, "special": false, - "text": "202" + "text": "." }, { "id": 15, - "logprob": -0.037109375, + "logprob": -1.5244141, "special": false, "text": "0" }, { - "id": 2978, - "logprob": -0.8095703, + "id": 12, + "logprob": -2.0761719, "special": false, - "text": " school" + "text": "-" }, { - "id": 1060, - "logprob": -0.013053894, + "id": 1310, + "logprob": -0.71484375, "special": false, - "text": " year" + "text": "rc" + }, + { + "id": 16, + "logprob": -1.0244141, + "special": false, + "text": "1" } ], "top_tokens": null }, - "generated_text": " for the 2019-2020 school year" + "generated_text": ": 1.5.0-rc1" }, { "details": { @@ -279,80 +279,80 @@ }, { "id": 2323, - "logprob": -9.421875, + "logprob": -9.6015625, "text": "Test" }, { "id": 1715, - "logprob": -10.546875, + "logprob": -10.515625, "text": " request" } ], "seed": null, "tokens": [ { - "id": 369, - "logprob": -2.1816406, + "id": 25, + "logprob": -2.1914062, "special": false, - "text": " for" - }, - { - "id": 279, - "logprob": -2.6992188, - "special": false, - "text": " the" + "text": ":" }, { "id": 220, - "logprob": -3.6308594, + "logprob": -3.7421875, "special": false, "text": " " }, { - "id": 679, - "logprob": -1.7988281, + "id": 16, + "logprob": -2.2753906, "special": false, - "text": "201" + "text": "1" }, { - "id": 24, - "logprob": -1.3535156, + "id": 13, + "logprob": -1.2041016, "special": false, - "text": "9" + "text": "." }, { - "id": 12, - "logprob": -2.0058594, + "id": 20, + "logprob": -2.7675781, "special": false, - "text": "-" + "text": "5" }, { - "id": 2366, - "logprob": -0.45410156, + "id": 13, + "logprob": -1.1884766, "special": false, - "text": "202" + "text": "." }, { "id": 15, - "logprob": -0.037109375, + "logprob": -1.5244141, "special": false, "text": "0" }, { - "id": 2978, - "logprob": -0.8095703, + "id": 12, + "logprob": -2.0761719, "special": false, - "text": " school" + "text": "-" }, { - "id": 1060, - "logprob": -0.013053894, + "id": 1310, + "logprob": -0.71484375, "special": false, - "text": " year" + "text": "rc" + }, + { + "id": 16, + "logprob": -1.0244141, + "special": false, + "text": "1" } ], "top_tokens": null }, - "generated_text": " for the 2019-2020 school year" + "generated_text": ": 1.5.0-rc1" } ] diff --git a/integration-tests/models/test_flash_llama_fp8.py b/integration-tests/models/test_flash_llama_fp8.py index 808d1329..bc7458b7 100644 --- a/integration-tests/models/test_flash_llama_fp8.py +++ b/integration-tests/models/test_flash_llama_fp8.py @@ -48,16 +48,15 @@ async def test_flash_llama_fp8_all_params(flash_llama_fp8, response_snapshot): assert response == response_snapshot -# TODO: fix and re-enable # @pytest.mark.release -# @pytest.mark.asyncio -# @pytest.mark.private -# async def test_flash_llama_fp8_load(flash_llama_fp8, generate_load, response_snapshot): -# responses = await generate_load( -# flash_llama_fp8, "Test request", max_new_tokens=10, n=4 -# ) +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_fp8_load(flash_llama_fp8, generate_load, response_snapshot): + responses = await generate_load( + flash_llama_fp8, "Test request", max_new_tokens=10, n=4 + ) -# assert len(responses) == 4 -# assert all([r.generated_text == responses[0].generated_text for r in responses]) + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) -# assert responses == response_snapshot + assert responses == response_snapshot diff --git a/server/text_generation_server/layers/marlin/fp8.py b/server/text_generation_server/layers/marlin/fp8.py index fe55a58a..42b0aa4d 100644 --- a/server/text_generation_server/layers/marlin/fp8.py +++ b/server/text_generation_server/layers/marlin/fp8.py @@ -39,7 +39,7 @@ class GPTQMarlinFP8Linear(nn.Module): log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel") scales = scales.unsqueeze(0) - if scales.shape[1] == 1: + if scales.size(0) == 1: out_features, in_features = qweight.shape scales = scales.repeat(1, out_features) qweight, scales = repack_fp8_for_marlin(qweight, scales)