Support exl2-quantized Qwen2 models (#2085)

Fixes #2081.
This commit is contained in:
Daniël de Kok 2024-06-20 07:56:16 +02:00 committed by GitHub
parent cdbf802860
commit f5a9837592
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 4 additions and 23 deletions

View File

@ -40,31 +40,12 @@ def _load_gqa(config, prefix: str, weights):
assert config.hidden_size % config.num_attention_heads == 0
assert config.num_attention_heads % weights.process_group.size() == 0
weight = weights.get_multi_weights_col(
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0,
)
if config.quantize not in ["gptq", "awq", "marlin"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
w = [
weights.get_sharded(f"{p}.bias", dim=0)
for p in [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
]
bias = torch.cat(w, dim=0).to(dtype=weights.dtype).to(device=weights.device)
return TensorParallelColumnLinear(
get_linear(weight, bias=bias, quantize=config.quantize)
weights=weights,
bias=True,
)