From b95732180dc52be869e8c3e752a9c54608a6c7a5 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 26 Jan 2024 14:00:29 +0100 Subject: [PATCH] Reinstate exl2 with tp (#1490) # What does this PR do? Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- ...t_flash_starcoder_gptq_default_params.json | 26 +++++++++---------- .../utils/gptq/exllamav2.py | 4 +++ server/text_generation_server/utils/layers.py | 12 ++++----- 3 files changed, 23 insertions(+), 19 deletions(-) diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json index 5598a2ad..1ace3814 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json +++ b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json @@ -16,52 +16,52 @@ }, { "id": 21017, - "logprob": -9.09375, + "logprob": -9.0859375, "text": "ometric" }, { "id": 81, - "logprob": -0.25976562, + "logprob": -0.25830078, "text": "_" }, { "id": 6009, - "logprob": -2.2148438, + "logprob": -2.1875, "text": "mean" }, { "id": 26, - "logprob": -0.3010254, + "logprob": -0.30004883, "text": "(" }, { "id": 62, - "logprob": -5.6757812, + "logprob": -5.6171875, "text": "L" }, { "id": 44, - "logprob": -3.0898438, + "logprob": -3.078125, "text": ":" }, { "id": 1682, - "logprob": -0.6791992, + "logprob": -0.68066406, "text": " List" }, { "id": 77, - "logprob": -0.38891602, + "logprob": -0.38745117, "text": "[" }, { "id": 1808, - "logprob": -0.92041016, + "logprob": -0.9453125, "text": "float" }, { "id": 10794, - "logprob": -2.5390625, + "logprob": -2.5371094, "text": "]):" } ], @@ -69,7 +69,7 @@ "tokens": [ { "id": 284, - "logprob": 0.0, + "logprob": -0.051635742, "special": false, "text": "\n " }, @@ -81,7 +81,7 @@ }, { "id": 11665, - "logprob": -1.6005859, + "logprob": -1.2236328, "special": false, "text": " reduce" }, @@ -159,7 +159,7 @@ }, { "id": 203, - "logprob": -0.11968994, + "logprob": -0.12695312, "special": false, "text": "\n" }, diff --git a/server/text_generation_server/utils/gptq/exllamav2.py b/server/text_generation_server/utils/gptq/exllamav2.py index a24e834b..2b897f25 100644 --- a/server/text_generation_server/utils/gptq/exllamav2.py +++ b/server/text_generation_server/utils/gptq/exllamav2.py @@ -185,6 +185,10 @@ class QuantLinear(nn.Module): "g_idx": self.g_idx, } temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size()) + + # We NEED to keep a pointer on Python side, otherwise the garbage collector will mess with us, + # and `Memory access fault by GPU node-2` will EAT you. + self.temp_dq = temp_dq self.q_handle = ext_make_q_matrix(self.q_tensors, temp_dq) def forward(self, x, force_cuda=False): diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 5a0de0d7..c9393d99 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -35,12 +35,12 @@ except Exception: HAS_EXLLAMA = False CAN_EXLLAMA = major >= 8 V2 = os.getenv("EXLLAMA_VERSION", "2") == "2" -if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1: - V2 = False - log_once( - logger.warning, - "Disabling exllama v2 and using v1 instead because there are issues when sharding", - ) +# if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1: +# V2 = False +# log_once( +# logger.warning, +# "Disabling exllama v2 and using v1 instead because there are issues when sharding", +# ) if os.getenv("DISABLE_EXLLAMA") == "True": HAS_EXLLAMA = False