fix of use of unquantized weights in cohere GQA loading, also enable … (#2291)
fix of use of unquantized weights in cohere GQA loading, also enable the model in intel platform Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
5ad39dd3c3
commit
8642250602
|
@ -45,6 +45,7 @@ from text_generation_server.layers.layernorm import (
|
|||
from text_generation_server.layers.rotary import (
|
||||
PositionRotaryEmbedding,
|
||||
)
|
||||
from text_generation_server.utils.weights import UnquantizedWeight
|
||||
|
||||
if SYSTEM == "cuda":
|
||||
import dropout_layer_norm
|
||||
|
@ -83,6 +84,12 @@ class CohereRotary(PositionRotaryEmbedding):
|
|||
|
||||
# Inplace operation, updating query and key.
|
||||
ops.rotary_embedding(query, key, head_size, cos, sin, False)
|
||||
elif SYSTEM == "ipex":
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
ipex.llm.functional.rotary_embedding(
|
||||
query, key, sin, cos, query.size(-1), False
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
||||
|
@ -99,7 +106,7 @@ class CohereLayerNorm(nn.Module):
|
|||
self.eps = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
if hidden_states.shape[-1] > 8192 or SYSTEM == "rocm":
|
||||
if hidden_states.shape[-1] > 8192 or SYSTEM != "cuda":
|
||||
hidden_states = hidden_states.reshape(
|
||||
-1, self.weight.shape[0], self.weight.shape[1]
|
||||
)
|
||||
|
@ -166,16 +173,16 @@ def _load_gqa(config, prefix: str, weights):
|
|||
dim=0,
|
||||
)
|
||||
|
||||
if config.quantize not in ["gptq", "awq", "marlin"]:
|
||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||
if isinstance(weight, UnquantizedWeight):
|
||||
weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
num_heads = config.num_attention_heads // weights.process_group.size()
|
||||
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
||||
assert list(weight.shape) == [
|
||||
assert list(weight.weight.shape) == [
|
||||
(num_heads + 2 * num_key_value_heads) * head_size,
|
||||
config.hidden_size,
|
||||
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||
], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||
|
||||
if config.attention_bias:
|
||||
w = [
|
||||
|
|
Loading…
Reference in New Issue