diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 932ab32e..e6fe1372 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -182,7 +182,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - elif model_type == "llama": + elif model_type == "llama" or model_type == "baichuan": if FLASH_ATTENTION: return FlashLlama( model_id, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index f0e1236d..55b1aae9 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -149,6 +149,26 @@ class LlamaRMSNorm(nn.Module): return normed_hidden_states, res +def load_attention(config, prefix, weights): + if config.num_attention_heads != config.num_key_value_heads: + return _load_gqa(config, prefix, weights) + else: + if config.model_type == "baichuan": + return TensorParallelColumnLinear.load_qkv( + config, + prefix=f"{prefix}.W_pack", + weights=weights, + bias=False, + ) + else: + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) + def _load_gqa(config, prefix: str, weights): assert config.hidden_size % config.num_attention_heads == 0 assert config.num_attention_heads % weights.process_group.size() == 0 @@ -205,16 +225,9 @@ class FlashLlamaAttention(torch.nn.Module): self.num_key_value_heads = ( config.num_key_value_heads // weights.process_group.size() ) - if config.num_attention_heads != config.num_key_value_heads: - self.query_key_value = _load_gqa(config, prefix, weights) - else: - self.query_key_value = TensorParallelColumnLinear.load_multi( - config, - prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - dim=0, - weights=weights, - bias=False, - ) + + self.query_key_value = load_attention(config, prefix, weights) + self.o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", diff --git a/server/text_generation_server/utils/convert.py b/server/text_generation_server/utils/convert.py index 8d414eca..0b62f520 100644 --- a/server/text_generation_server/utils/convert.py +++ b/server/text_generation_server/utils/convert.py @@ -29,9 +29,15 @@ def _remove_duplicate_names( [name for name in shared if _is_complete(state_dict[name])] ) if not complete_names: - raise RuntimeError( - f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue." - ) + if len(shared) == 1: + # Force contiguous + name = list(shared)[0] + state_dict[name] = state_dict[name].clone() + complete_names = {name} + else: + raise RuntimeError( + f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue." + ) keep_name = sorted(list(complete_names))[0] diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 6be54048..c1c36194 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -331,6 +331,19 @@ class TensorParallelHead(SuperLayer): class TensorParallelColumnLinear(SuperLayer): + @classmethod + def load_qkv(cls, config, prefix: str, weights, bias: bool): + """Specific method when the QKV was joined after the fact""" + weight = weights.get_weights_col_packed_qkv( + prefix, quantize=config.quantize + ) + if bias: + raise NotImplementedError("packed_qkv only implemented for baichuan") + else: + bias = None + linear = get_linear(weight, bias, config.quantize) + return cls(linear) + @classmethod def load(cls, config, prefix: str, weights, bias: bool): return cls.load_multi(config, [prefix], weights, bias, dim=0) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 9d47a7d3..2ef7ad39 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -62,7 +62,7 @@ class Weights: def get_shape(self, tensor_name: str): return self._get_slice(tensor_name).get_shape() - def get_tensor(self, tensor_name: str): + def get_tensor(self, tensor_name: str, to_device = True): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) tensor = f.get_tensor(tensor_name) @@ -70,16 +70,17 @@ class Weights: # u4 which are disguised as int32 if tensor.dtype not in [torch.int32, torch.int64]: tensor = tensor.to(dtype=self.dtype) - tensor = tensor.to(device=self.device) + if to_device: + tensor = tensor.to(device=self.device) return tensor def get_partial_sharded(self, tensor_name: str, dim: int): filename, tensor_name = self.get_filename(tensor_name) + f = self._get_handle(filename) + slice_ = f.get_slice(tensor_name) world_size = self.process_group.size() rank = self.process_group.rank() - f = self._get_handle(filename) - slice_ = f.get_slice(tensor_name) size = slice_.get_shape()[dim] block_size = size // world_size start = rank * block_size @@ -109,6 +110,66 @@ class Weights: ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" return self.get_partial_sharded(tensor_name, dim) + + def _get_qweight(self, name: str): + slice_ = self._get_slice(name) + total_size = slice_.get_shape()[1] + assert total_size % 3 == 0, "Prepacked quantized qkv is not divisible by 3" + single_size = total_size // 3 + world_size = self.process_group.size() + rank = self.process_group.rank() + + assert single_size % world_size == 0, f"Prepacked quantized qkv cannot be sharded across {world_size} shards" + block_size = single_size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + q = slice_[:, start:stop] + k = slice_[:, start+single_size:stop+single_size] + v = slice_[:, start+2*single_size:stop+2*single_size] + weight = torch.cat([q,k,v], dim=1) + weight = weight.to(device=self.device) + return weight + + def get_weights_col_packed_qkv(self, prefix: str, quantize: str): + """ + Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being + already alternating Q,K,V within the main tensor + """ + if quantize == "gptq": + try: + qweight = self._get_qweight(f"{prefix}.qweight") + except RuntimeError: + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) + + qzeros = self._get_qweight(f"{prefix}.qzeros") + scales = self._get_qweight(f"{prefix}.scales") + scales = scales.to(dtype=self.dtype) + g_idx = self.get_tensor(f"{prefix}.g_idx") + + bits, groupsize = self._get_gptq_params() + weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) + else: + slice_ = self._get_slice(f"{prefix}.weight") + total_size = slice_.get_shape()[0] + assert total_size % 3 == 0, "Prepacked qkv is not divisible by 3" + single_size = total_size // 3 + world_size = self.process_group.size() + rank = self.process_group.rank() + + assert single_size % world_size == 0, f"Prepacked qkv cannot be sharded across {world_size} shards" + block_size = single_size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + q = slice_[start:stop] + k = slice_[start+single_size:stop+single_size] + v = slice_[start+2*single_size:stop+2*single_size] + weight = torch.cat([q,k,v], dim=0) + weight = weight.to(device=self.device) + weight = weight.to(dtype=self.dtype) + return weight + def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): if quantize == "gptq": try: @@ -137,6 +198,22 @@ class Weights: w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] weight = torch.cat(w, dim=dim) return weight + + def get_tensor_shard(self, var, dim): + world_size = self.process_group.size() + rank = self.process_group.rank() + block_size = var.size()[dim] // world_size + start = rank * block_size + stop = (rank + 1) * block_size + if dim == 0: + tensor = var[start:stop] + elif dim == 1: + tensor = var[:, start:stop] + else: + raise NotImplementedError("Let's make that generic when needed") + tensor = tensor.to(dtype=self.dtype) + tensor = tensor.to(device=self.device) + return tensor def get_multi_weights_row(self, prefix: str, quantize: str): if quantize == "gptq":