fix of use of unquantized weights in cohere GQA loading, also enable … (#2291)

fix of use of unquantized weights in cohere GQA loading, also enable the model in intel platform

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi 2024-07-24 16:44:02 +08:00 committed by GitHub
parent 5ad39dd3c3
commit 8642250602
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 12 additions and 5 deletions

View File

@ -45,6 +45,7 @@ from text_generation_server.layers.layernorm import (
from text_generation_server.layers.rotary import ( from text_generation_server.layers.rotary import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
) )
from text_generation_server.utils.weights import UnquantizedWeight
if SYSTEM == "cuda": if SYSTEM == "cuda":
import dropout_layer_norm import dropout_layer_norm
@ -83,6 +84,12 @@ class CohereRotary(PositionRotaryEmbedding):
# Inplace operation, updating query and key. # Inplace operation, updating query and key.
ops.rotary_embedding(query, key, head_size, cos, sin, False) ops.rotary_embedding(query, key, head_size, cos, sin, False)
elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
ipex.llm.functional.rotary_embedding(
query, key, sin, cos, query.size(-1), False
)
else: else:
raise ValueError( raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
@ -99,7 +106,7 @@ class CohereLayerNorm(nn.Module):
self.eps = eps self.eps = eps
def forward(self, hidden_states): def forward(self, hidden_states):
if hidden_states.shape[-1] > 8192 or SYSTEM == "rocm": if hidden_states.shape[-1] > 8192 or SYSTEM != "cuda":
hidden_states = hidden_states.reshape( hidden_states = hidden_states.reshape(
-1, self.weight.shape[0], self.weight.shape[1] -1, self.weight.shape[0], self.weight.shape[1]
) )
@ -166,16 +173,16 @@ def _load_gqa(config, prefix: str, weights):
dim=0, dim=0,
) )
if config.quantize not in ["gptq", "awq", "marlin"]: if isinstance(weight, UnquantizedWeight):
weight = weight.to(dtype=weights.dtype).to(device=weights.device) weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads head_size = config.hidden_size // config.num_attention_heads
num_heads = config.num_attention_heads // weights.process_group.size() num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size() num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.shape) == [ assert list(weight.weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size, (num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size, config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
if config.attention_bias: if config.attention_bias:
w = [ w = [