From 5b6b74e21d6cfa961afe3338fc5cfd45fa357b50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 27 Sep 2024 16:19:42 +0200 Subject: [PATCH] Improve support for GPUs with capability < 8 (#2575) * Improve support for GPUs with capability < 8 - For models that cannot use flashinfer, use flash-attn v1 + paged attention for models with a compute capability older than 8. - Disable prefix caching when using paged attention. - When using flash-attn v1, pass the key/value, rather than the cache, since v1 cannot use block tables. * nix: add flash-attn-v1 to the server environment * Move disabling prefix caching into the block of exceptions * Capability as `usize`s --- Cargo.lock | 1 + Cargo.toml | 1 + flake.lock | 14 ++--- flake.nix | 45 ++-------------- launcher/Cargo.toml | 1 + launcher/src/gpu.rs | 26 +++++++++ launcher/src/main.rs | 25 +++++++-- nix/impure-shell.nix | 54 +++++++++++++++++++ nix/server.nix | 1 + router/Cargo.toml | 2 +- .../layers/attention/__init__.py | 18 ++++++- .../layers/attention/cuda.py | 34 ++++++------ .../layers/attention/ipex.py | 1 + .../layers/attention/rocm.py | 4 +- .../custom_modeling/flash_cohere_modeling.py | 5 +- .../custom_modeling/flash_dbrx_modeling.py | 5 +- .../flash_deepseek_v2_modeling.py | 5 +- .../custom_modeling/flash_gemma2_modeling.py | 6 +-- .../custom_modeling/flash_gemma_modeling.py | 6 +-- .../custom_modeling/flash_gpt2_modeling.py | 6 +-- .../custom_modeling/flash_gptj_modeling.py | 5 +- .../custom_modeling/flash_llama_modeling.py | 5 +- .../custom_modeling/flash_mistral_modeling.py | 5 +- .../custom_modeling/flash_mixtral_modeling.py | 6 +-- .../custom_modeling/flash_neox_modeling.py | 6 +-- .../custom_modeling/flash_phi_modeling.py | 6 +-- .../custom_modeling/flash_qwen2_modeling.py | 6 +-- .../custom_modeling/flash_rw_modeling.py | 10 ++-- .../flash_santacoder_modeling.py | 6 +-- .../flash_starcoder2_modeling.py | 6 +-- 30 files changed, 205 insertions(+), 116 deletions(-) create mode 100644 launcher/src/gpu.rs create mode 100644 nix/impure-shell.nix diff --git a/Cargo.lock b/Cargo.lock index e535004e..6796212f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4243,6 +4243,7 @@ dependencies = [ "hf-hub", "nix 0.28.0", "once_cell", + "pyo3", "reqwest", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index 032dc857..a783fadb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,7 @@ metrics = { version = "0.23.0" } metrics-exporter-prometheus = { version = "0.15.1", features = [] } minijinja = { version = "2.2.0", features = ["json"] } minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } +pyo3 = { version = "0.22.2", features = ["auto-initialize"] } [profile.release] incremental = true diff --git a/flake.lock b/flake.lock index d811be5e..14e23b77 100644 --- a/flake.lock +++ b/flake.lock @@ -978,16 +978,16 @@ "nixpkgs": "nixpkgs_6" }, "locked": { - "lastModified": 1726743157, - "narHash": "sha256-7OczwJsA47o+aUftMwkoh8R31DlNSl2FgRjqE8zAggk=", - "owner": "danieldk", - "repo": "tgi-nix", - "rev": "bcc9fd01cf81bc42cebb999a736a377adfa8942f", + "lastModified": 1727353315, + "narHash": "sha256-yZovq/6P8Z199r7e+NbTXyCqRgK6grRkLxYHWHnHckI=", + "owner": "huggingface", + "repo": "text-generation-inference-nix", + "rev": "1d42c4125ebafb87707118168995675cc5050b9d", "type": "github" }, "original": { - "owner": "danieldk", - "repo": "tgi-nix", + "owner": "huggingface", + "repo": "text-generation-inference-nix", "type": "github" } } diff --git a/flake.nix b/flake.nix index 260b2554..1b396453 100644 --- a/flake.nix +++ b/flake.nix @@ -5,7 +5,7 @@ inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; }; nix-filter.url = "github:numtide/nix-filter"; - tgi-nix.url = "github:danieldk/tgi-nix"; + tgi-nix.url = "github:huggingface/text-generation-inference-nix"; nixpkgs.follows = "tgi-nix/nixpkgs"; flake-utils.url = "github:numtide/flake-utils"; rust-overlay = { @@ -132,49 +132,12 @@ pre-commit ruff ]); - }; - impure = mkShell { - buildInputs = - [ - openssl.dev - pkg-config - (rust-bin.stable.latest.default.override { - extensions = [ - "rust-analyzer" - "rust-src" - ]; - }) - protobuf - ] - ++ (with python3.pkgs; [ - venvShellHook - docker - pip - ipdb - click - pyright - pytest - pytest-asyncio - redocly - ruff - syrupy - ]); + impure = callPackage ./nix/impure-shell.nix { inherit server; }; - inputsFrom = [ server ]; - - venvDir = "./.venv"; - - postVenvCreation = '' - unset SOURCE_DATE_EPOCH - ( cd server ; python -m pip install --no-dependencies -e . ) - ( cd clients/python ; python -m pip install --no-dependencies -e . ) - ''; - postShellHook = '' - unset SOURCE_DATE_EPOCH - export PATH=$PATH:~/.cargo/bin - ''; + impure-flash-attn-v1 = callPackage ./nix/impure-shell.nix { + server = server.override { flash-attn = python3.pkgs.flash-attn-v1; }; }; }; diff --git a/launcher/Cargo.toml b/launcher/Cargo.toml index eb219423..033a9a04 100644 --- a/launcher/Cargo.toml +++ b/launcher/Cargo.toml @@ -12,6 +12,7 @@ ctrlc = { version = "3.4.1", features = ["termination"] } hf-hub = "0.3.2" nix = { version = "0.28.0", features = ["signal"] } once_cell = "1.19.0" +pyo3 = { workspace = true } serde = { version = "1.0.188", features = ["derive"] } serde_json = "1.0.107" thiserror = "1.0.59" diff --git a/launcher/src/gpu.rs b/launcher/src/gpu.rs new file mode 100644 index 00000000..755d246a --- /dev/null +++ b/launcher/src/gpu.rs @@ -0,0 +1,26 @@ +use std::sync::LazyLock; + +pub static COMPUTE_CAPABILITY: LazyLock> = + LazyLock::new(get_cuda_capability); + +fn get_cuda_capability() -> Option<(usize, usize)> { + use pyo3::prelude::*; + + let py_get_capability = |py: Python| -> PyResult<(isize, isize)> { + let torch = py.import_bound("torch.cuda")?; + let get_device_capability = torch.getattr("get_device_capability")?; + get_device_capability.call0()?.extract() + }; + + match pyo3::Python::with_gil(py_get_capability) { + Ok((major, minor)) if major < 0 || minor < 0 => { + tracing::warn!("Ignoring negative GPU compute capabilities: {major}.{minor}"); + None + } + Ok((major, minor)) => Some((major as usize, minor as usize)), + Err(err) => { + tracing::warn!("Cannot determine GPU compute capability: {}", err); + None + } + } +} diff --git a/launcher/src/main.rs b/launcher/src/main.rs index deb18478..583220a6 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -26,6 +26,7 @@ use thiserror::Error; use tracing_subscriber::{filter::LevelFilter, EnvFilter}; mod env_runtime; +mod gpu; fn get_config( model_id: &str, @@ -65,6 +66,7 @@ fn get_config( } fn resolve_attention(config: &Option, lora_adapters: &Option) -> (String, String) { + let compute_capability = *gpu::COMPUTE_CAPABILITY; let mut prefix_caching: Option = std::env::var("USE_PREFIX_CACHING").ok(); let mut attention: Option = std::env::var("ATTENTION").ok(); if let Some(config) = config { @@ -77,6 +79,13 @@ fn resolve_attention(config: &Option, lora_adapters: &Option) -> prefix_caching = Some("0".to_string()); } } + + let fallback_attention = if matches!(compute_capability, Some((major, _)) if major < 8) { + "paged" + } else { + "flashdecoding" + }; + match config.head_dim { Some(h) if h == 64 || h == 128 || h == 256 => { if lora_adapters.is_some() && prefix_caching.is_none() { @@ -89,10 +98,14 @@ fn resolve_attention(config: &Option, lora_adapters: &Option) -> // flashinfer ? if attention.is_none() { tracing::info!( - "Forcing flash decoding because model {} requires it", + "Forcing attention to '{fallback_attention}' because model {} requires it", config.model_type.as_ref().unwrap() ); - attention = Some("flashdecoding".to_string()); + attention = Some(fallback_attention.to_string()); + } + if fallback_attention == "paged" && prefix_caching.is_none() { + tracing::info!("Disabling prefix caching because it is not supported with 'paged' attention"); + prefix_caching = Some("0".to_string()); } } Some("t5") => {} @@ -101,8 +114,8 @@ fn resolve_attention(config: &Option, lora_adapters: &Option) -> } _ => { if attention.is_none() { - tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching"); - attention = Some("flashdecoding".to_string()); + tracing::info!("Forcing attention to '{fallback_attention}' because head dim is not supported by flashinfer, also disabling prefix caching"); + attention = Some(fallback_attention.to_string()); } if prefix_caching.is_none() { prefix_caching = Some("0".to_string()); @@ -110,8 +123,10 @@ fn resolve_attention(config: &Option, lora_adapters: &Option) -> } } } - let prefix_caching = prefix_caching.unwrap_or("true".to_string()); + let attention = attention.unwrap_or("flashinfer".to_string()); + let prefix_caching = prefix_caching.unwrap_or("true".to_string()); + (prefix_caching, attention) } diff --git a/nix/impure-shell.nix b/nix/impure-shell.nix new file mode 100644 index 00000000..a4dad4ba --- /dev/null +++ b/nix/impure-shell.nix @@ -0,0 +1,54 @@ +{ + mkShell, + openssl, + pkg-config, + protobuf, + python3, + pyright, + redocly, + ruff, + rust-bin, + server, +}: + +mkShell { + buildInputs = + [ + openssl.dev + pkg-config + (rust-bin.stable.latest.default.override { + extensions = [ + "rust-analyzer" + "rust-src" + ]; + }) + protobuf + pyright + redocly + ruff + ] + ++ (with python3.pkgs; [ + venvShellHook + docker + pip + ipdb + click + pytest + pytest-asyncio + syrupy + ]); + + inputsFrom = [ server ]; + + venvDir = "./.venv"; + + postVenvCreation = '' + unset SOURCE_DATE_EPOCH + ( cd server ; python -m pip install --no-dependencies -e . ) + ( cd clients/python ; python -m pip install --no-dependencies -e . ) + ''; + postShellHook = '' + unset SOURCE_DATE_EPOCH + export PATH=$PATH:~/.cargo/bin + ''; +} diff --git a/nix/server.nix b/nix/server.nix index 5921da7f..7406d563 100644 --- a/nix/server.nix +++ b/nix/server.nix @@ -13,6 +13,7 @@ flash-attn, flash-attn-layer-norm, flash-attn-rotary, + flash-attn-v1, grpc-interceptor, grpcio-reflection, grpcio-status, diff --git a/router/Cargo.toml b/router/Cargo.toml index 6a752db6..83d85327 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -61,7 +61,7 @@ uuid = { version = "1.9.1", default-features = false, features = [ ] } csv = "1.3.0" ureq = "=2.9" -pyo3 = { version = "0.22.2", features = ["auto-initialize"] } +pyo3 = { workspace = true } [build-dependencies] diff --git a/server/text_generation_server/layers/attention/__init__.py b/server/text_generation_server/layers/attention/__init__.py index 56fc5319..4f2b9807 100644 --- a/server/text_generation_server/layers/attention/__init__.py +++ b/server/text_generation_server/layers/attention/__init__.py @@ -11,11 +11,24 @@ if SYSTEM == "cuda": paged_attention, reshape_and_cache, SUPPORTS_WINDOWING, + PREFILL_IN_KV_CACHE, ) elif SYSTEM == "rocm": - from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING + from .rocm import ( + attention, + paged_attention, + reshape_and_cache, + PREFILL_IN_KV_CACHE, + SUPPORTS_WINDOWING, + ) elif SYSTEM == "ipex": - from .ipex import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING + from .ipex import ( + attention, + paged_attention, + reshape_and_cache, + PREFILL_IN_KV_CACHE, + SUPPORTS_WINDOWING, + ) else: raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") @@ -24,6 +37,7 @@ __all__ = [ "attention", "paged_attention", "reshape_and_cache", + "PREFILL_IN_KV_CACHE", "SUPPORTS_WINDOWING", "Seqlen", ] diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 4b588b5c..51af928d 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -287,16 +287,14 @@ elif V2: else: def attention( - q, - k, - v, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - cu_seqlens, - max_s, - softmax_scale, - window_size_left=-1, - causal=None, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seqlen: Seqlen, + block_tables: torch.Tensor, + softmax_scale: float, + window_size_left: int = -1, + causal: bool = True, softcap=None, ): if window_size_left != -1: @@ -338,16 +336,22 @@ else: k, v, out, - cu_seqlens, - cu_seqlens, - max_s, - max_s, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_q, + seqlen.max_q, + seqlen.max_k, 0.0, softmax_scale, False, - True, + causal, False, 0, None, ) return out + + +# Prefill in the cache with every kind of attention, unless we +# have a configuration that requires flash-attention v1, which +# does not support block tables. +PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2 diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index d0eadc75..657c90af 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -5,6 +5,7 @@ from text_generation_server.layers.attention import Seqlen from typing import Optional SUPPORTS_WINDOWING = False +PREFILL_IN_KV_CACHE = False def attention( diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 16ce8d2b..9f24ac98 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -13,6 +13,9 @@ _PARTITION_SIZE = 512 use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"} ENGINE = "triton" if use_triton else "ck" + +PREFILL_IN_KV_CACHE = False + try: from vllm._C import cache_ops except Exception as e: @@ -156,7 +159,6 @@ if ENGINE != "triton": "or install flash attention with `cd server && make install install-flash-attention`" ) from e else: - for idx in range(torch.cuda.device_count()): name = torch.cuda.get_device_name(idx) if "MI210" not in name and "MI250" not in name: diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 374ccb10..30656038 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -39,6 +39,7 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.layernorm import ( FastLayerNorm, ) @@ -297,8 +298,8 @@ class FlashCohereAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else key, - kv_cache[1] if SYSTEM != "ipex" else value, + kv_cache[0] if PREFILL_IN_KV_CACHE else key, + kv_cache[1] if PREFILL_IN_KV_CACHE else value, seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 0dc88098..1137a453 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -30,6 +30,7 @@ from text_generation_server.layers.attention import ( attention, reshape_and_cache, Seqlen, + PREFILL_IN_KV_CACHE, ) from text_generation_server.layers import ( FastLinear, @@ -336,8 +337,8 @@ class DbrxAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], - kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], + kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 2ca7cc24..ac191ec3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -35,6 +35,7 @@ from text_generation_server.layers.attention import ( paged_attention, reshape_and_cache, ) +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale @@ -327,8 +328,8 @@ class DeepseekV2Attention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else key, - kv_cache[1] if SYSTEM != "ipex" else value, + kv_cache[0] if PREFILL_IN_KV_CACHE else key, + kv_cache[1] if PREFILL_IN_KV_CACHE else value, seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index 887e187e..7a3d60c9 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -25,7 +25,6 @@ from torch import nn from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( paged_attention, attention, @@ -41,6 +40,7 @@ from text_generation_server.layers import ( TensorParallelMultiAdapterLinear, TensorParallelAdapterRowLinear, ) +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, @@ -260,8 +260,8 @@ class FlashGemma2Attention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], - kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], + kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 77ae4b35..4c1be6f6 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -25,12 +25,12 @@ from torch import nn from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( paged_attention, attention, reshape_and_cache, Seqlen, + PREFILL_IN_KV_CACHE, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -231,8 +231,8 @@ class FlashGemmaAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], - kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], + kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 411c4ce1..44c015cf 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -24,7 +24,7 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN from typing import Optional, List, Tuple -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.attention import ( paged_attention, attention, @@ -231,8 +231,8 @@ class FlashGPT2Attention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else key, - kv_cache[1] if SYSTEM != "ipex" else value, + kv_cache[0] if PREFILL_IN_KV_CACHE else key, + kv_cache[1] if PREFILL_IN_KV_CACHE else value, seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py index ef071d46..aca97004 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py @@ -38,6 +38,7 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) @@ -192,8 +193,8 @@ class FlashGPTJAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else key, - kv_cache[1] if SYSTEM != "ipex" else value, + kv_cache[0] if PREFILL_IN_KV_CACHE else key, + kv_cache[1] if PREFILL_IN_KV_CACHE else value, seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 7d639e35..758e39aa 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -27,6 +27,7 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( paged_attention, @@ -220,8 +221,8 @@ class FlashLlamaAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], - kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], + kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index cdd23796..3e16d371 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -41,6 +41,7 @@ from text_generation_server.layers import ( TensorParallelMultiAdapterLinear, TensorParallelAdapterRowLinear, ) +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, @@ -218,8 +219,8 @@ class MistralAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0], - kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1], + kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 02da1384..5836d30a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -39,10 +39,10 @@ from text_generation_server.layers.attention import ( paged_attention, reshape_and_cache, ) +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.rotary import PositionRotaryEmbedding -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.weights import UnquantizedWeight @@ -267,8 +267,8 @@ class MixtralAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0], - kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1], + kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 454e45eb..ad4e382f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -26,7 +26,6 @@ from transformers.activations import ACT2FN from transformers.modeling_utils import PreTrainedModel from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig from typing import Optional, List, Tuple -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( paged_attention, attention, @@ -40,6 +39,7 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.layernorm import ( FastLayerNorm, ) @@ -172,8 +172,8 @@ class FlashNeoxAttention(torch.nn.Module): # flash attention attn_output = attention( qkv[:, 0], - kv_cache[0] if SYSTEM != "ipex" else qkv[:, 1], - kv_cache[1] if SYSTEM != "ipex" else qkv[:, 2], + kv_cache[0] if PREFILL_IN_KV_CACHE else qkv[:, 1], + kv_cache[1] if PREFILL_IN_KV_CACHE else qkv[:, 2], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index e2d9bbbc..2a0dc606 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -19,13 +19,13 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.layernorm import ( FastLayerNorm, ) from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) -from text_generation_server.utils.import_utils import SYSTEM class PhiConfig(PretrainedConfig): @@ -194,8 +194,8 @@ class FlashPhiAttention(torch.nn.Module): if cu_seqlen_prefill is not None: attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], - kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], + kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 999b72e7..02c788d3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -17,11 +17,11 @@ from text_generation_server.layers import ( TensorParallelEmbedding, SpeculativeHead, ) +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, ) -from text_generation_server.utils.import_utils import SYSTEM def load_attention(config, prefix, weights): @@ -137,8 +137,8 @@ class Qwen2Attention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0], - kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1], + kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index edc54c09..6671d85e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -5,7 +5,6 @@ import torch.distributed from torch import nn from transformers.configuration_utils import PretrainedConfig from transformers.modeling_utils import PreTrainedModel -from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers import ( SpeculativeHead, TensorParallelColumnLinear, @@ -13,6 +12,7 @@ from text_generation_server.layers import ( TensorParallelRowLinear, get_linear, ) +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.layernorm import FastLayerNorm from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.attention import ( @@ -207,8 +207,8 @@ class FlashRWAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], - kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], + kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], seqlen, block_tables, self.softmax_scale, @@ -325,8 +325,8 @@ class FlashRWLargeAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else kv[:, :, 0].contiguous(), - kv_cache[1] if SYSTEM != "ipex" else kv[:, :, 1].contiguous(), + kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, :, 0].contiguous(), + kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, :, 1].contiguous(), seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index f97b4409..43eb9687 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -18,11 +18,11 @@ from text_generation_server.layers import ( TensorParallelEmbedding, get_linear, ) +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.gptq import GPTQWeightsLoader from text_generation_server.layers.layernorm import ( FastLayerNorm, ) -from text_generation_server.utils.import_utils import SYSTEM def load_multi_mqa( @@ -293,8 +293,8 @@ class FlashMQAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else key_value[:, 0], - kv_cache[1] if SYSTEM != "ipex" else key_value[:, 1], + kv_cache[0] if PREFILL_IN_KV_CACHE else key_value[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else key_value[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 6aa7fa21..4975cf22 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -39,6 +39,7 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.layernorm import ( FastLayerNorm, FastRMSNorm, @@ -47,7 +48,6 @@ from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) from text_generation_server.utils.weights import UnquantizedWeight -from text_generation_server.utils.import_utils import SYSTEM class Starcoder2Config(PretrainedConfig): @@ -242,8 +242,8 @@ class Starcoder2Attention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0], - kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1], + kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], + kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], seqlen, block_tables, self.softmax_scale,