From 7c2e0af2a61745a36b26fae2c817f608be757a4c Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 30 Aug 2023 11:09:46 +0200 Subject: [PATCH] Fix f180 (#951) # What does this PR do? Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- .../models/custom_modeling/flash_rw_modeling.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 14caa23d..8419fa4f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -241,19 +241,21 @@ class FlashRWLargeAttention(torch.nn.Module): hidden_size = config.hidden_size num_heads = config.n_head - num_heads_kv = config.n_head_kv + # num_heads_kv = config.n_head_kv + num_groups = config.n_head_kv self.hidden_size = hidden_size self.head_size = hidden_size // num_heads + self.num_groups = num_groups self.rotary_emb = PositionRotaryEmbedding.static( config=config, dim=self.head_size, base=10000.0, device=weights.device ) self.softmax_scale = self.head_size ** (-0.5) - self.num_groups = num_heads // (num_heads_kv * 2) + # self.num_groups = num_heads // (num_heads_kv * 2) self.num_heads = num_heads // self.num_groups - self.num_heads_kv = num_heads_kv // self.num_groups + # self.num_heads_kv = num_heads_kv // self.num_groups process_group = weights.process_group if process_group.size() > self.num_groups: @@ -264,6 +266,7 @@ class FlashRWLargeAttention(torch.nn.Module): raise NotImplementedError( f"Tensor Parallelism is not implemented for {self.num_groups} not divisible by {process_group.size()}" ) + self.num_groups = self.num_groups // process_group.size() self.query_key_value = TensorParallelColumnLinear.load(