Pr 2337 ci branch (#2379)
* hotfix: fix xpu crash brought by code refine. torch.xpu rely on import ipex Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * reable gemma2 in xpu Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix in regression in ipex flashattention Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> --------- Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> Co-authored-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
689b1abbf6
commit
2ca5980634
|
@ -2,6 +2,7 @@ import intel_extension_for_pytorch as ipex
|
|||
import torch
|
||||
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
from typing import Optional
|
||||
|
||||
SUPPORTS_WINDOWING = False
|
||||
|
||||
|
@ -15,11 +16,12 @@ def attention(
|
|||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
causal=True,
|
||||
softcap: Optional[float] = None,
|
||||
):
|
||||
out = torch.empty_like(q)
|
||||
|
||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||
return ipex.llm.functional.varlen_attention(
|
||||
ipex.llm.functional.varlen_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
|
@ -36,6 +38,8 @@ def attention(
|
|||
None,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def reshape_and_cache(
|
||||
key: torch.Tensor,
|
||||
|
@ -58,6 +62,7 @@ def paged_attention(
|
|||
block_tables: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
softcap: Optional[float] = None,
|
||||
):
|
||||
out = torch.empty_like(query)
|
||||
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||
|
|
|
@ -56,6 +56,8 @@ elif torch.version.cuda is not None and torch.cuda.is_available():
|
|||
get_free_memory = get_cuda_free_memory
|
||||
elif is_ipex_available():
|
||||
SYSTEM = "ipex"
|
||||
import intel_extension_for_pytorch # noqa: F401
|
||||
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
empty_cache = torch.xpu.empty_cache
|
||||
synchronize = torch.xpu.synchronize
|
||||
|
|
Loading…
Reference in New Issue