import intel_extension_for_pytorch as ipex import torch SUPPORTS_WINDOWING = False def attention( q, k, v, out, cu_seqlens, max_s, softmax_scale, window_size_left=-1, ): # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. return ipex.llm.functional.varlen_attention( q, k, v, out, cu_seqlens, cu_seqlens, max_s, max_s, 0.0, softmax_scale, False, True, False, None, ) def reshape_and_cache( key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, slots: torch.Tensor, kv_cache_dtype: str = "auto", kv_scale: int = 1.0, ): ipex.llm.modules.PagedAttention.reshape_and_cache( key, value, key_cache, value_cache, slots ) def paged_attention( out: torch.Tensor, query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, kv_head_mapping: torch.Tensor, softmax_scale: float, block_tables: torch.Tensor, input_lengths: torch.Tensor, max_s: int, kv_cache_dtype: str = "auto", kv_scale: int = 1.0, ): query = query.contiguous() block_size = value_cache.shape[3] return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( out, query, key_cache, value_cache, kv_head_mapping, softmax_scale, block_tables, input_lengths, block_size, max_s, None, )