diff --git a/server/text_generation_server/models/custom_modeling/mllama.py b/server/text_generation_server/models/custom_modeling/mllama.py index 73536bd6..6e091a74 100644 --- a/server/text_generation_server/models/custom_modeling/mllama.py +++ b/server/text_generation_server/models/custom_modeling/mllama.py @@ -19,7 +19,12 @@ from typing import Optional, Tuple, List import torch import torch.utils.checkpoint from torch import nn -import flash_attn_2_cuda +from text_generation_server.utils.import_utils import SYSTEM + +if SYSTEM == "ipex": + import intel_extension_for_pytorch as ipex +else: + import flash_attn_2_cuda from transformers.activations import ACT2FN import torch.nn.functional as F @@ -698,29 +703,60 @@ class MllamaTextCrossAttention(nn.Module): # logger.info( # f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}" # ) - attn_output = flash_attn_2_cuda.varlen_fwd( - query_states, - key_states, - value_states, - None, - cu_seqlen_q, - cu_seqlen_k, - None, - None, - None, # block_tables - None, - max_q, - max_k, - 0.0, - self.softmax_scale, - False, - causal, # Causal - -1, # window_size_left, - -1, - 0.0, # softcap - False, - None, - )[0] + if SYSTEM == "ipex": + attn_output = torch.empty_like(query_states) + ipex.llm.functional.varlen_attention( + ( + query_states.contiguous() + if query_states.device.type == "xpu" + else query_states + ), + ( + key_states.contiguous() + if key_states.device.type == "xpu" + else key_states + ), + ( + value_states.contiguous() + if value_states.device.type == "xpu" + else value_states + ), + attn_output, + cu_seqlen_q, + cu_seqlen_k, + max_q, + max_k, + 0.0, + self.softmax_scale, + False, + causal, + False, + None, + ) + else: + attn_output = flash_attn_2_cuda.varlen_fwd( + query_states, + key_states, + value_states, + None, + cu_seqlen_q, + cu_seqlen_k, + None, + None, + None, # block_tables + None, + max_q, + max_k, + 0.0, + self.softmax_scale, + False, + causal, # Causal + -1, # window_size_left, + -1, + 0.0, # softcap + False, + None, + )[0] attn_output = self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return attn_output