Hotfix: fix of use of unquantized weights in Mixtral GQA loading (#2269)

* Update idefics_causal_lm.py

Fix syntax issues

* fix dbrx & opt model prefix bug

* Hotfix: fix of use of unquantized weights in Mixtral GQA loading
This commit is contained in:
icyboy™ 2024-07-22 17:31:00 +08:00 committed by GitHub
parent f3435bab8c
commit 4e4207224e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 5 additions and 4 deletions

View File

@ -52,6 +52,7 @@ from text_generation_server.layers.layernorm import (
from text_generation_server.layers.rotary import ( from text_generation_server.layers.rotary import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
) )
from text_generation_server.utils.weights import UnquantizedWeight
class MixtralConfig(PretrainedConfig): class MixtralConfig(PretrainedConfig):
@ -138,16 +139,16 @@ def _load_gqa(config, prefix: str, weights):
dim=0, dim=0,
) )
if config.quantize not in ["gptq", "awq", "marlin"]: if isinstance(weight, UnquantizedWeight):
weight = weight.to(dtype=weights.dtype).to(device=weights.device) weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads head_size = config.hidden_size // config.num_attention_heads
num_heads = config.num_attention_heads // weights.process_group.size() num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size() num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.shape) == [ assert list(weight.weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size, (num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size, config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear(get_linear(weight, bias=None)) return TensorParallelColumnLinear(get_linear(weight, bias=None))