From 85dfc39222798b75559c891789283de23c679ca5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 10 Jun 2024 09:22:29 +0200 Subject: [PATCH] Add Phi-3 medium support (#2039) Add support for Phi-3-medium The main difference between the medium and mini models is that medium uses grouped query attention with a packed QKV matrix. This change adds support for GQA with packed matrixes to `Weights.get_weights_col_packed` and uses it for Phi-3. This also allows us to remove the custom implementation of GQA from dbrx attention loading. --- .../layers/tensor_parallel.py | 17 ++- .../custom_modeling/flash_dbrx_modeling.py | 131 +----------------- .../custom_modeling/flash_gpt2_modeling.py | 7 +- .../custom_modeling/flash_llama_modeling.py | 9 ++ .../text_generation_server/utils/weights.py | 118 +++++++++++----- 5 files changed, 118 insertions(+), 164 deletions(-) diff --git a/server/text_generation_server/layers/tensor_parallel.py b/server/text_generation_server/layers/tensor_parallel.py index 192c2b42..6005f737 100644 --- a/server/text_generation_server/layers/tensor_parallel.py +++ b/server/text_generation_server/layers/tensor_parallel.py @@ -129,9 +129,22 @@ class TensorParallelColumnLinear(SuperLayer): return cls(linear) @classmethod - def load_qkv(cls, config, prefix: str, weights, bias: bool): + def load_qkv( + cls, + config, + prefix: str, + weights, + bias: bool, + num_heads: int, + num_key_value_heads: int, + ): """Specific method when the QKV was joined after the fact""" - weight = weights.get_weights_col_packed_qkv(prefix, quantize=config.quantize) + weight = weights.get_weights_col_packed_qkv( + prefix, + quantize=config.quantize, + num_heads=num_heads, + num_key_value_heads=num_key_value_heads, + ) if bias: raise NotImplementedError("packed_qkv only implemented for baichuan") else: diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 63ce6543..94cf6452 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -20,7 +20,6 @@ from torch import nn from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple, Any -from loguru import logger from text_generation_server.utils.import_utils import SYSTEM if SYSTEM != "xpu": @@ -164,129 +163,13 @@ def promote_scalar(x: torch.Tensor) -> torch.Tensor: def load_attention(config, prefix, weights): - if config.n_heads != config.attn_config.kv_n_heads: - return _load_gqa(config, prefix, weights) - else: - return TensorParallelColumnLinear.load_qkv( - config, - prefix=f"{prefix}.Wqkv", - weights=weights, - bias=False, - ) - - -def _load_gqa(config, prefix: str, weights): - assert config.d_model % config.n_heads == 0 - assert config.n_heads % weights.process_group.size() == 0 - - head_dim = config.d_model // config.n_heads - world_size = weights.process_group.size() - rank = weights.process_group.rank() - - q_block_size = config.d_model // world_size - q_start = rank * q_block_size - q_stop = (rank + 1) * q_block_size - - kv_block_size = (config.attn_config.kv_n_heads * head_dim) // world_size - k_offset = config.d_model - k_start = k_offset + rank * kv_block_size - k_stop = k_offset + (rank + 1) * kv_block_size - - v_offset = config.d_model + config.attn_config.kv_n_heads * head_dim - v_start = v_offset + rank * kv_block_size - v_stop = v_offset + (rank + 1) * kv_block_size - - if config.quantize in ["gptq", "awq"]: - from text_generation_server.layers.gptq import GPTQWeight - - try: - qweight_slice = weights._get_slice(f"{prefix}.qweight") - q_qweight = qweight_slice[:, q_start:q_stop] - k_qweight = qweight_slice[:, k_start:k_stop] - v_qweight = qweight_slice[:, v_start:v_stop] - - qweight = torch.cat([q_qweight, k_qweight, v_qweight], dim=1) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{config.quantize}` weight, make sure the model is already quantized" - ) - - qzeros_slice = weights._get_slice(f"{prefix}.qzeros") - q_qzeros = qzeros_slice[:, q_start:q_stop] - k_qzeros = qzeros_slice[:, k_start:k_stop] - v_qzeros = qzeros_slice[:, v_start:v_stop] - - qzeros = torch.cat([q_qzeros, k_qzeros, v_qzeros], dim=1) - - scales_slice = weights._get_slice(f"{prefix}.scales") - q_scales = scales_slice[:, q_start:q_stop] - k_scales = scales_slice[:, k_start:k_stop] - v_scales = scales_slice[:, v_start:v_stop] - - scales = torch.cat([q_scales, k_scales, v_scales], dim=1) - - bits, groupsize, desc_act, quant_method = weights._get_gptq_params() - - from text_generation_server.layers import HAS_EXLLAMA - - use_exllama = ( - bits == 4 and HAS_EXLLAMA and config.quantize == "gptq" and not desc_act - ) - - if config.quantize == "gptq" and quant_method == "gptq": - g_idx_slice = weights._get_slice(f"{prefix}.g_idx") - q_g_idx = g_idx_slice[:, q_start:q_stop] - k_g_idx = g_idx_slice[:, k_start:k_stop] - v_g_idx = g_idx_slice[:, v_start:v_stop] - - w = [q_g_idx, k_g_idx, v_g_idx] - for w2 in w[1:]: - torch.testing.assert_close(w2, w[0]) - g_idx = w[0] - elif config.quantize == "gptq" and quant_method == "awq": - log_once( - logger.info, "Converting AWQ model to Exllama/GPTQ packing format." - ) - from text_generation_server.layers.awq.conveersion_utils import ( - fast_awq_to_gptq, - ) - - qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) - if use_exllama: - g_idx = None - else: - g_idx = ( - torch.arange(qweight.shape[0] * (32 // bits), device=qweight.device) - // groupsize - ).to(dtype=torch.int32) - else: - g_idx = None - - weight = GPTQWeight( - qweight=qweight, - qzeros=qzeros, - scales=scales, - g_idx=g_idx, - bits=bits, - groupsize=groupsize, - use_exllama=use_exllama, - ) - elif config.quantize == "marlin": - # NOTE: at the time marlin support was added, the only model that - # exists is LnL-AI/dbrx-base-converted-v2-4bit-gptq-marlin(-v2), - # but it requires manual concatenation of weight files. - raise RuntimeError("dbrx models with marlin quantization are not yet supported") - else: - qkv_slice = weights._get_slice(f"{prefix}.Wqkv.weight") - q = qkv_slice[q_start:q_stop] - k = qkv_slice[k_start:k_stop] - v = qkv_slice[v_start:v_stop] - - weight = torch.cat([q, k, v], dim=0) - weight = weight.to(dtype=weights.dtype).to(device=weights.device) - - return TensorParallelColumnLinear( - get_linear(weight, bias=None, quantize=config.quantize) + return TensorParallelColumnLinear.load_qkv( + config, + prefix=f"{prefix}.Wqkv", + weights=weights, + bias=False, + num_heads=config.n_heads, + num_key_value_heads=config.attn_config.kv_n_heads, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 0178c911..0c01f56a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -59,7 +59,12 @@ def _load_qkv_gptq(config, prefix: str, weights): rank = weights.process_group.rank() # Weights - weight = weights.get_weights_col_packed_qkv(f"{prefix}.c_attn", config.quantize) + weight = weights.get_weights_col_packed_qkv( + f"{prefix}.c_attn", + config.quantize, + config.num_attention_heads, + config.num_attention_heads, + ) # Bias slice_ = weights._get_slice(f"{prefix}.c_attn.bias") diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index cef712f0..0d06d104 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -62,6 +62,8 @@ def load_attention(config, prefix, weights): prefix=f"{prefix}.qkv_proj", weights=weights, bias=bias, + num_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, ) elif config.model_type == "baichuan": return TensorParallelColumnLinear.load_qkv( @@ -69,6 +71,8 @@ def load_attention(config, prefix, weights): prefix=f"{prefix}.W_pack", weights=weights, bias=bias, + num_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, ) # otherwise, load the default attention based on the number of heads @@ -107,6 +111,11 @@ class FlashLlamaAttention(torch.nn.Module): f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " f"and `num_shards`: {weights.process_group.size()}" ) + if config.num_key_value_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_key_value_heads` must be divisible by `num_shards` (got `num_key_value_heads`: {config.num_key_value_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) self.num_heads = self.num_heads // weights.process_group.size() self.num_key_value_heads = ( config.num_key_value_heads // weights.process_group.size() diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 557656e7..4d5fcb25 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -1,7 +1,6 @@ -from dataclasses import dataclass, field import os from pathlib import Path -from typing import List, Dict, Optional, Set, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union from safetensors import safe_open, SafetensorError import torch from loguru import logger @@ -121,49 +120,62 @@ class Weights: ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" return self.get_partial_sharded(tensor_name, dim) - def _get_qweight(self, name: str, blocks: int): + def _get_qweight(self, name: str, block_sizes: Union[int, List[int]]): slice_ = self._get_slice(name) total_size = slice_.get_shape()[1] - assert ( - total_size % blocks == 0 - ), f"Prepacked quantized matrix is not divisible by {blocks}" - single_size = total_size // blocks + block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=block_sizes) + world_size = self.process_group.size() rank = self.process_group.rank() - assert ( - single_size % world_size == 0 - ), f"Prepacked quantized matrix cannot be sharded across {world_size} shards" - block_size = single_size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - weights = [] - for block in range(blocks): - weights.append( - slice_[:, start + block * single_size : stop + block * single_size] - ) + block_offset = 0 + for block_size in block_sizes: + assert ( + block_size % world_size == 0 + ), f"Prepacked qkv cannot be sharded across {world_size} shards" + shard_block_size = block_size // world_size + start = rank * shard_block_size + stop = (rank + 1) * shard_block_size + weights.append(slice_[:, block_offset + start : block_offset + stop]) + block_offset += block_size weight = torch.cat(weights, dim=1) weight = weight.to(device=self.device) return weight - def get_weights_col_packed_qkv(self, prefix: str, quantize: str): - return self.get_weights_col_packed(prefix, quantize, 3) + def get_weights_col_packed_qkv( + self, + prefix: str, + quantize: str, + num_heads: int, + num_key_value_heads: int, + ): + return self.get_weights_col_packed( + prefix, quantize, [num_heads, num_key_value_heads, num_key_value_heads] + ) def get_weights_col_packed_gate_up(self, prefix: str, quantize: str): return self.get_weights_col_packed(prefix, quantize, 2) - def get_weights_col_packed(self, prefix: str, quantize: str, blocks: int): + def get_weights_col_packed( + self, prefix: str, quantize: str, block_sizes: Union[int, List[int]] + ): """ Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being - already alternating Q,K,V within the main tensor + already alternating Q,K,V within the main tensor. + + The columns are split in equally sized blocks when blocks is an `int`, or + in blocks proportional given to the sizes. For instance `[2, 1, 1]` will + divide an input with dimensionality `1024` in `[512, 256, 256]`. This is + convenient for e.g. splitting QKV without knowing the storage details of + quantized weights. """ if quantize in ["gptq", "awq"]: from text_generation_server.layers.gptq import GPTQWeight try: - qweight = self._get_qweight(f"{prefix}.qweight", blocks) + qweight = self._get_qweight(f"{prefix}.qweight", block_sizes) except RuntimeError: raise RuntimeError( f"Cannot load `{quantize}` weight, make sure the model is already quantized." @@ -171,8 +183,8 @@ class Weights: bits, groupsize, _, quant_method = self._get_gptq_params() - qzeros = self._get_qweight(f"{prefix}.qzeros", blocks) - scales = self._get_qweight(f"{prefix}.scales", blocks) + qzeros = self._get_qweight(f"{prefix}.qzeros", block_sizes) + scales = self._get_qweight(f"{prefix}.scales", block_sizes) scales = scales.to(dtype=self.dtype) if quantize == "gptq" and quant_method == "gptq": @@ -205,27 +217,31 @@ class Weights: elif quantize == "marlin": from text_generation_server.layers.marlin import MarlinWeight - B = self._get_qweight(f"{prefix}.B", blocks) - s = self._get_qweight(f"{prefix}.s", blocks) + B = self._get_qweight(f"{prefix}.B", block_sizes) + s = self._get_qweight(f"{prefix}.s", block_sizes) weight = MarlinWeight(B=B, s=s) else: slice_ = self._get_slice(f"{prefix}.weight") total_size = slice_.get_shape()[0] - assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}" - single_size = total_size // blocks + block_sizes = _blocks_to_block_sizes( + total_size=total_size, blocks=block_sizes + ) + world_size = self.process_group.size() rank = self.process_group.rank() - assert ( - single_size % world_size == 0 - ), f"Prepacked qkv cannot be sharded across {world_size} shards" - block_size = single_size // world_size - start = rank * block_size - stop = (rank + 1) * block_size tensors = [] - for i in range(blocks): - tensor = slice_[start + i * single_size : stop + i * single_size] + block_offset = 0 + for block_size in block_sizes: + assert ( + block_size % world_size == 0 + ), f"Prepacked weights cannot be sharded across {world_size} shards" + shard_block_size = block_size // world_size + start = rank * shard_block_size + stop = (rank + 1) * shard_block_size + tensor = slice_[block_offset + start : block_offset + stop] tensors.append(tensor) + block_offset += block_size weight = torch.cat(tensors, dim=0) weight = weight.to(device=self.device) weight = weight.to(dtype=self.dtype) @@ -593,3 +609,31 @@ class Weights: self.quant_method = "awq" except Exception: pass + + +def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]: + """ + Convert block count or proportions to block sizes. + + This function accepts + + - The number of blocks (int), in which case the block size is + total_size//blocks; or + - A list of block sizes (List[int]). + + In the latter case, if sum(blocks) < total_size, the ratios between + the block sizes will be preserved. For instance, if blocks is + [2, 1, 1] and total_size is 1024, the returned block sizes are + [512, 256, 256]. + """ + if isinstance(blocks, list): + total_blocks = sum(blocks) + assert ( + total_size % total_blocks == 0 + ), f"Cannot split {total_size} in proportional blocks: {blocks}" + part_size = total_size // total_blocks + return [part_size * block for block in blocks] + else: + assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}" + single_size = total_size // blocks + return [single_size] * blocks