Hotfix: fix of use of unquantized weights in Mixtral GQA loading (#2269)
* Update idefics_causal_lm.py Fix syntax issues * fix dbrx & opt model prefix bug * Hotfix: fix of use of unquantized weights in Mixtral GQA loading
This commit is contained in:
parent
f3435bab8c
commit
4e4207224e
|
@ -52,6 +52,7 @@ from text_generation_server.layers.layernorm import (
|
|||
from text_generation_server.layers.rotary import (
|
||||
PositionRotaryEmbedding,
|
||||
)
|
||||
from text_generation_server.utils.weights import UnquantizedWeight
|
||||
|
||||
|
||||
class MixtralConfig(PretrainedConfig):
|
||||
|
@ -138,16 +139,16 @@ def _load_gqa(config, prefix: str, weights):
|
|||
dim=0,
|
||||
)
|
||||
|
||||
if config.quantize not in ["gptq", "awq", "marlin"]:
|
||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||
if isinstance(weight, UnquantizedWeight):
|
||||
weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
num_heads = config.num_attention_heads // weights.process_group.size()
|
||||
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
||||
assert list(weight.shape) == [
|
||||
assert list(weight.weight.shape) == [
|
||||
(num_heads + 2 * num_key_value_heads) * head_size,
|
||||
config.hidden_size,
|
||||
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||
], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias=None))
|
||||
|
||||
|
|
Loading…
Reference in New Issue