Add Phi-3 medium support (#2039)
Add support for Phi-3-medium The main difference between the medium and mini models is that medium uses grouped query attention with a packed QKV matrix. This change adds support for GQA with packed matrixes to `Weights.get_weights_col_packed` and uses it for Phi-3. This also allows us to remove the custom implementation of GQA from dbrx attention loading.
This commit is contained in:
parent
9b3674d903
commit
85dfc39222
|
@ -129,9 +129,22 @@ class TensorParallelColumnLinear(SuperLayer):
|
|||
return cls(linear)
|
||||
|
||||
@classmethod
|
||||
def load_qkv(cls, config, prefix: str, weights, bias: bool):
|
||||
def load_qkv(
|
||||
cls,
|
||||
config,
|
||||
prefix: str,
|
||||
weights,
|
||||
bias: bool,
|
||||
num_heads: int,
|
||||
num_key_value_heads: int,
|
||||
):
|
||||
"""Specific method when the QKV was joined after the fact"""
|
||||
weight = weights.get_weights_col_packed_qkv(prefix, quantize=config.quantize)
|
||||
weight = weights.get_weights_col_packed_qkv(
|
||||
prefix,
|
||||
quantize=config.quantize,
|
||||
num_heads=num_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
)
|
||||
if bias:
|
||||
raise NotImplementedError("packed_qkv only implemented for baichuan")
|
||||
else:
|
||||
|
|
|
@ -20,7 +20,6 @@ from torch import nn
|
|||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple, Any
|
||||
from loguru import logger
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
if SYSTEM != "xpu":
|
||||
|
@ -164,129 +163,13 @@ def promote_scalar(x: torch.Tensor) -> torch.Tensor:
|
|||
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
if config.n_heads != config.attn_config.kv_n_heads:
|
||||
return _load_gqa(config, prefix, weights)
|
||||
else:
|
||||
return TensorParallelColumnLinear.load_qkv(
|
||||
config,
|
||||
prefix=f"{prefix}.Wqkv",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
|
||||
def _load_gqa(config, prefix: str, weights):
|
||||
assert config.d_model % config.n_heads == 0
|
||||
assert config.n_heads % weights.process_group.size() == 0
|
||||
|
||||
head_dim = config.d_model // config.n_heads
|
||||
world_size = weights.process_group.size()
|
||||
rank = weights.process_group.rank()
|
||||
|
||||
q_block_size = config.d_model // world_size
|
||||
q_start = rank * q_block_size
|
||||
q_stop = (rank + 1) * q_block_size
|
||||
|
||||
kv_block_size = (config.attn_config.kv_n_heads * head_dim) // world_size
|
||||
k_offset = config.d_model
|
||||
k_start = k_offset + rank * kv_block_size
|
||||
k_stop = k_offset + (rank + 1) * kv_block_size
|
||||
|
||||
v_offset = config.d_model + config.attn_config.kv_n_heads * head_dim
|
||||
v_start = v_offset + rank * kv_block_size
|
||||
v_stop = v_offset + (rank + 1) * kv_block_size
|
||||
|
||||
if config.quantize in ["gptq", "awq"]:
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
|
||||
try:
|
||||
qweight_slice = weights._get_slice(f"{prefix}.qweight")
|
||||
q_qweight = qweight_slice[:, q_start:q_stop]
|
||||
k_qweight = qweight_slice[:, k_start:k_stop]
|
||||
v_qweight = qweight_slice[:, v_start:v_stop]
|
||||
|
||||
qweight = torch.cat([q_qweight, k_qweight, v_qweight], dim=1)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{config.quantize}` weight, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
qzeros_slice = weights._get_slice(f"{prefix}.qzeros")
|
||||
q_qzeros = qzeros_slice[:, q_start:q_stop]
|
||||
k_qzeros = qzeros_slice[:, k_start:k_stop]
|
||||
v_qzeros = qzeros_slice[:, v_start:v_stop]
|
||||
|
||||
qzeros = torch.cat([q_qzeros, k_qzeros, v_qzeros], dim=1)
|
||||
|
||||
scales_slice = weights._get_slice(f"{prefix}.scales")
|
||||
q_scales = scales_slice[:, q_start:q_stop]
|
||||
k_scales = scales_slice[:, k_start:k_stop]
|
||||
v_scales = scales_slice[:, v_start:v_stop]
|
||||
|
||||
scales = torch.cat([q_scales, k_scales, v_scales], dim=1)
|
||||
|
||||
bits, groupsize, desc_act, quant_method = weights._get_gptq_params()
|
||||
|
||||
from text_generation_server.layers import HAS_EXLLAMA
|
||||
|
||||
use_exllama = (
|
||||
bits == 4 and HAS_EXLLAMA and config.quantize == "gptq" and not desc_act
|
||||
)
|
||||
|
||||
if config.quantize == "gptq" and quant_method == "gptq":
|
||||
g_idx_slice = weights._get_slice(f"{prefix}.g_idx")
|
||||
q_g_idx = g_idx_slice[:, q_start:q_stop]
|
||||
k_g_idx = g_idx_slice[:, k_start:k_stop]
|
||||
v_g_idx = g_idx_slice[:, v_start:v_stop]
|
||||
|
||||
w = [q_g_idx, k_g_idx, v_g_idx]
|
||||
for w2 in w[1:]:
|
||||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
elif config.quantize == "gptq" and quant_method == "awq":
|
||||
log_once(
|
||||
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||
)
|
||||
from text_generation_server.layers.awq.conveersion_utils import (
|
||||
fast_awq_to_gptq,
|
||||
)
|
||||
|
||||
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||
if use_exllama:
|
||||
g_idx = None
|
||||
else:
|
||||
g_idx = (
|
||||
torch.arange(qweight.shape[0] * (32 // bits), device=qweight.device)
|
||||
// groupsize
|
||||
).to(dtype=torch.int32)
|
||||
else:
|
||||
g_idx = None
|
||||
|
||||
weight = GPTQWeight(
|
||||
qweight=qweight,
|
||||
qzeros=qzeros,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=bits,
|
||||
groupsize=groupsize,
|
||||
use_exllama=use_exllama,
|
||||
)
|
||||
elif config.quantize == "marlin":
|
||||
# NOTE: at the time marlin support was added, the only model that
|
||||
# exists is LnL-AI/dbrx-base-converted-v2-4bit-gptq-marlin(-v2),
|
||||
# but it requires manual concatenation of weight files.
|
||||
raise RuntimeError("dbrx models with marlin quantization are not yet supported")
|
||||
else:
|
||||
qkv_slice = weights._get_slice(f"{prefix}.Wqkv.weight")
|
||||
q = qkv_slice[q_start:q_stop]
|
||||
k = qkv_slice[k_start:k_stop]
|
||||
v = qkv_slice[v_start:v_stop]
|
||||
|
||||
weight = torch.cat([q, k, v], dim=0)
|
||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||
|
||||
return TensorParallelColumnLinear(
|
||||
get_linear(weight, bias=None, quantize=config.quantize)
|
||||
return TensorParallelColumnLinear.load_qkv(
|
||||
config,
|
||||
prefix=f"{prefix}.Wqkv",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
num_heads=config.n_heads,
|
||||
num_key_value_heads=config.attn_config.kv_n_heads,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -59,7 +59,12 @@ def _load_qkv_gptq(config, prefix: str, weights):
|
|||
rank = weights.process_group.rank()
|
||||
|
||||
# Weights
|
||||
weight = weights.get_weights_col_packed_qkv(f"{prefix}.c_attn", config.quantize)
|
||||
weight = weights.get_weights_col_packed_qkv(
|
||||
f"{prefix}.c_attn",
|
||||
config.quantize,
|
||||
config.num_attention_heads,
|
||||
config.num_attention_heads,
|
||||
)
|
||||
|
||||
# Bias
|
||||
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
|
||||
|
|
|
@ -62,6 +62,8 @@ def load_attention(config, prefix, weights):
|
|||
prefix=f"{prefix}.qkv_proj",
|
||||
weights=weights,
|
||||
bias=bias,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
)
|
||||
elif config.model_type == "baichuan":
|
||||
return TensorParallelColumnLinear.load_qkv(
|
||||
|
@ -69,6 +71,8 @@ def load_attention(config, prefix, weights):
|
|||
prefix=f"{prefix}.W_pack",
|
||||
weights=weights,
|
||||
bias=bias,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
)
|
||||
|
||||
# otherwise, load the default attention based on the number of heads
|
||||
|
@ -107,6 +111,11 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
if config.num_key_value_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`num_key_value_heads` must be divisible by `num_shards` (got `num_key_value_heads`: {config.num_key_value_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
self.num_key_value_heads = (
|
||||
config.num_key_value_heads // weights.process_group.size()
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from dataclasses import dataclass, field
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional, Set, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from safetensors import safe_open, SafetensorError
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
@ -121,49 +120,62 @@ class Weights:
|
|||
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
||||
return self.get_partial_sharded(tensor_name, dim)
|
||||
|
||||
def _get_qweight(self, name: str, blocks: int):
|
||||
def _get_qweight(self, name: str, block_sizes: Union[int, List[int]]):
|
||||
slice_ = self._get_slice(name)
|
||||
total_size = slice_.get_shape()[1]
|
||||
assert (
|
||||
total_size % blocks == 0
|
||||
), f"Prepacked quantized matrix is not divisible by {blocks}"
|
||||
single_size = total_size // blocks
|
||||
block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=block_sizes)
|
||||
|
||||
world_size = self.process_group.size()
|
||||
rank = self.process_group.rank()
|
||||
|
||||
assert (
|
||||
single_size % world_size == 0
|
||||
), f"Prepacked quantized matrix cannot be sharded across {world_size} shards"
|
||||
block_size = single_size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
|
||||
weights = []
|
||||
for block in range(blocks):
|
||||
weights.append(
|
||||
slice_[:, start + block * single_size : stop + block * single_size]
|
||||
)
|
||||
block_offset = 0
|
||||
for block_size in block_sizes:
|
||||
assert (
|
||||
block_size % world_size == 0
|
||||
), f"Prepacked qkv cannot be sharded across {world_size} shards"
|
||||
shard_block_size = block_size // world_size
|
||||
start = rank * shard_block_size
|
||||
stop = (rank + 1) * shard_block_size
|
||||
weights.append(slice_[:, block_offset + start : block_offset + stop])
|
||||
block_offset += block_size
|
||||
|
||||
weight = torch.cat(weights, dim=1)
|
||||
weight = weight.to(device=self.device)
|
||||
return weight
|
||||
|
||||
def get_weights_col_packed_qkv(self, prefix: str, quantize: str):
|
||||
return self.get_weights_col_packed(prefix, quantize, 3)
|
||||
def get_weights_col_packed_qkv(
|
||||
self,
|
||||
prefix: str,
|
||||
quantize: str,
|
||||
num_heads: int,
|
||||
num_key_value_heads: int,
|
||||
):
|
||||
return self.get_weights_col_packed(
|
||||
prefix, quantize, [num_heads, num_key_value_heads, num_key_value_heads]
|
||||
)
|
||||
|
||||
def get_weights_col_packed_gate_up(self, prefix: str, quantize: str):
|
||||
return self.get_weights_col_packed(prefix, quantize, 2)
|
||||
|
||||
def get_weights_col_packed(self, prefix: str, quantize: str, blocks: int):
|
||||
def get_weights_col_packed(
|
||||
self, prefix: str, quantize: str, block_sizes: Union[int, List[int]]
|
||||
):
|
||||
"""
|
||||
Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
|
||||
already alternating Q,K,V within the main tensor
|
||||
already alternating Q,K,V within the main tensor.
|
||||
|
||||
The columns are split in equally sized blocks when blocks is an `int`, or
|
||||
in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
|
||||
divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
|
||||
convenient for e.g. splitting QKV without knowing the storage details of
|
||||
quantized weights.
|
||||
"""
|
||||
if quantize in ["gptq", "awq"]:
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
|
||||
try:
|
||||
qweight = self._get_qweight(f"{prefix}.qweight", blocks)
|
||||
qweight = self._get_qweight(f"{prefix}.qweight", block_sizes)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
|
||||
|
@ -171,8 +183,8 @@ class Weights:
|
|||
|
||||
bits, groupsize, _, quant_method = self._get_gptq_params()
|
||||
|
||||
qzeros = self._get_qweight(f"{prefix}.qzeros", blocks)
|
||||
scales = self._get_qweight(f"{prefix}.scales", blocks)
|
||||
qzeros = self._get_qweight(f"{prefix}.qzeros", block_sizes)
|
||||
scales = self._get_qweight(f"{prefix}.scales", block_sizes)
|
||||
scales = scales.to(dtype=self.dtype)
|
||||
|
||||
if quantize == "gptq" and quant_method == "gptq":
|
||||
|
@ -205,27 +217,31 @@ class Weights:
|
|||
elif quantize == "marlin":
|
||||
from text_generation_server.layers.marlin import MarlinWeight
|
||||
|
||||
B = self._get_qweight(f"{prefix}.B", blocks)
|
||||
s = self._get_qweight(f"{prefix}.s", blocks)
|
||||
B = self._get_qweight(f"{prefix}.B", block_sizes)
|
||||
s = self._get_qweight(f"{prefix}.s", block_sizes)
|
||||
weight = MarlinWeight(B=B, s=s)
|
||||
else:
|
||||
slice_ = self._get_slice(f"{prefix}.weight")
|
||||
total_size = slice_.get_shape()[0]
|
||||
assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
|
||||
single_size = total_size // blocks
|
||||
block_sizes = _blocks_to_block_sizes(
|
||||
total_size=total_size, blocks=block_sizes
|
||||
)
|
||||
|
||||
world_size = self.process_group.size()
|
||||
rank = self.process_group.rank()
|
||||
|
||||
assert (
|
||||
single_size % world_size == 0
|
||||
), f"Prepacked qkv cannot be sharded across {world_size} shards"
|
||||
block_size = single_size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensors = []
|
||||
for i in range(blocks):
|
||||
tensor = slice_[start + i * single_size : stop + i * single_size]
|
||||
block_offset = 0
|
||||
for block_size in block_sizes:
|
||||
assert (
|
||||
block_size % world_size == 0
|
||||
), f"Prepacked weights cannot be sharded across {world_size} shards"
|
||||
shard_block_size = block_size // world_size
|
||||
start = rank * shard_block_size
|
||||
stop = (rank + 1) * shard_block_size
|
||||
tensor = slice_[block_offset + start : block_offset + stop]
|
||||
tensors.append(tensor)
|
||||
block_offset += block_size
|
||||
weight = torch.cat(tensors, dim=0)
|
||||
weight = weight.to(device=self.device)
|
||||
weight = weight.to(dtype=self.dtype)
|
||||
|
@ -593,3 +609,31 @@ class Weights:
|
|||
self.quant_method = "awq"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
|
||||
"""
|
||||
Convert block count or proportions to block sizes.
|
||||
|
||||
This function accepts
|
||||
|
||||
- The number of blocks (int), in which case the block size is
|
||||
total_size//blocks; or
|
||||
- A list of block sizes (List[int]).
|
||||
|
||||
In the latter case, if sum(blocks) < total_size, the ratios between
|
||||
the block sizes will be preserved. For instance, if blocks is
|
||||
[2, 1, 1] and total_size is 1024, the returned block sizes are
|
||||
[512, 256, 256].
|
||||
"""
|
||||
if isinstance(blocks, list):
|
||||
total_blocks = sum(blocks)
|
||||
assert (
|
||||
total_size % total_blocks == 0
|
||||
), f"Cannot split {total_size} in proportional blocks: {blocks}"
|
||||
part_size = total_size // total_blocks
|
||||
return [part_size * block for block in blocks]
|
||||
else:
|
||||
assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
|
||||
single_size = total_size // blocks
|
||||
return [single_size] * blocks
|
||||
|
|
Loading…
Reference in New Issue