fix of use of unquantized weights in cohere GQA loading, also enable … (#2291)
fix of use of unquantized weights in cohere GQA loading, also enable the model in intel platform Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
5ad39dd3c3
commit
8642250602
|
@ -45,6 +45,7 @@ from text_generation_server.layers.layernorm import (
|
||||||
from text_generation_server.layers.rotary import (
|
from text_generation_server.layers.rotary import (
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
|
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
import dropout_layer_norm
|
import dropout_layer_norm
|
||||||
|
@ -83,6 +84,12 @@ class CohereRotary(PositionRotaryEmbedding):
|
||||||
|
|
||||||
# Inplace operation, updating query and key.
|
# Inplace operation, updating query and key.
|
||||||
ops.rotary_embedding(query, key, head_size, cos, sin, False)
|
ops.rotary_embedding(query, key, head_size, cos, sin, False)
|
||||||
|
elif SYSTEM == "ipex":
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
|
ipex.llm.functional.rotary_embedding(
|
||||||
|
query, key, sin, cos, query.size(-1), False
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
||||||
|
@ -99,7 +106,7 @@ class CohereLayerNorm(nn.Module):
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
if hidden_states.shape[-1] > 8192 or SYSTEM == "rocm":
|
if hidden_states.shape[-1] > 8192 or SYSTEM != "cuda":
|
||||||
hidden_states = hidden_states.reshape(
|
hidden_states = hidden_states.reshape(
|
||||||
-1, self.weight.shape[0], self.weight.shape[1]
|
-1, self.weight.shape[0], self.weight.shape[1]
|
||||||
)
|
)
|
||||||
|
@ -166,16 +173,16 @@ def _load_gqa(config, prefix: str, weights):
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.quantize not in ["gptq", "awq", "marlin"]:
|
if isinstance(weight, UnquantizedWeight):
|
||||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||||
|
|
||||||
head_size = config.hidden_size // config.num_attention_heads
|
head_size = config.hidden_size // config.num_attention_heads
|
||||||
num_heads = config.num_attention_heads // weights.process_group.size()
|
num_heads = config.num_attention_heads // weights.process_group.size()
|
||||||
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
||||||
assert list(weight.shape) == [
|
assert list(weight.weight.shape) == [
|
||||||
(num_heads + 2 * num_key_value_heads) * head_size,
|
(num_heads + 2 * num_key_value_heads) * head_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||||
|
|
||||||
if config.attention_bias:
|
if config.attention_bias:
|
||||||
w = [
|
w = [
|
||||||
|
|
Loading…
Reference in New Issue