2024-05-31 09:57:01 -06:00
|
|
|
import intel_extension_for_pytorch as ipex
|
|
|
|
import torch
|
2024-10-24 08:36:18 -06:00
|
|
|
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
|
2024-06-25 04:21:29 -06:00
|
|
|
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
|
2024-07-02 03:56:07 -06:00
|
|
|
from text_generation_server.layers.attention import Seqlen
|
2024-08-08 10:30:29 -06:00
|
|
|
from typing import Optional
|
2024-05-31 09:57:01 -06:00
|
|
|
|
|
|
|
SUPPORTS_WINDOWING = False
|
|
|
|
|
|
|
|
|
|
|
|
def attention(
|
2024-10-17 02:42:52 -06:00
|
|
|
*,
|
|
|
|
query: torch.Tensor,
|
|
|
|
key: torch.Tensor,
|
|
|
|
value: torch.Tensor,
|
|
|
|
kv_cache: KVCache,
|
2024-10-24 08:36:18 -06:00
|
|
|
kv_scales: KVScales,
|
2024-09-05 09:41:39 -06:00
|
|
|
seqlen: Seqlen,
|
|
|
|
block_tables: torch.Tensor,
|
2024-10-17 02:42:52 -06:00
|
|
|
softmax_scale: float,
|
|
|
|
window_size_left: int = -1,
|
|
|
|
causal: bool = True,
|
2024-08-08 10:30:29 -06:00
|
|
|
softcap: Optional[float] = None,
|
2024-05-31 09:57:01 -06:00
|
|
|
):
|
2024-10-17 02:42:52 -06:00
|
|
|
if softcap is not None:
|
|
|
|
raise NotImplementedError("softcap is not available in IPEX")
|
|
|
|
|
|
|
|
out = torch.empty_like(query)
|
2024-08-01 09:03:28 -06:00
|
|
|
|
2024-06-10 01:09:50 -06:00
|
|
|
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
2024-08-08 10:30:29 -06:00
|
|
|
ipex.llm.functional.varlen_attention(
|
2024-10-17 02:42:52 -06:00
|
|
|
query.contiguous() if query.device.type == "xpu" else query,
|
|
|
|
key.contiguous() if key.device.type == "xpu" else key,
|
|
|
|
value.contiguous() if value.device.type == "xpu" else value,
|
2024-05-31 09:57:01 -06:00
|
|
|
out,
|
2024-09-05 09:41:39 -06:00
|
|
|
seqlen.cu_seqlen_q,
|
|
|
|
seqlen.cu_seqlen_q,
|
|
|
|
seqlen.max_q,
|
|
|
|
seqlen.max_q,
|
2024-05-31 09:57:01 -06:00
|
|
|
0.0,
|
|
|
|
softmax_scale,
|
|
|
|
False,
|
2024-07-01 06:32:54 -06:00
|
|
|
causal,
|
2024-05-31 09:57:01 -06:00
|
|
|
False,
|
|
|
|
None,
|
|
|
|
)
|
|
|
|
|
2024-08-08 10:30:29 -06:00
|
|
|
return out
|
|
|
|
|
2024-05-31 09:57:01 -06:00
|
|
|
|
|
|
|
def paged_attention(
|
|
|
|
query: torch.Tensor,
|
2024-10-17 02:42:52 -06:00
|
|
|
kv_cache: KVCache,
|
2024-05-31 09:57:01 -06:00
|
|
|
kv_head_mapping: torch.Tensor,
|
|
|
|
softmax_scale: float,
|
|
|
|
block_tables: torch.Tensor,
|
2024-07-02 03:56:07 -06:00
|
|
|
seqlen: Seqlen,
|
2024-05-31 09:57:01 -06:00
|
|
|
max_s: int,
|
2024-10-24 08:36:18 -06:00
|
|
|
*,
|
|
|
|
kv_scales: KVScales,
|
2024-08-08 10:30:29 -06:00
|
|
|
softcap: Optional[float] = None,
|
2024-05-31 09:57:01 -06:00
|
|
|
):
|
2024-10-17 02:42:52 -06:00
|
|
|
if softcap is not None:
|
|
|
|
raise NotImplementedError("softcap is not available in IPEX")
|
|
|
|
|
2024-08-01 09:03:28 -06:00
|
|
|
out = torch.empty_like(query)
|
2024-10-16 04:49:33 -06:00
|
|
|
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
2024-07-02 03:56:07 -06:00
|
|
|
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
2024-05-31 09:57:01 -06:00
|
|
|
out,
|
|
|
|
query,
|
2024-10-17 02:42:52 -06:00
|
|
|
kv_cache.key,
|
|
|
|
kv_cache.value,
|
2024-05-31 09:57:01 -06:00
|
|
|
kv_head_mapping,
|
|
|
|
softmax_scale,
|
|
|
|
block_tables,
|
2024-10-16 04:49:33 -06:00
|
|
|
input_lengths,
|
2024-06-25 04:21:29 -06:00
|
|
|
BLOCK_SIZE,
|
2024-05-31 09:57:01 -06:00
|
|
|
max_s,
|
|
|
|
None,
|
|
|
|
)
|
2024-07-02 03:56:07 -06:00
|
|
|
return out
|
2024-10-04 09:51:48 -06:00
|
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
"SUPPORTS_WINDOWING",
|
|
|
|
"attention",
|
|
|
|
"paged_attention",
|
|
|
|
]
|