fix(server): fix deepseekv2 loading (#2266)
This commit is contained in:
parent
53ec0b790b
commit
f3435bab8c
|
@ -34,7 +34,6 @@ from text_generation_server.layers.attention.common import Seqlen
|
||||||
from text_generation_server.layers.layernorm import FastRMSNorm
|
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils.log import log_once
|
|
||||||
from text_generation_server.utils.weights import Weights
|
from text_generation_server.utils.weights import Weights
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
|
@ -240,7 +239,6 @@ class DeepseekV2Attention(torch.nn.Module):
|
||||||
if config.attention_bias
|
if config.attention_bias
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
quantize=config.quantize,
|
|
||||||
)
|
)
|
||||||
self.q_a_layernorm = FastRMSNorm.load(
|
self.q_a_layernorm = FastRMSNorm.load(
|
||||||
prefix=f"{prefix}.q_a_layernorm",
|
prefix=f"{prefix}.q_a_layernorm",
|
||||||
|
@ -261,7 +259,6 @@ class DeepseekV2Attention(torch.nn.Module):
|
||||||
if config.attention_bias
|
if config.attention_bias
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
quantize=config.quantize,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.kv_a_layernorm = FastRMSNorm.load(
|
self.kv_a_layernorm = FastRMSNorm.load(
|
||||||
|
|
Loading…
Reference in New Issue