enable mllama in intel platform (#2610)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
0da4df4b96
commit
57f9685dc3
|
@ -19,7 +19,12 @@ from typing import Optional, Tuple, List
|
|||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
import flash_attn_2_cuda
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
if SYSTEM == "ipex":
|
||||
import intel_extension_for_pytorch as ipex
|
||||
else:
|
||||
import flash_attn_2_cuda
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
import torch.nn.functional as F
|
||||
|
@ -698,6 +703,37 @@ class MllamaTextCrossAttention(nn.Module):
|
|||
# logger.info(
|
||||
# f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}"
|
||||
# )
|
||||
if SYSTEM == "ipex":
|
||||
attn_output = torch.empty_like(query_states)
|
||||
ipex.llm.functional.varlen_attention(
|
||||
(
|
||||
query_states.contiguous()
|
||||
if query_states.device.type == "xpu"
|
||||
else query_states
|
||||
),
|
||||
(
|
||||
key_states.contiguous()
|
||||
if key_states.device.type == "xpu"
|
||||
else key_states
|
||||
),
|
||||
(
|
||||
value_states.contiguous()
|
||||
if value_states.device.type == "xpu"
|
||||
else value_states
|
||||
),
|
||||
attn_output,
|
||||
cu_seqlen_q,
|
||||
cu_seqlen_k,
|
||||
max_q,
|
||||
max_k,
|
||||
0.0,
|
||||
self.softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
attn_output = flash_attn_2_cuda.varlen_fwd(
|
||||
query_states,
|
||||
key_states,
|
||||
|
|
Loading…
Reference in New Issue