From 4cce84301b9f0f3472cf8b54e3119afd87caa09e Mon Sep 17 00:00:00 2001 From: xiaobin <737489727@qq.com> Date: Fri, 8 Sep 2023 22:51:34 +0800 Subject: [PATCH] fit for baichuan models (#981) As more and more people begin to use Baichuan's open-source models, the influence of Baichuan models is growing, especially in China. Many community members are interested in adding support for Baichuan models to TGI. Meanwhile, Baichuan is a very open company, and in the future, it plans to open-source more and more models, taking all this into consideration, we would like to add support for the Baichuan model to TGI. To do this, we need to make some changes, which we hope can be merged into the main branch of TGI. In the future, we would be happy to help maintain support for Baichuan models in TGI. We sincerely hope that our pull request can be accepted. Thank you. By the way, the changes of this time mainly for supporting Baichuan-7B. --------- Co-authored-by: xiaoyuze Co-authored-by: Nicolas Patry --- .../text_generation_server/models/__init__.py | 2 +- .../custom_modeling/flash_llama_modeling.py | 33 ++++--- .../text_generation_server/utils/convert.py | 12 ++- server/text_generation_server/utils/layers.py | 13 +++ .../text_generation_server/utils/weights.py | 85 ++++++++++++++++++- 5 files changed, 127 insertions(+), 18 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 932ab32e..e6fe1372 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -182,7 +182,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - elif model_type == "llama": + elif model_type == "llama" or model_type == "baichuan": if FLASH_ATTENTION: return FlashLlama( model_id, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index f0e1236d..55b1aae9 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -149,6 +149,26 @@ class LlamaRMSNorm(nn.Module): return normed_hidden_states, res +def load_attention(config, prefix, weights): + if config.num_attention_heads != config.num_key_value_heads: + return _load_gqa(config, prefix, weights) + else: + if config.model_type == "baichuan": + return TensorParallelColumnLinear.load_qkv( + config, + prefix=f"{prefix}.W_pack", + weights=weights, + bias=False, + ) + else: + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) + def _load_gqa(config, prefix: str, weights): assert config.hidden_size % config.num_attention_heads == 0 assert config.num_attention_heads % weights.process_group.size() == 0 @@ -205,16 +225,9 @@ class FlashLlamaAttention(torch.nn.Module): self.num_key_value_heads = ( config.num_key_value_heads // weights.process_group.size() ) - if config.num_attention_heads != config.num_key_value_heads: - self.query_key_value = _load_gqa(config, prefix, weights) - else: - self.query_key_value = TensorParallelColumnLinear.load_multi( - config, - prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - dim=0, - weights=weights, - bias=False, - ) + + self.query_key_value = load_attention(config, prefix, weights) + self.o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", diff --git a/server/text_generation_server/utils/convert.py b/server/text_generation_server/utils/convert.py index 8d414eca..0b62f520 100644 --- a/server/text_generation_server/utils/convert.py +++ b/server/text_generation_server/utils/convert.py @@ -29,9 +29,15 @@ def _remove_duplicate_names( [name for name in shared if _is_complete(state_dict[name])] ) if not complete_names: - raise RuntimeError( - f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue." - ) + if len(shared) == 1: + # Force contiguous + name = list(shared)[0] + state_dict[name] = state_dict[name].clone() + complete_names = {name} + else: + raise RuntimeError( + f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue." + ) keep_name = sorted(list(complete_names))[0] diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 6be54048..c1c36194 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -331,6 +331,19 @@ class TensorParallelHead(SuperLayer): class TensorParallelColumnLinear(SuperLayer): + @classmethod + def load_qkv(cls, config, prefix: str, weights, bias: bool): + """Specific method when the QKV was joined after the fact""" + weight = weights.get_weights_col_packed_qkv( + prefix, quantize=config.quantize + ) + if bias: + raise NotImplementedError("packed_qkv only implemented for baichuan") + else: + bias = None + linear = get_linear(weight, bias, config.quantize) + return cls(linear) + @classmethod def load(cls, config, prefix: str, weights, bias: bool): return cls.load_multi(config, [prefix], weights, bias, dim=0) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 9d47a7d3..2ef7ad39 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -62,7 +62,7 @@ class Weights: def get_shape(self, tensor_name: str): return self._get_slice(tensor_name).get_shape() - def get_tensor(self, tensor_name: str): + def get_tensor(self, tensor_name: str, to_device = True): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) tensor = f.get_tensor(tensor_name) @@ -70,16 +70,17 @@ class Weights: # u4 which are disguised as int32 if tensor.dtype not in [torch.int32, torch.int64]: tensor = tensor.to(dtype=self.dtype) - tensor = tensor.to(device=self.device) + if to_device: + tensor = tensor.to(device=self.device) return tensor def get_partial_sharded(self, tensor_name: str, dim: int): filename, tensor_name = self.get_filename(tensor_name) + f = self._get_handle(filename) + slice_ = f.get_slice(tensor_name) world_size = self.process_group.size() rank = self.process_group.rank() - f = self._get_handle(filename) - slice_ = f.get_slice(tensor_name) size = slice_.get_shape()[dim] block_size = size // world_size start = rank * block_size @@ -109,6 +110,66 @@ class Weights: ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" return self.get_partial_sharded(tensor_name, dim) + + def _get_qweight(self, name: str): + slice_ = self._get_slice(name) + total_size = slice_.get_shape()[1] + assert total_size % 3 == 0, "Prepacked quantized qkv is not divisible by 3" + single_size = total_size // 3 + world_size = self.process_group.size() + rank = self.process_group.rank() + + assert single_size % world_size == 0, f"Prepacked quantized qkv cannot be sharded across {world_size} shards" + block_size = single_size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + q = slice_[:, start:stop] + k = slice_[:, start+single_size:stop+single_size] + v = slice_[:, start+2*single_size:stop+2*single_size] + weight = torch.cat([q,k,v], dim=1) + weight = weight.to(device=self.device) + return weight + + def get_weights_col_packed_qkv(self, prefix: str, quantize: str): + """ + Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being + already alternating Q,K,V within the main tensor + """ + if quantize == "gptq": + try: + qweight = self._get_qweight(f"{prefix}.qweight") + except RuntimeError: + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) + + qzeros = self._get_qweight(f"{prefix}.qzeros") + scales = self._get_qweight(f"{prefix}.scales") + scales = scales.to(dtype=self.dtype) + g_idx = self.get_tensor(f"{prefix}.g_idx") + + bits, groupsize = self._get_gptq_params() + weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) + else: + slice_ = self._get_slice(f"{prefix}.weight") + total_size = slice_.get_shape()[0] + assert total_size % 3 == 0, "Prepacked qkv is not divisible by 3" + single_size = total_size // 3 + world_size = self.process_group.size() + rank = self.process_group.rank() + + assert single_size % world_size == 0, f"Prepacked qkv cannot be sharded across {world_size} shards" + block_size = single_size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + q = slice_[start:stop] + k = slice_[start+single_size:stop+single_size] + v = slice_[start+2*single_size:stop+2*single_size] + weight = torch.cat([q,k,v], dim=0) + weight = weight.to(device=self.device) + weight = weight.to(dtype=self.dtype) + return weight + def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): if quantize == "gptq": try: @@ -137,6 +198,22 @@ class Weights: w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] weight = torch.cat(w, dim=dim) return weight + + def get_tensor_shard(self, var, dim): + world_size = self.process_group.size() + rank = self.process_group.rank() + block_size = var.size()[dim] // world_size + start = rank * block_size + stop = (rank + 1) * block_size + if dim == 0: + tensor = var[start:stop] + elif dim == 1: + tensor = var[:, start:stop] + else: + raise NotImplementedError("Let's make that generic when needed") + tensor = tensor.to(dtype=self.dtype) + tensor = tensor.to(device=self.device) + return tensor def get_multi_weights_row(self, prefix: str, quantize: str): if quantize == "gptq":