Hotfix: fix of use of unquantized weights in Gemma GQA loading (#2255)
This commit is contained in:
parent
ba291dad9f
commit
80adb5be16
|
@ -42,6 +42,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
|
|
||||||
|
|
||||||
class Gemma2Config(PretrainedConfig):
|
class Gemma2Config(PretrainedConfig):
|
||||||
|
@ -144,16 +145,16 @@ def _load_gqa(config, prefix: str, weights):
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.quantize not in ["gptq", "awq", "marlin"]:
|
if isinstance(weight, UnquantizedWeight):
|
||||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||||
|
|
||||||
head_size = config.head_dim
|
head_size = config.head_dim
|
||||||
num_heads = config.num_attention_heads // weights.process_group.size()
|
num_heads = config.num_attention_heads // weights.process_group.size()
|
||||||
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
||||||
assert list(weight.shape) == [
|
assert list(weight.weight.shape) == [
|
||||||
(num_heads + 2 * num_key_value_heads) * head_size,
|
(num_heads + 2 * num_key_value_heads) * head_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||||
|
|
||||||
return TensorParallelColumnLinear(get_linear(weight, bias=None))
|
return TensorParallelColumnLinear(get_linear(weight, bias=None))
|
||||||
|
|
||||||
|
|
|
@ -42,6 +42,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
|
|
||||||
|
|
||||||
class GemmaConfig(PretrainedConfig):
|
class GemmaConfig(PretrainedConfig):
|
||||||
|
@ -144,16 +145,16 @@ def _load_gqa(config, prefix: str, weights):
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.quantize not in ["gptq", "awq", "marlin"]:
|
if isinstance(weight, UnquantizedWeight):
|
||||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||||
|
|
||||||
head_size = config.head_dim
|
head_size = config.head_dim
|
||||||
num_heads = config.num_attention_heads // weights.process_group.size()
|
num_heads = config.num_attention_heads // weights.process_group.size()
|
||||||
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
||||||
assert list(weight.shape) == [
|
assert list(weight.weight.shape) == [
|
||||||
(num_heads + 2 * num_key_value_heads) * head_size,
|
(num_heads + 2 * num_key_value_heads) * head_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||||
|
|
||||||
return TensorParallelColumnLinear(get_linear(weight, bias=None))
|
return TensorParallelColumnLinear(get_linear(weight, bias=None))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue