diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index c7b29d13..ab10dee1 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -45,6 +45,7 @@ from text_generation_server.layers.layernorm import ( from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) +from text_generation_server.utils.weights import UnquantizedWeight if SYSTEM == "cuda": import dropout_layer_norm @@ -83,6 +84,12 @@ class CohereRotary(PositionRotaryEmbedding): # Inplace operation, updating query and key. ops.rotary_embedding(query, key, head_size, cos, sin, False) + elif SYSTEM == "ipex": + import intel_extension_for_pytorch as ipex + + ipex.llm.functional.rotary_embedding( + query, key, sin, cos, query.size(-1), False + ) else: raise ValueError( "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." @@ -99,7 +106,7 @@ class CohereLayerNorm(nn.Module): self.eps = eps def forward(self, hidden_states): - if hidden_states.shape[-1] > 8192 or SYSTEM == "rocm": + if hidden_states.shape[-1] > 8192 or SYSTEM != "cuda": hidden_states = hidden_states.reshape( -1, self.weight.shape[0], self.weight.shape[1] ) @@ -166,16 +173,16 @@ def _load_gqa(config, prefix: str, weights): dim=0, ) - if config.quantize not in ["gptq", "awq", "marlin"]: - weight = weight.to(dtype=weights.dtype).to(device=weights.device) + if isinstance(weight, UnquantizedWeight): + weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.hidden_size // config.num_attention_heads num_heads = config.num_attention_heads // weights.process_group.size() num_key_value_heads = config.num_key_value_heads // weights.process_group.size() - assert list(weight.shape) == [ + assert list(weight.weight.shape) == [ (num_heads + 2 * num_key_value_heads) * head_size, config.hidden_size, - ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" if config.attention_bias: w = [