Merge remote-tracking branch 'upstream/main' into rocm_6.2_updates
This commit is contained in:
commit
473d9a892d
|
@ -4243,6 +4243,7 @@ dependencies = [
|
|||
"hf-hub",
|
||||
"nix 0.28.0",
|
||||
"once_cell",
|
||||
"pyo3",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
|
|
@ -33,6 +33,7 @@ metrics = { version = "0.23.0" }
|
|||
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
|
||||
minijinja = { version = "2.2.0", features = ["json"] }
|
||||
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
|
||||
pyo3 = { version = "0.22.2", features = ["auto-initialize"] }
|
||||
|
||||
[profile.release]
|
||||
incremental = true
|
||||
|
|
14
flake.lock
14
flake.lock
|
@ -978,16 +978,16 @@
|
|||
"nixpkgs": "nixpkgs_6"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1726743157,
|
||||
"narHash": "sha256-7OczwJsA47o+aUftMwkoh8R31DlNSl2FgRjqE8zAggk=",
|
||||
"owner": "danieldk",
|
||||
"repo": "tgi-nix",
|
||||
"rev": "bcc9fd01cf81bc42cebb999a736a377adfa8942f",
|
||||
"lastModified": 1727353315,
|
||||
"narHash": "sha256-yZovq/6P8Z199r7e+NbTXyCqRgK6grRkLxYHWHnHckI=",
|
||||
"owner": "huggingface",
|
||||
"repo": "text-generation-inference-nix",
|
||||
"rev": "1d42c4125ebafb87707118168995675cc5050b9d",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "danieldk",
|
||||
"repo": "tgi-nix",
|
||||
"owner": "huggingface",
|
||||
"repo": "text-generation-inference-nix",
|
||||
"type": "github"
|
||||
}
|
||||
}
|
||||
|
|
45
flake.nix
45
flake.nix
|
@ -5,7 +5,7 @@
|
|||
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||
};
|
||||
nix-filter.url = "github:numtide/nix-filter";
|
||||
tgi-nix.url = "github:danieldk/tgi-nix";
|
||||
tgi-nix.url = "github:huggingface/text-generation-inference-nix";
|
||||
nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||
flake-utils.url = "github:numtide/flake-utils";
|
||||
rust-overlay = {
|
||||
|
@ -132,49 +132,12 @@
|
|||
pre-commit
|
||||
ruff
|
||||
]);
|
||||
|
||||
};
|
||||
|
||||
impure = mkShell {
|
||||
buildInputs =
|
||||
[
|
||||
openssl.dev
|
||||
pkg-config
|
||||
(rust-bin.stable.latest.default.override {
|
||||
extensions = [
|
||||
"rust-analyzer"
|
||||
"rust-src"
|
||||
];
|
||||
})
|
||||
protobuf
|
||||
]
|
||||
++ (with python3.pkgs; [
|
||||
venvShellHook
|
||||
docker
|
||||
pip
|
||||
ipdb
|
||||
click
|
||||
pyright
|
||||
pytest
|
||||
pytest-asyncio
|
||||
redocly
|
||||
ruff
|
||||
syrupy
|
||||
]);
|
||||
impure = callPackage ./nix/impure-shell.nix { inherit server; };
|
||||
|
||||
inputsFrom = [ server ];
|
||||
|
||||
venvDir = "./.venv";
|
||||
|
||||
postVenvCreation = ''
|
||||
unset SOURCE_DATE_EPOCH
|
||||
( cd server ; python -m pip install --no-dependencies -e . )
|
||||
( cd clients/python ; python -m pip install --no-dependencies -e . )
|
||||
'';
|
||||
postShellHook = ''
|
||||
unset SOURCE_DATE_EPOCH
|
||||
export PATH=$PATH:~/.cargo/bin
|
||||
'';
|
||||
impure-flash-attn-v1 = callPackage ./nix/impure-shell.nix {
|
||||
server = server.override { flash-attn = python3.pkgs.flash-attn-v1; };
|
||||
};
|
||||
};
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ ctrlc = { version = "3.4.1", features = ["termination"] }
|
|||
hf-hub = "0.3.2"
|
||||
nix = { version = "0.28.0", features = ["signal"] }
|
||||
once_cell = "1.19.0"
|
||||
pyo3 = { workspace = true }
|
||||
serde = { version = "1.0.188", features = ["derive"] }
|
||||
serde_json = "1.0.107"
|
||||
thiserror = "1.0.59"
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
use std::sync::LazyLock;
|
||||
|
||||
pub static COMPUTE_CAPABILITY: LazyLock<Option<(usize, usize)>> =
|
||||
LazyLock::new(get_cuda_capability);
|
||||
|
||||
fn get_cuda_capability() -> Option<(usize, usize)> {
|
||||
use pyo3::prelude::*;
|
||||
|
||||
let py_get_capability = |py: Python| -> PyResult<(isize, isize)> {
|
||||
let torch = py.import_bound("torch.cuda")?;
|
||||
let get_device_capability = torch.getattr("get_device_capability")?;
|
||||
get_device_capability.call0()?.extract()
|
||||
};
|
||||
|
||||
match pyo3::Python::with_gil(py_get_capability) {
|
||||
Ok((major, minor)) if major < 0 || minor < 0 => {
|
||||
tracing::warn!("Ignoring negative GPU compute capabilities: {major}.{minor}");
|
||||
None
|
||||
}
|
||||
Ok((major, minor)) => Some((major as usize, minor as usize)),
|
||||
Err(err) => {
|
||||
tracing::warn!("Cannot determine GPU compute capability: {}", err);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
|
@ -26,6 +26,7 @@ use thiserror::Error;
|
|||
use tracing_subscriber::{filter::LevelFilter, EnvFilter};
|
||||
|
||||
mod env_runtime;
|
||||
mod gpu;
|
||||
|
||||
fn get_config(
|
||||
model_id: &str,
|
||||
|
@ -65,6 +66,7 @@ fn get_config(
|
|||
}
|
||||
|
||||
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
|
||||
let compute_capability = *gpu::COMPUTE_CAPABILITY;
|
||||
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
|
||||
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
|
||||
if let Some(config) = config {
|
||||
|
@ -77,6 +79,13 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
|||
prefix_caching = Some("0".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
let fallback_attention = if matches!(compute_capability, Some((major, _)) if major < 8) {
|
||||
"paged"
|
||||
} else {
|
||||
"flashdecoding"
|
||||
};
|
||||
|
||||
match config.head_dim {
|
||||
Some(h) if h == 64 || h == 128 || h == 256 => {
|
||||
if lora_adapters.is_some() && prefix_caching.is_none() {
|
||||
|
@ -89,10 +98,14 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
|||
// flashinfer ?
|
||||
if attention.is_none() {
|
||||
tracing::info!(
|
||||
"Forcing flash decoding because model {} requires it",
|
||||
"Forcing attention to '{fallback_attention}' because model {} requires it",
|
||||
config.model_type.as_ref().unwrap()
|
||||
);
|
||||
attention = Some("flashdecoding".to_string());
|
||||
attention = Some(fallback_attention.to_string());
|
||||
}
|
||||
if fallback_attention == "paged" && prefix_caching.is_none() {
|
||||
tracing::info!("Disabling prefix caching because it is not supported with 'paged' attention");
|
||||
prefix_caching = Some("0".to_string());
|
||||
}
|
||||
}
|
||||
Some("t5") => {}
|
||||
|
@ -101,8 +114,8 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
|||
}
|
||||
_ => {
|
||||
if attention.is_none() {
|
||||
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching");
|
||||
attention = Some("flashdecoding".to_string());
|
||||
tracing::info!("Forcing attention to '{fallback_attention}' because head dim is not supported by flashinfer, also disabling prefix caching");
|
||||
attention = Some(fallback_attention.to_string());
|
||||
}
|
||||
if prefix_caching.is_none() {
|
||||
prefix_caching = Some("0".to_string());
|
||||
|
@ -110,8 +123,10 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
|||
}
|
||||
}
|
||||
}
|
||||
let prefix_caching = prefix_caching.unwrap_or("true".to_string());
|
||||
|
||||
let attention = attention.unwrap_or("flashinfer".to_string());
|
||||
let prefix_caching = prefix_caching.unwrap_or("true".to_string());
|
||||
|
||||
(prefix_caching, attention)
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
{
|
||||
mkShell,
|
||||
openssl,
|
||||
pkg-config,
|
||||
protobuf,
|
||||
python3,
|
||||
pyright,
|
||||
redocly,
|
||||
ruff,
|
||||
rust-bin,
|
||||
server,
|
||||
}:
|
||||
|
||||
mkShell {
|
||||
buildInputs =
|
||||
[
|
||||
openssl.dev
|
||||
pkg-config
|
||||
(rust-bin.stable.latest.default.override {
|
||||
extensions = [
|
||||
"rust-analyzer"
|
||||
"rust-src"
|
||||
];
|
||||
})
|
||||
protobuf
|
||||
pyright
|
||||
redocly
|
||||
ruff
|
||||
]
|
||||
++ (with python3.pkgs; [
|
||||
venvShellHook
|
||||
docker
|
||||
pip
|
||||
ipdb
|
||||
click
|
||||
pytest
|
||||
pytest-asyncio
|
||||
syrupy
|
||||
]);
|
||||
|
||||
inputsFrom = [ server ];
|
||||
|
||||
venvDir = "./.venv";
|
||||
|
||||
postVenvCreation = ''
|
||||
unset SOURCE_DATE_EPOCH
|
||||
( cd server ; python -m pip install --no-dependencies -e . )
|
||||
( cd clients/python ; python -m pip install --no-dependencies -e . )
|
||||
'';
|
||||
postShellHook = ''
|
||||
unset SOURCE_DATE_EPOCH
|
||||
export PATH=$PATH:~/.cargo/bin
|
||||
'';
|
||||
}
|
|
@ -13,6 +13,7 @@
|
|||
flash-attn,
|
||||
flash-attn-layer-norm,
|
||||
flash-attn-rotary,
|
||||
flash-attn-v1,
|
||||
grpc-interceptor,
|
||||
grpcio-reflection,
|
||||
grpcio-status,
|
||||
|
|
|
@ -61,7 +61,7 @@ uuid = { version = "1.9.1", default-features = false, features = [
|
|||
] }
|
||||
csv = "1.3.0"
|
||||
ureq = "=2.9"
|
||||
pyo3 = { version = "0.22.2", features = ["auto-initialize"] }
|
||||
pyo3 = { workspace = true }
|
||||
|
||||
|
||||
[build-dependencies]
|
||||
|
|
|
@ -18,16 +18,16 @@ elif SYSTEM == "rocm":
|
|||
attention,
|
||||
paged_attention,
|
||||
reshape_and_cache,
|
||||
SUPPORTS_WINDOWING,
|
||||
PREFILL_IN_KV_CACHE,
|
||||
SUPPORTS_WINDOWING,
|
||||
)
|
||||
elif SYSTEM == "ipex":
|
||||
from .ipex import (
|
||||
attention,
|
||||
paged_attention,
|
||||
reshape_and_cache,
|
||||
SUPPORTS_WINDOWING,
|
||||
PREFILL_IN_KV_CACHE,
|
||||
SUPPORTS_WINDOWING,
|
||||
)
|
||||
else:
|
||||
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
|
||||
|
@ -37,7 +37,7 @@ __all__ = [
|
|||
"attention",
|
||||
"paged_attention",
|
||||
"reshape_and_cache",
|
||||
"SUPPORTS_WINDOWING",
|
||||
"PREFILL_IN_KV_CACHE",
|
||||
"SUPPORTS_WINDOWING",
|
||||
"Seqlen",
|
||||
]
|
||||
|
|
|
@ -287,16 +287,14 @@ elif V2:
|
|||
else:
|
||||
|
||||
def attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
causal=None,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
block_tables: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
window_size_left: int = -1,
|
||||
causal: bool = True,
|
||||
softcap=None,
|
||||
):
|
||||
if window_size_left != -1:
|
||||
|
@ -338,14 +336,14 @@ else:
|
|||
k,
|
||||
v,
|
||||
out,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
max_s,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.max_q,
|
||||
seqlen.max_k,
|
||||
0.0,
|
||||
softmax_scale,
|
||||
False,
|
||||
True,
|
||||
causal,
|
||||
False,
|
||||
0,
|
||||
None,
|
||||
|
|
|
@ -215,7 +215,6 @@ if ENGINE != "triton":
|
|||
"or install flash attention with `cd server && make install install-flash-attention`"
|
||||
) from e
|
||||
else:
|
||||
|
||||
for idx in range(torch.cuda.device_count()):
|
||||
name = torch.cuda.get_device_name(idx)
|
||||
if "MI210" not in name and "MI250" not in name:
|
||||
|
|
|
@ -18,13 +18,13 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from typing import Optional, List, Tuple
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
|
|
|
@ -27,8 +27,8 @@ import torch.distributed
|
|||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
|
|
Loading…
Reference in New Issue