From 986b4044d17067c8ff851df800f22fa9fb2ace45 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 23 Apr 2024 18:40:05 +0200 Subject: [PATCH] Phi3 support (#1797) # What does this PR do? Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- .../text_generation_server/models/__init__.py | 2 +- .../custom_modeling/flash_llama_modeling.py | 29 ++++++++++++++----- server/text_generation_server/utils/layers.py | 13 +++++++++ .../text_generation_server/utils/weights.py | 19 ++++++++---- 4 files changed, 49 insertions(+), 14 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 06792b0d..e4e8717d 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -327,7 +327,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - elif model_type == "llama" or model_type == "baichuan": + elif model_type == "llama" or model_type == "baichuan" or model_type == "phi3": if FLASH_ATTENTION: return FlashLlama( model_id, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 4cf0fcf2..6d796ac3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -101,6 +101,13 @@ def load_attention(config, prefix, weights): weights=weights, bias=False, ) + elif config.model_type == "phi3": + return TensorParallelColumnLinear.load_qkv( + config, + prefix=f"{prefix}.qkv_proj", + weights=weights, + bias=False, + ) else: return TensorParallelColumnLinear.load_multi( config, @@ -257,13 +264,21 @@ class LlamaMLP(nn.Module): ) ) # Fuse gate and up proj - self.gate_up_proj = TensorParallelColumnLinear.load_multi( - config, - prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], - weights=weights, - dim=0, - bias=False, - ) + if config.model_type == "phi3": + self.gate_up_proj = TensorParallelColumnLinear.load_gate_up( + config, + prefix=f"{prefix}.gate_up_proj", + weights=weights, + bias=False, + ) + else: + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) self.down_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.down_proj", diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 9cf5c80f..69bd5e88 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -696,6 +696,19 @@ class TensorParallelHead(SuperLayer): class TensorParallelColumnLinear(SuperLayer): + @classmethod + def load_gate_up(cls, config, prefix: str, weights, bias: bool): + """Specific method when the QKV was joined after the fact""" + weight = weights.get_weights_col_packed_gate_up( + prefix, quantize=config.quantize + ) + if bias: + raise NotImplementedError("packed_gate_up only implemented without bias") + else: + bias = None + linear = get_linear(weight, bias, config.quantize) + return cls(linear) + @classmethod def load_qkv(cls, config, prefix: str, weights, bias: bool): """Specific method when the QKV was joined after the fact""" diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index d0614346..da7aed1a 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -141,6 +141,12 @@ class Weights: return weight def get_weights_col_packed_qkv(self, prefix: str, quantize: str): + return self.get_weights_col_packed(prefix, quantize, 3) + + def get_weights_col_packed_gate_up(self, prefix: str, quantize: str): + return self.get_weights_col_packed(prefix, quantize, 2) + + def get_weights_col_packed(self, prefix: str, quantize: str, blocks: int): """ Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being already alternating Q,K,V within the main tensor @@ -181,8 +187,8 @@ class Weights: else: slice_ = self._get_slice(f"{prefix}.weight") total_size = slice_.get_shape()[0] - assert total_size % 3 == 0, "Prepacked qkv is not divisible by 3" - single_size = total_size // 3 + assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}" + single_size = total_size // blocks world_size = self.process_group.size() rank = self.process_group.rank() @@ -192,10 +198,11 @@ class Weights: block_size = single_size // world_size start = rank * block_size stop = (rank + 1) * block_size - q = slice_[start:stop] - k = slice_[start + single_size : stop + single_size] - v = slice_[start + 2 * single_size : stop + 2 * single_size] - weight = torch.cat([q, k, v], dim=0) + tensors = [] + for i in range(blocks): + tensor = slice_[start + i * single_size : stop + i * single_size] + tensors.append(tensor) + weight = torch.cat(tensors, dim=0) weight = weight.to(device=self.device) weight = weight.to(dtype=self.dtype) return weight