Improve support for GPUs with capability < 8 (#2575)
* Improve support for GPUs with capability < 8 - For models that cannot use flashinfer, use flash-attn v1 + paged attention for models with a compute capability older than 8. - Disable prefix caching when using paged attention. - When using flash-attn v1, pass the key/value, rather than the cache, since v1 cannot use block tables. * nix: add flash-attn-v1 to the server environment * Move disabling prefix caching into the block of exceptions * Capability as `usize`s
This commit is contained in:
parent
0aa66d693a
commit
5b6b74e21d
|
@ -4243,6 +4243,7 @@ dependencies = [
|
||||||
"hf-hub",
|
"hf-hub",
|
||||||
"nix 0.28.0",
|
"nix 0.28.0",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
|
"pyo3",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
|
|
@ -33,6 +33,7 @@ metrics = { version = "0.23.0" }
|
||||||
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
|
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
|
||||||
minijinja = { version = "2.2.0", features = ["json"] }
|
minijinja = { version = "2.2.0", features = ["json"] }
|
||||||
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
|
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
|
||||||
|
pyo3 = { version = "0.22.2", features = ["auto-initialize"] }
|
||||||
|
|
||||||
[profile.release]
|
[profile.release]
|
||||||
incremental = true
|
incremental = true
|
||||||
|
|
14
flake.lock
14
flake.lock
|
@ -978,16 +978,16 @@
|
||||||
"nixpkgs": "nixpkgs_6"
|
"nixpkgs": "nixpkgs_6"
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1726743157,
|
"lastModified": 1727353315,
|
||||||
"narHash": "sha256-7OczwJsA47o+aUftMwkoh8R31DlNSl2FgRjqE8zAggk=",
|
"narHash": "sha256-yZovq/6P8Z199r7e+NbTXyCqRgK6grRkLxYHWHnHckI=",
|
||||||
"owner": "danieldk",
|
"owner": "huggingface",
|
||||||
"repo": "tgi-nix",
|
"repo": "text-generation-inference-nix",
|
||||||
"rev": "bcc9fd01cf81bc42cebb999a736a377adfa8942f",
|
"rev": "1d42c4125ebafb87707118168995675cc5050b9d",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
"owner": "danieldk",
|
"owner": "huggingface",
|
||||||
"repo": "tgi-nix",
|
"repo": "text-generation-inference-nix",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
45
flake.nix
45
flake.nix
|
@ -5,7 +5,7 @@
|
||||||
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||||
};
|
};
|
||||||
nix-filter.url = "github:numtide/nix-filter";
|
nix-filter.url = "github:numtide/nix-filter";
|
||||||
tgi-nix.url = "github:danieldk/tgi-nix";
|
tgi-nix.url = "github:huggingface/text-generation-inference-nix";
|
||||||
nixpkgs.follows = "tgi-nix/nixpkgs";
|
nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||||
flake-utils.url = "github:numtide/flake-utils";
|
flake-utils.url = "github:numtide/flake-utils";
|
||||||
rust-overlay = {
|
rust-overlay = {
|
||||||
|
@ -132,49 +132,12 @@
|
||||||
pre-commit
|
pre-commit
|
||||||
ruff
|
ruff
|
||||||
]);
|
]);
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
impure = mkShell {
|
impure = callPackage ./nix/impure-shell.nix { inherit server; };
|
||||||
buildInputs =
|
|
||||||
[
|
|
||||||
openssl.dev
|
|
||||||
pkg-config
|
|
||||||
(rust-bin.stable.latest.default.override {
|
|
||||||
extensions = [
|
|
||||||
"rust-analyzer"
|
|
||||||
"rust-src"
|
|
||||||
];
|
|
||||||
})
|
|
||||||
protobuf
|
|
||||||
]
|
|
||||||
++ (with python3.pkgs; [
|
|
||||||
venvShellHook
|
|
||||||
docker
|
|
||||||
pip
|
|
||||||
ipdb
|
|
||||||
click
|
|
||||||
pyright
|
|
||||||
pytest
|
|
||||||
pytest-asyncio
|
|
||||||
redocly
|
|
||||||
ruff
|
|
||||||
syrupy
|
|
||||||
]);
|
|
||||||
|
|
||||||
inputsFrom = [ server ];
|
impure-flash-attn-v1 = callPackage ./nix/impure-shell.nix {
|
||||||
|
server = server.override { flash-attn = python3.pkgs.flash-attn-v1; };
|
||||||
venvDir = "./.venv";
|
|
||||||
|
|
||||||
postVenvCreation = ''
|
|
||||||
unset SOURCE_DATE_EPOCH
|
|
||||||
( cd server ; python -m pip install --no-dependencies -e . )
|
|
||||||
( cd clients/python ; python -m pip install --no-dependencies -e . )
|
|
||||||
'';
|
|
||||||
postShellHook = ''
|
|
||||||
unset SOURCE_DATE_EPOCH
|
|
||||||
export PATH=$PATH:~/.cargo/bin
|
|
||||||
'';
|
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@ ctrlc = { version = "3.4.1", features = ["termination"] }
|
||||||
hf-hub = "0.3.2"
|
hf-hub = "0.3.2"
|
||||||
nix = { version = "0.28.0", features = ["signal"] }
|
nix = { version = "0.28.0", features = ["signal"] }
|
||||||
once_cell = "1.19.0"
|
once_cell = "1.19.0"
|
||||||
|
pyo3 = { workspace = true }
|
||||||
serde = { version = "1.0.188", features = ["derive"] }
|
serde = { version = "1.0.188", features = ["derive"] }
|
||||||
serde_json = "1.0.107"
|
serde_json = "1.0.107"
|
||||||
thiserror = "1.0.59"
|
thiserror = "1.0.59"
|
||||||
|
|
|
@ -0,0 +1,26 @@
|
||||||
|
use std::sync::LazyLock;
|
||||||
|
|
||||||
|
pub static COMPUTE_CAPABILITY: LazyLock<Option<(usize, usize)>> =
|
||||||
|
LazyLock::new(get_cuda_capability);
|
||||||
|
|
||||||
|
fn get_cuda_capability() -> Option<(usize, usize)> {
|
||||||
|
use pyo3::prelude::*;
|
||||||
|
|
||||||
|
let py_get_capability = |py: Python| -> PyResult<(isize, isize)> {
|
||||||
|
let torch = py.import_bound("torch.cuda")?;
|
||||||
|
let get_device_capability = torch.getattr("get_device_capability")?;
|
||||||
|
get_device_capability.call0()?.extract()
|
||||||
|
};
|
||||||
|
|
||||||
|
match pyo3::Python::with_gil(py_get_capability) {
|
||||||
|
Ok((major, minor)) if major < 0 || minor < 0 => {
|
||||||
|
tracing::warn!("Ignoring negative GPU compute capabilities: {major}.{minor}");
|
||||||
|
None
|
||||||
|
}
|
||||||
|
Ok((major, minor)) => Some((major as usize, minor as usize)),
|
||||||
|
Err(err) => {
|
||||||
|
tracing::warn!("Cannot determine GPU compute capability: {}", err);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -26,6 +26,7 @@ use thiserror::Error;
|
||||||
use tracing_subscriber::{filter::LevelFilter, EnvFilter};
|
use tracing_subscriber::{filter::LevelFilter, EnvFilter};
|
||||||
|
|
||||||
mod env_runtime;
|
mod env_runtime;
|
||||||
|
mod gpu;
|
||||||
|
|
||||||
fn get_config(
|
fn get_config(
|
||||||
model_id: &str,
|
model_id: &str,
|
||||||
|
@ -65,6 +66,7 @@ fn get_config(
|
||||||
}
|
}
|
||||||
|
|
||||||
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
|
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
|
||||||
|
let compute_capability = *gpu::COMPUTE_CAPABILITY;
|
||||||
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
|
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
|
||||||
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
|
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
|
||||||
if let Some(config) = config {
|
if let Some(config) = config {
|
||||||
|
@ -77,6 +79,13 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
||||||
prefix_caching = Some("0".to_string());
|
prefix_caching = Some("0".to_string());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let fallback_attention = if matches!(compute_capability, Some((major, _)) if major < 8) {
|
||||||
|
"paged"
|
||||||
|
} else {
|
||||||
|
"flashdecoding"
|
||||||
|
};
|
||||||
|
|
||||||
match config.head_dim {
|
match config.head_dim {
|
||||||
Some(h) if h == 64 || h == 128 || h == 256 => {
|
Some(h) if h == 64 || h == 128 || h == 256 => {
|
||||||
if lora_adapters.is_some() && prefix_caching.is_none() {
|
if lora_adapters.is_some() && prefix_caching.is_none() {
|
||||||
|
@ -89,10 +98,14 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
||||||
// flashinfer ?
|
// flashinfer ?
|
||||||
if attention.is_none() {
|
if attention.is_none() {
|
||||||
tracing::info!(
|
tracing::info!(
|
||||||
"Forcing flash decoding because model {} requires it",
|
"Forcing attention to '{fallback_attention}' because model {} requires it",
|
||||||
config.model_type.as_ref().unwrap()
|
config.model_type.as_ref().unwrap()
|
||||||
);
|
);
|
||||||
attention = Some("flashdecoding".to_string());
|
attention = Some(fallback_attention.to_string());
|
||||||
|
}
|
||||||
|
if fallback_attention == "paged" && prefix_caching.is_none() {
|
||||||
|
tracing::info!("Disabling prefix caching because it is not supported with 'paged' attention");
|
||||||
|
prefix_caching = Some("0".to_string());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Some("t5") => {}
|
Some("t5") => {}
|
||||||
|
@ -101,8 +114,8 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
if attention.is_none() {
|
if attention.is_none() {
|
||||||
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching");
|
tracing::info!("Forcing attention to '{fallback_attention}' because head dim is not supported by flashinfer, also disabling prefix caching");
|
||||||
attention = Some("flashdecoding".to_string());
|
attention = Some(fallback_attention.to_string());
|
||||||
}
|
}
|
||||||
if prefix_caching.is_none() {
|
if prefix_caching.is_none() {
|
||||||
prefix_caching = Some("0".to_string());
|
prefix_caching = Some("0".to_string());
|
||||||
|
@ -110,8 +123,10 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let prefix_caching = prefix_caching.unwrap_or("true".to_string());
|
|
||||||
let attention = attention.unwrap_or("flashinfer".to_string());
|
let attention = attention.unwrap_or("flashinfer".to_string());
|
||||||
|
let prefix_caching = prefix_caching.unwrap_or("true".to_string());
|
||||||
|
|
||||||
(prefix_caching, attention)
|
(prefix_caching, attention)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
{
|
||||||
|
mkShell,
|
||||||
|
openssl,
|
||||||
|
pkg-config,
|
||||||
|
protobuf,
|
||||||
|
python3,
|
||||||
|
pyright,
|
||||||
|
redocly,
|
||||||
|
ruff,
|
||||||
|
rust-bin,
|
||||||
|
server,
|
||||||
|
}:
|
||||||
|
|
||||||
|
mkShell {
|
||||||
|
buildInputs =
|
||||||
|
[
|
||||||
|
openssl.dev
|
||||||
|
pkg-config
|
||||||
|
(rust-bin.stable.latest.default.override {
|
||||||
|
extensions = [
|
||||||
|
"rust-analyzer"
|
||||||
|
"rust-src"
|
||||||
|
];
|
||||||
|
})
|
||||||
|
protobuf
|
||||||
|
pyright
|
||||||
|
redocly
|
||||||
|
ruff
|
||||||
|
]
|
||||||
|
++ (with python3.pkgs; [
|
||||||
|
venvShellHook
|
||||||
|
docker
|
||||||
|
pip
|
||||||
|
ipdb
|
||||||
|
click
|
||||||
|
pytest
|
||||||
|
pytest-asyncio
|
||||||
|
syrupy
|
||||||
|
]);
|
||||||
|
|
||||||
|
inputsFrom = [ server ];
|
||||||
|
|
||||||
|
venvDir = "./.venv";
|
||||||
|
|
||||||
|
postVenvCreation = ''
|
||||||
|
unset SOURCE_DATE_EPOCH
|
||||||
|
( cd server ; python -m pip install --no-dependencies -e . )
|
||||||
|
( cd clients/python ; python -m pip install --no-dependencies -e . )
|
||||||
|
'';
|
||||||
|
postShellHook = ''
|
||||||
|
unset SOURCE_DATE_EPOCH
|
||||||
|
export PATH=$PATH:~/.cargo/bin
|
||||||
|
'';
|
||||||
|
}
|
|
@ -13,6 +13,7 @@
|
||||||
flash-attn,
|
flash-attn,
|
||||||
flash-attn-layer-norm,
|
flash-attn-layer-norm,
|
||||||
flash-attn-rotary,
|
flash-attn-rotary,
|
||||||
|
flash-attn-v1,
|
||||||
grpc-interceptor,
|
grpc-interceptor,
|
||||||
grpcio-reflection,
|
grpcio-reflection,
|
||||||
grpcio-status,
|
grpcio-status,
|
||||||
|
|
|
@ -61,7 +61,7 @@ uuid = { version = "1.9.1", default-features = false, features = [
|
||||||
] }
|
] }
|
||||||
csv = "1.3.0"
|
csv = "1.3.0"
|
||||||
ureq = "=2.9"
|
ureq = "=2.9"
|
||||||
pyo3 = { version = "0.22.2", features = ["auto-initialize"] }
|
pyo3 = { workspace = true }
|
||||||
|
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
|
|
@ -11,11 +11,24 @@ if SYSTEM == "cuda":
|
||||||
paged_attention,
|
paged_attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
SUPPORTS_WINDOWING,
|
SUPPORTS_WINDOWING,
|
||||||
|
PREFILL_IN_KV_CACHE,
|
||||||
)
|
)
|
||||||
elif SYSTEM == "rocm":
|
elif SYSTEM == "rocm":
|
||||||
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
from .rocm import (
|
||||||
|
attention,
|
||||||
|
paged_attention,
|
||||||
|
reshape_and_cache,
|
||||||
|
PREFILL_IN_KV_CACHE,
|
||||||
|
SUPPORTS_WINDOWING,
|
||||||
|
)
|
||||||
elif SYSTEM == "ipex":
|
elif SYSTEM == "ipex":
|
||||||
from .ipex import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
from .ipex import (
|
||||||
|
attention,
|
||||||
|
paged_attention,
|
||||||
|
reshape_and_cache,
|
||||||
|
PREFILL_IN_KV_CACHE,
|
||||||
|
SUPPORTS_WINDOWING,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
|
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
|
||||||
|
|
||||||
|
@ -24,6 +37,7 @@ __all__ = [
|
||||||
"attention",
|
"attention",
|
||||||
"paged_attention",
|
"paged_attention",
|
||||||
"reshape_and_cache",
|
"reshape_and_cache",
|
||||||
|
"PREFILL_IN_KV_CACHE",
|
||||||
"SUPPORTS_WINDOWING",
|
"SUPPORTS_WINDOWING",
|
||||||
"Seqlen",
|
"Seqlen",
|
||||||
]
|
]
|
||||||
|
|
|
@ -287,16 +287,14 @@ elif V2:
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
q,
|
q: torch.Tensor,
|
||||||
k,
|
k: torch.Tensor,
|
||||||
v,
|
v: torch.Tensor,
|
||||||
key_cache: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
value_cache: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
cu_seqlens,
|
softmax_scale: float,
|
||||||
max_s,
|
window_size_left: int = -1,
|
||||||
softmax_scale,
|
causal: bool = True,
|
||||||
window_size_left=-1,
|
|
||||||
causal=None,
|
|
||||||
softcap=None,
|
softcap=None,
|
||||||
):
|
):
|
||||||
if window_size_left != -1:
|
if window_size_left != -1:
|
||||||
|
@ -338,16 +336,22 @@ else:
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
out,
|
out,
|
||||||
cu_seqlens,
|
seqlen.cu_seqlen_q,
|
||||||
cu_seqlens,
|
seqlen.cu_seqlen_q,
|
||||||
max_s,
|
seqlen.max_q,
|
||||||
max_s,
|
seqlen.max_k,
|
||||||
0.0,
|
0.0,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
False,
|
False,
|
||||||
True,
|
causal,
|
||||||
False,
|
False,
|
||||||
0,
|
0,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
# Prefill in the cache with every kind of attention, unless we
|
||||||
|
# have a configuration that requires flash-attention v1, which
|
||||||
|
# does not support block tables.
|
||||||
|
PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2
|
||||||
|
|
|
@ -5,6 +5,7 @@ from text_generation_server.layers.attention import Seqlen
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
SUPPORTS_WINDOWING = False
|
SUPPORTS_WINDOWING = False
|
||||||
|
PREFILL_IN_KV_CACHE = False
|
||||||
|
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
|
|
|
@ -13,6 +13,9 @@ _PARTITION_SIZE = 512
|
||||||
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
|
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
|
||||||
ENGINE = "triton" if use_triton else "ck"
|
ENGINE = "triton" if use_triton else "ck"
|
||||||
|
|
||||||
|
|
||||||
|
PREFILL_IN_KV_CACHE = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from vllm._C import cache_ops
|
from vllm._C import cache_ops
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -156,7 +159,6 @@ if ENGINE != "triton":
|
||||||
"or install flash attention with `cd server && make install install-flash-attention`"
|
"or install flash attention with `cd server && make install install-flash-attention`"
|
||||||
) from e
|
) from e
|
||||||
else:
|
else:
|
||||||
|
|
||||||
for idx in range(torch.cuda.device_count()):
|
for idx in range(torch.cuda.device_count()):
|
||||||
name = torch.cuda.get_device_name(idx)
|
name = torch.cuda.get_device_name(idx)
|
||||||
if "MI210" not in name and "MI250" not in name:
|
if "MI210" not in name and "MI250" not in name:
|
||||||
|
|
|
@ -39,6 +39,7 @@ from text_generation_server.layers import (
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
)
|
)
|
||||||
|
@ -297,8 +298,8 @@ class FlashCohereAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else key,
|
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
|
||||||
kv_cache[1] if SYSTEM != "ipex" else value,
|
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -30,6 +30,7 @@ from text_generation_server.layers.attention import (
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
Seqlen,
|
Seqlen,
|
||||||
|
PREFILL_IN_KV_CACHE,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
FastLinear,
|
FastLinear,
|
||||||
|
@ -336,8 +337,8 @@ class DbrxAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -35,6 +35,7 @@ from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||||
from text_generation_server.layers.layernorm import FastRMSNorm
|
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
||||||
|
@ -327,8 +328,8 @@ class DeepseekV2Attention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else key,
|
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
|
||||||
kv_cache[1] if SYSTEM != "ipex" else value,
|
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -25,7 +25,6 @@ from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
|
@ -41,6 +40,7 @@ from text_generation_server.layers import (
|
||||||
TensorParallelMultiAdapterLinear,
|
TensorParallelMultiAdapterLinear,
|
||||||
TensorParallelAdapterRowLinear,
|
TensorParallelAdapterRowLinear,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
|
@ -260,8 +260,8 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -25,12 +25,12 @@ from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
Seqlen,
|
Seqlen,
|
||||||
|
PREFILL_IN_KV_CACHE,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
|
@ -231,8 +231,8 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -24,7 +24,7 @@ import torch.distributed
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
|
@ -231,8 +231,8 @@ class FlashGPT2Attention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else key,
|
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
|
||||||
kv_cache[1] if SYSTEM != "ipex" else value,
|
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -38,6 +38,7 @@ from text_generation_server.layers import (
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||||
from text_generation_server.layers.rotary import (
|
from text_generation_server.layers.rotary import (
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
)
|
)
|
||||||
|
@ -192,8 +193,8 @@ class FlashGPTJAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else key,
|
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
|
||||||
kv_cache[1] if SYSTEM != "ipex" else value,
|
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -27,6 +27,7 @@ import torch.distributed
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
|
|
||||||
|
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
|
@ -220,8 +221,8 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -41,6 +41,7 @@ from text_generation_server.layers import (
|
||||||
TensorParallelMultiAdapterLinear,
|
TensorParallelMultiAdapterLinear,
|
||||||
TensorParallelAdapterRowLinear,
|
TensorParallelAdapterRowLinear,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
|
@ -218,8 +219,8 @@ class MistralAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
|
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
||||||
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
|
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -39,10 +39,10 @@ from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||||
from text_generation_server.layers.layernorm import FastRMSNorm
|
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
from text_generation_server.utils.weights import UnquantizedWeight
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
|
|
||||||
|
|
||||||
|
@ -267,8 +267,8 @@ class MixtralAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
|
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
||||||
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
|
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -26,7 +26,6 @@ from transformers.activations import ACT2FN
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig
|
from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
|
@ -40,6 +39,7 @@ from text_generation_server.layers import (
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
)
|
)
|
||||||
|
@ -172,8 +172,8 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
qkv[:, 0],
|
qkv[:, 0],
|
||||||
kv_cache[0] if SYSTEM != "ipex" else qkv[:, 1],
|
kv_cache[0] if PREFILL_IN_KV_CACHE else qkv[:, 1],
|
||||||
kv_cache[1] if SYSTEM != "ipex" else qkv[:, 2],
|
kv_cache[1] if PREFILL_IN_KV_CACHE else qkv[:, 2],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -19,13 +19,13 @@ from text_generation_server.layers import (
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.rotary import (
|
from text_generation_server.layers.rotary import (
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
|
|
||||||
class PhiConfig(PretrainedConfig):
|
class PhiConfig(PretrainedConfig):
|
||||||
|
@ -194,8 +194,8 @@ class FlashPhiAttention(torch.nn.Module):
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -17,11 +17,11 @@ from text_generation_server.layers import (
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights):
|
||||||
|
@ -137,8 +137,8 @@ class Qwen2Attention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
|
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
||||||
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
|
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -5,7 +5,6 @@ import torch.distributed
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
|
@ -13,6 +12,7 @@ from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||||
from text_generation_server.layers.layernorm import FastLayerNorm
|
from text_generation_server.layers.layernorm import FastLayerNorm
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
|
@ -207,8 +207,8 @@ class FlashRWAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -325,8 +325,8 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, :, 0].contiguous(),
|
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, :, 0].contiguous(),
|
||||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, :, 1].contiguous(),
|
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, :, 1].contiguous(),
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -18,11 +18,11 @@ from text_generation_server.layers import (
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||||
from text_generation_server.layers.gptq import GPTQWeightsLoader
|
from text_generation_server.layers.gptq import GPTQWeightsLoader
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
|
|
||||||
def load_multi_mqa(
|
def load_multi_mqa(
|
||||||
|
@ -293,8 +293,8 @@ class FlashMQAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else key_value[:, 0],
|
kv_cache[0] if PREFILL_IN_KV_CACHE else key_value[:, 0],
|
||||||
kv_cache[1] if SYSTEM != "ipex" else key_value[:, 1],
|
kv_cache[1] if PREFILL_IN_KV_CACHE else key_value[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -39,6 +39,7 @@ from text_generation_server.layers import (
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
|
@ -47,7 +48,6 @@ from text_generation_server.layers.rotary import (
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.weights import UnquantizedWeight
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
|
|
||||||
class Starcoder2Config(PretrainedConfig):
|
class Starcoder2Config(PretrainedConfig):
|
||||||
|
@ -242,8 +242,8 @@ class Starcoder2Attention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
|
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
||||||
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
|
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
Loading…
Reference in New Issue