Improve support for GPUs with capability < 8 (#2575)

* Improve support for GPUs with capability < 8

- For models that cannot use flashinfer, use flash-attn v1 + paged
  attention for models with a compute capability older than 8.
- Disable prefix caching when using paged attention.
- When using flash-attn v1, pass the key/value, rather than the
  cache, since v1 cannot use block tables.

* nix: add flash-attn-v1 to the server environment

* Move disabling prefix caching into the block of exceptions

* Capability as `usize`s
This commit is contained in:
Daniël de Kok 2024-09-27 16:19:42 +02:00 committed by GitHub
parent 0aa66d693a
commit 5b6b74e21d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
30 changed files with 205 additions and 116 deletions

1
Cargo.lock generated
View File

@ -4243,6 +4243,7 @@ dependencies = [
"hf-hub", "hf-hub",
"nix 0.28.0", "nix 0.28.0",
"once_cell", "once_cell",
"pyo3",
"reqwest", "reqwest",
"serde", "serde",
"serde_json", "serde_json",

View File

@ -33,6 +33,7 @@ metrics = { version = "0.23.0" }
metrics-exporter-prometheus = { version = "0.15.1", features = [] } metrics-exporter-prometheus = { version = "0.15.1", features = [] }
minijinja = { version = "2.2.0", features = ["json"] } minijinja = { version = "2.2.0", features = ["json"] }
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
pyo3 = { version = "0.22.2", features = ["auto-initialize"] }
[profile.release] [profile.release]
incremental = true incremental = true

View File

@ -978,16 +978,16 @@
"nixpkgs": "nixpkgs_6" "nixpkgs": "nixpkgs_6"
}, },
"locked": { "locked": {
"lastModified": 1726743157, "lastModified": 1727353315,
"narHash": "sha256-7OczwJsA47o+aUftMwkoh8R31DlNSl2FgRjqE8zAggk=", "narHash": "sha256-yZovq/6P8Z199r7e+NbTXyCqRgK6grRkLxYHWHnHckI=",
"owner": "danieldk", "owner": "huggingface",
"repo": "tgi-nix", "repo": "text-generation-inference-nix",
"rev": "bcc9fd01cf81bc42cebb999a736a377adfa8942f", "rev": "1d42c4125ebafb87707118168995675cc5050b9d",
"type": "github" "type": "github"
}, },
"original": { "original": {
"owner": "danieldk", "owner": "huggingface",
"repo": "tgi-nix", "repo": "text-generation-inference-nix",
"type": "github" "type": "github"
} }
} }

View File

@ -5,7 +5,7 @@
inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
}; };
nix-filter.url = "github:numtide/nix-filter"; nix-filter.url = "github:numtide/nix-filter";
tgi-nix.url = "github:danieldk/tgi-nix"; tgi-nix.url = "github:huggingface/text-generation-inference-nix";
nixpkgs.follows = "tgi-nix/nixpkgs"; nixpkgs.follows = "tgi-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils"; flake-utils.url = "github:numtide/flake-utils";
rust-overlay = { rust-overlay = {
@ -132,49 +132,12 @@
pre-commit pre-commit
ruff ruff
]); ]);
}; };
impure = mkShell { impure = callPackage ./nix/impure-shell.nix { inherit server; };
buildInputs =
[
openssl.dev
pkg-config
(rust-bin.stable.latest.default.override {
extensions = [
"rust-analyzer"
"rust-src"
];
})
protobuf
]
++ (with python3.pkgs; [
venvShellHook
docker
pip
ipdb
click
pyright
pytest
pytest-asyncio
redocly
ruff
syrupy
]);
inputsFrom = [ server ]; impure-flash-attn-v1 = callPackage ./nix/impure-shell.nix {
server = server.override { flash-attn = python3.pkgs.flash-attn-v1; };
venvDir = "./.venv";
postVenvCreation = ''
unset SOURCE_DATE_EPOCH
( cd server ; python -m pip install --no-dependencies -e . )
( cd clients/python ; python -m pip install --no-dependencies -e . )
'';
postShellHook = ''
unset SOURCE_DATE_EPOCH
export PATH=$PATH:~/.cargo/bin
'';
}; };
}; };

View File

@ -12,6 +12,7 @@ ctrlc = { version = "3.4.1", features = ["termination"] }
hf-hub = "0.3.2" hf-hub = "0.3.2"
nix = { version = "0.28.0", features = ["signal"] } nix = { version = "0.28.0", features = ["signal"] }
once_cell = "1.19.0" once_cell = "1.19.0"
pyo3 = { workspace = true }
serde = { version = "1.0.188", features = ["derive"] } serde = { version = "1.0.188", features = ["derive"] }
serde_json = "1.0.107" serde_json = "1.0.107"
thiserror = "1.0.59" thiserror = "1.0.59"

26
launcher/src/gpu.rs Normal file
View File

@ -0,0 +1,26 @@
use std::sync::LazyLock;
pub static COMPUTE_CAPABILITY: LazyLock<Option<(usize, usize)>> =
LazyLock::new(get_cuda_capability);
fn get_cuda_capability() -> Option<(usize, usize)> {
use pyo3::prelude::*;
let py_get_capability = |py: Python| -> PyResult<(isize, isize)> {
let torch = py.import_bound("torch.cuda")?;
let get_device_capability = torch.getattr("get_device_capability")?;
get_device_capability.call0()?.extract()
};
match pyo3::Python::with_gil(py_get_capability) {
Ok((major, minor)) if major < 0 || minor < 0 => {
tracing::warn!("Ignoring negative GPU compute capabilities: {major}.{minor}");
None
}
Ok((major, minor)) => Some((major as usize, minor as usize)),
Err(err) => {
tracing::warn!("Cannot determine GPU compute capability: {}", err);
None
}
}
}

View File

@ -26,6 +26,7 @@ use thiserror::Error;
use tracing_subscriber::{filter::LevelFilter, EnvFilter}; use tracing_subscriber::{filter::LevelFilter, EnvFilter};
mod env_runtime; mod env_runtime;
mod gpu;
fn get_config( fn get_config(
model_id: &str, model_id: &str,
@ -65,6 +66,7 @@ fn get_config(
} }
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) { fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
let compute_capability = *gpu::COMPUTE_CAPABILITY;
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok(); let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
let mut attention: Option<String> = std::env::var("ATTENTION").ok(); let mut attention: Option<String> = std::env::var("ATTENTION").ok();
if let Some(config) = config { if let Some(config) = config {
@ -77,6 +79,13 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
prefix_caching = Some("0".to_string()); prefix_caching = Some("0".to_string());
} }
} }
let fallback_attention = if matches!(compute_capability, Some((major, _)) if major < 8) {
"paged"
} else {
"flashdecoding"
};
match config.head_dim { match config.head_dim {
Some(h) if h == 64 || h == 128 || h == 256 => { Some(h) if h == 64 || h == 128 || h == 256 => {
if lora_adapters.is_some() && prefix_caching.is_none() { if lora_adapters.is_some() && prefix_caching.is_none() {
@ -89,10 +98,14 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
// flashinfer ? // flashinfer ?
if attention.is_none() { if attention.is_none() {
tracing::info!( tracing::info!(
"Forcing flash decoding because model {} requires it", "Forcing attention to '{fallback_attention}' because model {} requires it",
config.model_type.as_ref().unwrap() config.model_type.as_ref().unwrap()
); );
attention = Some("flashdecoding".to_string()); attention = Some(fallback_attention.to_string());
}
if fallback_attention == "paged" && prefix_caching.is_none() {
tracing::info!("Disabling prefix caching because it is not supported with 'paged' attention");
prefix_caching = Some("0".to_string());
} }
} }
Some("t5") => {} Some("t5") => {}
@ -101,8 +114,8 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
} }
_ => { _ => {
if attention.is_none() { if attention.is_none() {
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching"); tracing::info!("Forcing attention to '{fallback_attention}' because head dim is not supported by flashinfer, also disabling prefix caching");
attention = Some("flashdecoding".to_string()); attention = Some(fallback_attention.to_string());
} }
if prefix_caching.is_none() { if prefix_caching.is_none() {
prefix_caching = Some("0".to_string()); prefix_caching = Some("0".to_string());
@ -110,8 +123,10 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
} }
} }
} }
let prefix_caching = prefix_caching.unwrap_or("true".to_string());
let attention = attention.unwrap_or("flashinfer".to_string()); let attention = attention.unwrap_or("flashinfer".to_string());
let prefix_caching = prefix_caching.unwrap_or("true".to_string());
(prefix_caching, attention) (prefix_caching, attention)
} }

54
nix/impure-shell.nix Normal file
View File

@ -0,0 +1,54 @@
{
mkShell,
openssl,
pkg-config,
protobuf,
python3,
pyright,
redocly,
ruff,
rust-bin,
server,
}:
mkShell {
buildInputs =
[
openssl.dev
pkg-config
(rust-bin.stable.latest.default.override {
extensions = [
"rust-analyzer"
"rust-src"
];
})
protobuf
pyright
redocly
ruff
]
++ (with python3.pkgs; [
venvShellHook
docker
pip
ipdb
click
pytest
pytest-asyncio
syrupy
]);
inputsFrom = [ server ];
venvDir = "./.venv";
postVenvCreation = ''
unset SOURCE_DATE_EPOCH
( cd server ; python -m pip install --no-dependencies -e . )
( cd clients/python ; python -m pip install --no-dependencies -e . )
'';
postShellHook = ''
unset SOURCE_DATE_EPOCH
export PATH=$PATH:~/.cargo/bin
'';
}

View File

@ -13,6 +13,7 @@
flash-attn, flash-attn,
flash-attn-layer-norm, flash-attn-layer-norm,
flash-attn-rotary, flash-attn-rotary,
flash-attn-v1,
grpc-interceptor, grpc-interceptor,
grpcio-reflection, grpcio-reflection,
grpcio-status, grpcio-status,

View File

@ -61,7 +61,7 @@ uuid = { version = "1.9.1", default-features = false, features = [
] } ] }
csv = "1.3.0" csv = "1.3.0"
ureq = "=2.9" ureq = "=2.9"
pyo3 = { version = "0.22.2", features = ["auto-initialize"] } pyo3 = { workspace = true }
[build-dependencies] [build-dependencies]

View File

@ -11,11 +11,24 @@ if SYSTEM == "cuda":
paged_attention, paged_attention,
reshape_and_cache, reshape_and_cache,
SUPPORTS_WINDOWING, SUPPORTS_WINDOWING,
PREFILL_IN_KV_CACHE,
) )
elif SYSTEM == "rocm": elif SYSTEM == "rocm":
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING from .rocm import (
attention,
paged_attention,
reshape_and_cache,
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
)
elif SYSTEM == "ipex": elif SYSTEM == "ipex":
from .ipex import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING from .ipex import (
attention,
paged_attention,
reshape_and_cache,
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
)
else: else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
@ -24,6 +37,7 @@ __all__ = [
"attention", "attention",
"paged_attention", "paged_attention",
"reshape_and_cache", "reshape_and_cache",
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING", "SUPPORTS_WINDOWING",
"Seqlen", "Seqlen",
] ]

View File

@ -287,16 +287,14 @@ elif V2:
else: else:
def attention( def attention(
q, q: torch.Tensor,
k, k: torch.Tensor,
v, v: torch.Tensor,
key_cache: torch.Tensor, seqlen: Seqlen,
value_cache: torch.Tensor, block_tables: torch.Tensor,
cu_seqlens, softmax_scale: float,
max_s, window_size_left: int = -1,
softmax_scale, causal: bool = True,
window_size_left=-1,
causal=None,
softcap=None, softcap=None,
): ):
if window_size_left != -1: if window_size_left != -1:
@ -338,16 +336,22 @@ else:
k, k,
v, v,
out, out,
cu_seqlens, seqlen.cu_seqlen_q,
cu_seqlens, seqlen.cu_seqlen_q,
max_s, seqlen.max_q,
max_s, seqlen.max_k,
0.0, 0.0,
softmax_scale, softmax_scale,
False, False,
True, causal,
False, False,
0, 0,
None, None,
) )
return out return out
# Prefill in the cache with every kind of attention, unless we
# have a configuration that requires flash-attention v1, which
# does not support block tables.
PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2

View File

@ -5,6 +5,7 @@ from text_generation_server.layers.attention import Seqlen
from typing import Optional from typing import Optional
SUPPORTS_WINDOWING = False SUPPORTS_WINDOWING = False
PREFILL_IN_KV_CACHE = False
def attention( def attention(

View File

@ -13,6 +13,9 @@ _PARTITION_SIZE = 512
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"} use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
ENGINE = "triton" if use_triton else "ck" ENGINE = "triton" if use_triton else "ck"
PREFILL_IN_KV_CACHE = False
try: try:
from vllm._C import cache_ops from vllm._C import cache_ops
except Exception as e: except Exception as e:
@ -156,7 +159,6 @@ if ENGINE != "triton":
"or install flash attention with `cd server && make install install-flash-attention`" "or install flash attention with `cd server && make install install-flash-attention`"
) from e ) from e
else: else:
for idx in range(torch.cuda.device_count()): for idx in range(torch.cuda.device_count()):
name = torch.cuda.get_device_name(idx) name = torch.cuda.get_device_name(idx)
if "MI210" not in name and "MI250" not in name: if "MI210" not in name and "MI250" not in name:

View File

@ -39,6 +39,7 @@ from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
@ -297,8 +298,8 @@ class FlashCohereAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else key, kv_cache[0] if PREFILL_IN_KV_CACHE else key,
kv_cache[1] if SYSTEM != "ipex" else value, kv_cache[1] if PREFILL_IN_KV_CACHE else value,
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -30,6 +30,7 @@ from text_generation_server.layers.attention import (
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen, Seqlen,
PREFILL_IN_KV_CACHE,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
FastLinear, FastLinear,
@ -336,8 +337,8 @@ class DbrxAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -35,6 +35,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
reshape_and_cache, reshape_and_cache,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
@ -327,8 +328,8 @@ class DeepseekV2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else key, kv_cache[0] if PREFILL_IN_KV_CACHE else key,
kv_cache[1] if SYSTEM != "ipex" else value, kv_cache[1] if PREFILL_IN_KV_CACHE else value,
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -25,7 +25,6 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
@ -41,6 +40,7 @@ from text_generation_server.layers import (
TensorParallelMultiAdapterLinear, TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear, TensorParallelAdapterRowLinear,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
@ -260,8 +260,8 @@ class FlashGemma2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -25,12 +25,12 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen, Seqlen,
PREFILL_IN_KV_CACHE,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -231,8 +231,8 @@ class FlashGemmaAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -24,7 +24,7 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
@ -231,8 +231,8 @@ class FlashGPT2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else key, kv_cache[0] if PREFILL_IN_KV_CACHE else key,
kv_cache[1] if SYSTEM != "ipex" else value, kv_cache[1] if PREFILL_IN_KV_CACHE else value,
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -38,6 +38,7 @@ from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.rotary import ( from text_generation_server.layers.rotary import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
) )
@ -192,8 +193,8 @@ class FlashGPTJAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else key, kv_cache[0] if PREFILL_IN_KV_CACHE else key,
kv_cache[1] if SYSTEM != "ipex" else value, kv_cache[1] if PREFILL_IN_KV_CACHE else value,
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -27,6 +27,7 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
@ -220,8 +221,8 @@ class FlashLlamaAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -41,6 +41,7 @@ from text_generation_server.layers import (
TensorParallelMultiAdapterLinear, TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear, TensorParallelAdapterRowLinear,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
@ -218,8 +219,8 @@ class MistralAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0], kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1], kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -39,10 +39,10 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
reshape_and_cache, reshape_and_cache,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import UnquantizedWeight from text_generation_server.utils.weights import UnquantizedWeight
@ -267,8 +267,8 @@ class MixtralAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0], kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1], kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -26,7 +26,6 @@ from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
@ -40,6 +39,7 @@ from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
@ -172,8 +172,8 @@ class FlashNeoxAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
qkv[:, 0], qkv[:, 0],
kv_cache[0] if SYSTEM != "ipex" else qkv[:, 1], kv_cache[0] if PREFILL_IN_KV_CACHE else qkv[:, 1],
kv_cache[1] if SYSTEM != "ipex" else qkv[:, 2], kv_cache[1] if PREFILL_IN_KV_CACHE else qkv[:, 2],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -19,13 +19,13 @@ from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
from text_generation_server.layers.rotary import ( from text_generation_server.layers.rotary import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
) )
from text_generation_server.utils.import_utils import SYSTEM
class PhiConfig(PretrainedConfig): class PhiConfig(PretrainedConfig):
@ -194,8 +194,8 @@ class FlashPhiAttention(torch.nn.Module):
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -17,11 +17,11 @@ from text_generation_server.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
SpeculativeHead, SpeculativeHead,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
) )
from text_generation_server.utils.import_utils import SYSTEM
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights):
@ -137,8 +137,8 @@ class Qwen2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0], kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1], kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -5,7 +5,6 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers import ( from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
TensorParallelColumnLinear, TensorParallelColumnLinear,
@ -13,6 +12,7 @@ from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
get_linear, get_linear,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import FastLayerNorm from text_generation_server.layers.layernorm import FastLayerNorm
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
@ -207,8 +207,8 @@ class FlashRWAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -325,8 +325,8 @@ class FlashRWLargeAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, :, 0].contiguous(), kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, :, 0].contiguous(),
kv_cache[1] if SYSTEM != "ipex" else kv[:, :, 1].contiguous(), kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, :, 1].contiguous(),
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -18,11 +18,11 @@ from text_generation_server.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
get_linear, get_linear,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.gptq import GPTQWeightsLoader from text_generation_server.layers.gptq import GPTQWeightsLoader
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
from text_generation_server.utils.import_utils import SYSTEM
def load_multi_mqa( def load_multi_mqa(
@ -293,8 +293,8 @@ class FlashMQAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else key_value[:, 0], kv_cache[0] if PREFILL_IN_KV_CACHE else key_value[:, 0],
kv_cache[1] if SYSTEM != "ipex" else key_value[:, 1], kv_cache[1] if PREFILL_IN_KV_CACHE else key_value[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -39,6 +39,7 @@ from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
FastRMSNorm, FastRMSNorm,
@ -47,7 +48,6 @@ from text_generation_server.layers.rotary import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
) )
from text_generation_server.utils.weights import UnquantizedWeight from text_generation_server.utils.weights import UnquantizedWeight
from text_generation_server.utils.import_utils import SYSTEM
class Starcoder2Config(PretrainedConfig): class Starcoder2Config(PretrainedConfig):
@ -242,8 +242,8 @@ class Starcoder2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0], kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1], kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,