Add Phi-3 medium support (#2039)

Add support for Phi-3-medium

The main difference between the medium and mini models is that medium
uses grouped query attention with a packed QKV matrix. This change adds
support for GQA with packed matrixes to `Weights.get_weights_col_packed`
and uses it for Phi-3. This also allows us to remove the custom
implementation of GQA from dbrx attention loading.
This commit is contained in:
Daniël de Kok 2024-06-10 09:22:29 +02:00 committed by GitHub
parent 9b3674d903
commit 85dfc39222
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 118 additions and 164 deletions

View File

@ -129,9 +129,22 @@ class TensorParallelColumnLinear(SuperLayer):
return cls(linear)
@classmethod
def load_qkv(cls, config, prefix: str, weights, bias: bool):
def load_qkv(
cls,
config,
prefix: str,
weights,
bias: bool,
num_heads: int,
num_key_value_heads: int,
):
"""Specific method when the QKV was joined after the fact"""
weight = weights.get_weights_col_packed_qkv(prefix, quantize=config.quantize)
weight = weights.get_weights_col_packed_qkv(
prefix,
quantize=config.quantize,
num_heads=num_heads,
num_key_value_heads=num_key_value_heads,
)
if bias:
raise NotImplementedError("packed_qkv only implemented for baichuan")
else:

View File

@ -20,7 +20,6 @@ from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple, Any
from loguru import logger
from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM != "xpu":
@ -164,129 +163,13 @@ def promote_scalar(x: torch.Tensor) -> torch.Tensor:
def load_attention(config, prefix, weights):
if config.n_heads != config.attn_config.kv_n_heads:
return _load_gqa(config, prefix, weights)
else:
return TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.Wqkv",
weights=weights,
bias=False,
)
def _load_gqa(config, prefix: str, weights):
assert config.d_model % config.n_heads == 0
assert config.n_heads % weights.process_group.size() == 0
head_dim = config.d_model // config.n_heads
world_size = weights.process_group.size()
rank = weights.process_group.rank()
q_block_size = config.d_model // world_size
q_start = rank * q_block_size
q_stop = (rank + 1) * q_block_size
kv_block_size = (config.attn_config.kv_n_heads * head_dim) // world_size
k_offset = config.d_model
k_start = k_offset + rank * kv_block_size
k_stop = k_offset + (rank + 1) * kv_block_size
v_offset = config.d_model + config.attn_config.kv_n_heads * head_dim
v_start = v_offset + rank * kv_block_size
v_stop = v_offset + (rank + 1) * kv_block_size
if config.quantize in ["gptq", "awq"]:
from text_generation_server.layers.gptq import GPTQWeight
try:
qweight_slice = weights._get_slice(f"{prefix}.qweight")
q_qweight = qweight_slice[:, q_start:q_stop]
k_qweight = qweight_slice[:, k_start:k_stop]
v_qweight = qweight_slice[:, v_start:v_stop]
qweight = torch.cat([q_qweight, k_qweight, v_qweight], dim=1)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{config.quantize}` weight, make sure the model is already quantized"
)
qzeros_slice = weights._get_slice(f"{prefix}.qzeros")
q_qzeros = qzeros_slice[:, q_start:q_stop]
k_qzeros = qzeros_slice[:, k_start:k_stop]
v_qzeros = qzeros_slice[:, v_start:v_stop]
qzeros = torch.cat([q_qzeros, k_qzeros, v_qzeros], dim=1)
scales_slice = weights._get_slice(f"{prefix}.scales")
q_scales = scales_slice[:, q_start:q_stop]
k_scales = scales_slice[:, k_start:k_stop]
v_scales = scales_slice[:, v_start:v_stop]
scales = torch.cat([q_scales, k_scales, v_scales], dim=1)
bits, groupsize, desc_act, quant_method = weights._get_gptq_params()
from text_generation_server.layers import HAS_EXLLAMA
use_exllama = (
bits == 4 and HAS_EXLLAMA and config.quantize == "gptq" and not desc_act
)
if config.quantize == "gptq" and quant_method == "gptq":
g_idx_slice = weights._get_slice(f"{prefix}.g_idx")
q_g_idx = g_idx_slice[:, q_start:q_stop]
k_g_idx = g_idx_slice[:, k_start:k_stop]
v_g_idx = g_idx_slice[:, v_start:v_stop]
w = [q_g_idx, k_g_idx, v_g_idx]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
elif config.quantize == "gptq" and quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conveersion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
if use_exllama:
g_idx = None
else:
g_idx = (
torch.arange(qweight.shape[0] * (32 // bits), device=qweight.device)
// groupsize
).to(dtype=torch.int32)
else:
g_idx = None
weight = GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=bits,
groupsize=groupsize,
use_exllama=use_exllama,
)
elif config.quantize == "marlin":
# NOTE: at the time marlin support was added, the only model that
# exists is LnL-AI/dbrx-base-converted-v2-4bit-gptq-marlin(-v2),
# but it requires manual concatenation of weight files.
raise RuntimeError("dbrx models with marlin quantization are not yet supported")
else:
qkv_slice = weights._get_slice(f"{prefix}.Wqkv.weight")
q = qkv_slice[q_start:q_stop]
k = qkv_slice[k_start:k_stop]
v = qkv_slice[v_start:v_stop]
weight = torch.cat([q, k, v], dim=0)
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
return TensorParallelColumnLinear(
get_linear(weight, bias=None, quantize=config.quantize)
return TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.Wqkv",
weights=weights,
bias=False,
num_heads=config.n_heads,
num_key_value_heads=config.attn_config.kv_n_heads,
)

View File

@ -59,7 +59,12 @@ def _load_qkv_gptq(config, prefix: str, weights):
rank = weights.process_group.rank()
# Weights
weight = weights.get_weights_col_packed_qkv(f"{prefix}.c_attn", config.quantize)
weight = weights.get_weights_col_packed_qkv(
f"{prefix}.c_attn",
config.quantize,
config.num_attention_heads,
config.num_attention_heads,
)
# Bias
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")

View File

@ -62,6 +62,8 @@ def load_attention(config, prefix, weights):
prefix=f"{prefix}.qkv_proj",
weights=weights,
bias=bias,
num_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
)
elif config.model_type == "baichuan":
return TensorParallelColumnLinear.load_qkv(
@ -69,6 +71,8 @@ def load_attention(config, prefix, weights):
prefix=f"{prefix}.W_pack",
weights=weights,
bias=bias,
num_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
)
# otherwise, load the default attention based on the number of heads
@ -107,6 +111,11 @@ class FlashLlamaAttention(torch.nn.Module):
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
if config.num_key_value_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_key_value_heads` must be divisible by `num_shards` (got `num_key_value_heads`: {config.num_key_value_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()

View File

@ -1,7 +1,6 @@
from dataclasses import dataclass, field
import os
from pathlib import Path
from typing import List, Dict, Optional, Set, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
from safetensors import safe_open, SafetensorError
import torch
from loguru import logger
@ -121,49 +120,62 @@ class Weights:
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
return self.get_partial_sharded(tensor_name, dim)
def _get_qweight(self, name: str, blocks: int):
def _get_qweight(self, name: str, block_sizes: Union[int, List[int]]):
slice_ = self._get_slice(name)
total_size = slice_.get_shape()[1]
assert (
total_size % blocks == 0
), f"Prepacked quantized matrix is not divisible by {blocks}"
single_size = total_size // blocks
block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=block_sizes)
world_size = self.process_group.size()
rank = self.process_group.rank()
assert (
single_size % world_size == 0
), f"Prepacked quantized matrix cannot be sharded across {world_size} shards"
block_size = single_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
weights = []
for block in range(blocks):
weights.append(
slice_[:, start + block * single_size : stop + block * single_size]
)
block_offset = 0
for block_size in block_sizes:
assert (
block_size % world_size == 0
), f"Prepacked qkv cannot be sharded across {world_size} shards"
shard_block_size = block_size // world_size
start = rank * shard_block_size
stop = (rank + 1) * shard_block_size
weights.append(slice_[:, block_offset + start : block_offset + stop])
block_offset += block_size
weight = torch.cat(weights, dim=1)
weight = weight.to(device=self.device)
return weight
def get_weights_col_packed_qkv(self, prefix: str, quantize: str):
return self.get_weights_col_packed(prefix, quantize, 3)
def get_weights_col_packed_qkv(
self,
prefix: str,
quantize: str,
num_heads: int,
num_key_value_heads: int,
):
return self.get_weights_col_packed(
prefix, quantize, [num_heads, num_key_value_heads, num_key_value_heads]
)
def get_weights_col_packed_gate_up(self, prefix: str, quantize: str):
return self.get_weights_col_packed(prefix, quantize, 2)
def get_weights_col_packed(self, prefix: str, quantize: str, blocks: int):
def get_weights_col_packed(
self, prefix: str, quantize: str, block_sizes: Union[int, List[int]]
):
"""
Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
already alternating Q,K,V within the main tensor
already alternating Q,K,V within the main tensor.
The columns are split in equally sized blocks when blocks is an `int`, or
in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
convenient for e.g. splitting QKV without knowing the storage details of
quantized weights.
"""
if quantize in ["gptq", "awq"]:
from text_generation_server.layers.gptq import GPTQWeight
try:
qweight = self._get_qweight(f"{prefix}.qweight", blocks)
qweight = self._get_qweight(f"{prefix}.qweight", block_sizes)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
@ -171,8 +183,8 @@ class Weights:
bits, groupsize, _, quant_method = self._get_gptq_params()
qzeros = self._get_qweight(f"{prefix}.qzeros", blocks)
scales = self._get_qweight(f"{prefix}.scales", blocks)
qzeros = self._get_qweight(f"{prefix}.qzeros", block_sizes)
scales = self._get_qweight(f"{prefix}.scales", block_sizes)
scales = scales.to(dtype=self.dtype)
if quantize == "gptq" and quant_method == "gptq":
@ -205,27 +217,31 @@ class Weights:
elif quantize == "marlin":
from text_generation_server.layers.marlin import MarlinWeight
B = self._get_qweight(f"{prefix}.B", blocks)
s = self._get_qweight(f"{prefix}.s", blocks)
B = self._get_qweight(f"{prefix}.B", block_sizes)
s = self._get_qweight(f"{prefix}.s", block_sizes)
weight = MarlinWeight(B=B, s=s)
else:
slice_ = self._get_slice(f"{prefix}.weight")
total_size = slice_.get_shape()[0]
assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
single_size = total_size // blocks
block_sizes = _blocks_to_block_sizes(
total_size=total_size, blocks=block_sizes
)
world_size = self.process_group.size()
rank = self.process_group.rank()
assert (
single_size % world_size == 0
), f"Prepacked qkv cannot be sharded across {world_size} shards"
block_size = single_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensors = []
for i in range(blocks):
tensor = slice_[start + i * single_size : stop + i * single_size]
block_offset = 0
for block_size in block_sizes:
assert (
block_size % world_size == 0
), f"Prepacked weights cannot be sharded across {world_size} shards"
shard_block_size = block_size // world_size
start = rank * shard_block_size
stop = (rank + 1) * shard_block_size
tensor = slice_[block_offset + start : block_offset + stop]
tensors.append(tensor)
block_offset += block_size
weight = torch.cat(tensors, dim=0)
weight = weight.to(device=self.device)
weight = weight.to(dtype=self.dtype)
@ -593,3 +609,31 @@ class Weights:
self.quant_method = "awq"
except Exception:
pass
def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
"""
Convert block count or proportions to block sizes.
This function accepts
- The number of blocks (int), in which case the block size is
total_size//blocks; or
- A list of block sizes (List[int]).
In the latter case, if sum(blocks) < total_size, the ratios between
the block sizes will be preserved. For instance, if blocks is
[2, 1, 1] and total_size is 1024, the returned block sizes are
[512, 256, 256].
"""
if isinstance(blocks, list):
total_blocks = sum(blocks)
assert (
total_size % total_blocks == 0
), f"Cannot split {total_size} in proportional blocks: {blocks}"
part_size = total_size // total_blocks
return [part_size * block for block in blocks]
else:
assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
single_size = total_size // blocks
return [single_size] * blocks