Add basic FP8 KV cache support (#2603)

* Add basic FP8 KV cache support

This change adds rudimentary FP8 KV cache support. The support is
enabled by passing `--kv-cache-dtype fp8_e5m2` to the launcher. Doing so
uses this type for the KV cache. However support is still limited:

* Only the `fp8_e5m2` type is supported.
* The KV cache layout is the same as `float16`/`bfloat16` (HND).
* The FP8 KV cache is only supported for FlashInfer.
* Loading of scales is not yet supported.

* Fix Cargo.toml
This commit is contained in:
Daniël de Kok 2024-10-04 17:51:48 +02:00 committed by GitHub
parent 68103079f4
commit 2358c2bb54
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
33 changed files with 1015 additions and 192 deletions

14
Cargo.lock generated
View File

@ -4177,7 +4177,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-backends-trtllm" name = "text-generation-backends-trtllm"
version = "2.3.1-dev0" version = "2.3.2-dev0"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"async-trait", "async-trait",
@ -4200,7 +4200,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-benchmark" name = "text-generation-benchmark"
version = "2.3.1-dev0" version = "2.3.2-dev0"
dependencies = [ dependencies = [
"average", "average",
"clap 4.5.18", "clap 4.5.18",
@ -4220,7 +4220,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-client" name = "text-generation-client"
version = "2.3.1-dev0" version = "2.3.2-dev0"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"base64 0.22.1", "base64 0.22.1",
@ -4238,7 +4238,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-launcher" name = "text-generation-launcher"
version = "2.3.1-dev0" version = "2.3.2-dev0"
dependencies = [ dependencies = [
"clap 4.5.18", "clap 4.5.18",
"ctrlc", "ctrlc",
@ -4259,7 +4259,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-router" name = "text-generation-router"
version = "2.3.1-dev0" version = "2.3.2-dev0"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"async-trait", "async-trait",
@ -4308,7 +4308,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-router-v2" name = "text-generation-router-v2"
version = "2.3.1-dev0" version = "2.3.2-dev0"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"async-trait", "async-trait",
@ -4357,7 +4357,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-router-v3" name = "text-generation-router-v3"
version = "2.3.1-dev0" version = "2.3.2-dev0"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"async-trait", "async-trait",

View File

@ -89,6 +89,15 @@ Options:
[env: DTYPE=] [env: DTYPE=]
[possible values: float16, bfloat16] [possible values: float16, bfloat16]
```
## KV_CACHE_DTYPE
```shell
--kv-cache-dtype <KV_CACHE_DTYPE>
Specify the dtype for the key-value cache. When this option is not provided, the dtype of the model is used (typically `float16` or `bfloat16`). Currently the only supported value is `fp8_e5m2` on CUDA
[env: KV_CACHE_DTYPE=]
[possible values: fp8_e5m2]
``` ```
## TRUST_REMOTE_CODE ## TRUST_REMOTE_CODE
```shell ```shell

View File

@ -336,6 +336,7 @@ def launcher(event_loop):
use_flash_attention: bool = True, use_flash_attention: bool = True,
disable_grammar_support: bool = False, disable_grammar_support: bool = False,
dtype: Optional[str] = None, dtype: Optional[str] = None,
kv_cache_dtype: Optional[str] = None,
revision: Optional[str] = None, revision: Optional[str] = None,
max_input_length: Optional[int] = None, max_input_length: Optional[int] = None,
max_batch_prefill_tokens: Optional[int] = None, max_batch_prefill_tokens: Optional[int] = None,
@ -375,6 +376,9 @@ def launcher(event_loop):
if dtype is not None: if dtype is not None:
args.append("--dtype") args.append("--dtype")
args.append(dtype) args.append(dtype)
if kv_cache_dtype is not None:
args.append("--kv-cache-dtype")
args.append(kv_cache_dtype)
if revision is not None: if revision is not None:
args.append("--revision") args.append("--revision")
args.append(revision) args.append(revision)
@ -434,6 +438,7 @@ def launcher(event_loop):
use_flash_attention: bool = True, use_flash_attention: bool = True,
disable_grammar_support: bool = False, disable_grammar_support: bool = False,
dtype: Optional[str] = None, dtype: Optional[str] = None,
kv_cache_dtype: Optional[str] = None,
revision: Optional[str] = None, revision: Optional[str] = None,
max_input_length: Optional[int] = None, max_input_length: Optional[int] = None,
max_batch_prefill_tokens: Optional[int] = None, max_batch_prefill_tokens: Optional[int] = None,
@ -456,6 +461,9 @@ def launcher(event_loop):
if dtype is not None: if dtype is not None:
args.append("--dtype") args.append("--dtype")
args.append(dtype) args.append(dtype)
if kv_cache_dtype is not None:
args.append("--kv-cache-dtype")
args.append(kv_cache_dtype)
if revision is not None: if revision is not None:
args.append("--revision") args.append("--revision")
args.append(revision) args.append(revision)
@ -589,7 +597,6 @@ def generate_multi():
max_new_tokens: int, max_new_tokens: int,
seed: Optional[int] = None, seed: Optional[int] = None,
) -> List[Response]: ) -> List[Response]:
import numpy as np import numpy as np
arange = np.arange(len(prompts)) arange = np.arange(len(prompts))

View File

@ -0,0 +1,104 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 3923,
"logprob": -5.6328125,
"text": "What"
},
{
"id": 374,
"logprob": -1.2265625,
"text": " is"
},
{
"id": 5655,
"logprob": -9.1015625,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.8085938,
"text": " learning"
},
{
"id": 30,
"logprob": -1.0439453,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 18682,
"logprob": -2.1992188,
"special": false,
"text": " Deep"
},
{
"id": 6975,
"logprob": -0.079956055,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.2763672,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.37548828,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -1.4628906,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.02885437,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.2565918,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0063438416,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -1.3056641,
"special": false,
"text": " that"
},
{
"id": 374,
"logprob": -1.6035156,
"special": false,
"text": " is"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that is"
}

View File

@ -0,0 +1,57 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 3,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 374,
"logprob": -22.96875,
"text": " is"
},
{
"id": 5655,
"logprob": -10.71875,
"text": " deep"
},
{
"id": 6975,
"logprob": -2.6992188,
"text": " learning"
},
{
"id": 30,
"logprob": -4.8398438,
"text": "?"
}
],
"seed": 0,
"tokens": [
{
"id": 720,
"logprob": -0.4411621,
"special": false,
"text": " \n"
},
{
"id": 220,
"logprob": -0.35864258,
"special": false,
"text": " "
},
{
"id": 128001,
"logprob": 0.0,
"special": true,
"text": "<|end_of_text|>"
}
],
"top_tokens": null
},
"generated_text": "What is deep learning? \n "
}

View File

@ -0,0 +1,418 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 3923,
"logprob": -5.6328125,
"text": "What"
},
{
"id": 374,
"logprob": -1.2265625,
"text": " is"
},
{
"id": 5655,
"logprob": -9.1015625,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.8085938,
"text": " learning"
},
{
"id": 30,
"logprob": -1.0439453,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 18682,
"logprob": -2.1992188,
"special": false,
"text": " Deep"
},
{
"id": 6975,
"logprob": -0.07897949,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.27734375,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.37402344,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -1.4511719,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.02909851,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.25854492,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0061798096,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -1.3046875,
"special": false,
"text": " that"
},
{
"id": 374,
"logprob": -1.5537109,
"special": false,
"text": " is"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that is"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 3923,
"logprob": -5.6328125,
"text": "What"
},
{
"id": 374,
"logprob": -1.2265625,
"text": " is"
},
{
"id": 5655,
"logprob": -9.1015625,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.8085938,
"text": " learning"
},
{
"id": 30,
"logprob": -1.0439453,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 18682,
"logprob": -2.1992188,
"special": false,
"text": " Deep"
},
{
"id": 6975,
"logprob": -0.07897949,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.27734375,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.37402344,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -1.4511719,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.02909851,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.25854492,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0061798096,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -1.3046875,
"special": false,
"text": " that"
},
{
"id": 374,
"logprob": -1.5537109,
"special": false,
"text": " is"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that is"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 3923,
"logprob": -5.6328125,
"text": "What"
},
{
"id": 374,
"logprob": -1.2265625,
"text": " is"
},
{
"id": 5655,
"logprob": -9.1015625,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.8085938,
"text": " learning"
},
{
"id": 30,
"logprob": -1.0439453,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 18682,
"logprob": -2.1992188,
"special": false,
"text": " Deep"
},
{
"id": 6975,
"logprob": -0.07897949,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.27734375,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.37402344,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -1.4511719,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.02909851,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.25854492,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0061798096,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -1.3046875,
"special": false,
"text": " that"
},
{
"id": 374,
"logprob": -1.5537109,
"special": false,
"text": " is"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that is"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 3923,
"logprob": -5.6328125,
"text": "What"
},
{
"id": 374,
"logprob": -1.2265625,
"text": " is"
},
{
"id": 5655,
"logprob": -9.1015625,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.8085938,
"text": " learning"
},
{
"id": 30,
"logprob": -1.0439453,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 18682,
"logprob": -2.1992188,
"special": false,
"text": " Deep"
},
{
"id": 6975,
"logprob": -0.07897949,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.27734375,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.37402344,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -1.4511719,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.02909851,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.25854492,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0061798096,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -1.3046875,
"special": false,
"text": " that"
},
{
"id": 374,
"logprob": -1.5537109,
"special": false,
"text": " is"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that is"
}
]

View File

@ -0,0 +1,77 @@
import pytest
@pytest.fixture(scope="module")
def flash_llama_fp8_kv_cache_handle(launcher):
with launcher(
"meta-llama/Meta-Llama-3-8B", num_shard=2, kv_cache_dtype="fp8_e5m2"
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_fp8_kv_cache(flash_llama_fp8_kv_cache_handle):
await flash_llama_fp8_kv_cache_handle.health(300)
return flash_llama_fp8_kv_cache_handle.client
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_fp8_kv_cache(flash_llama_fp8_kv_cache, response_snapshot):
response = await flash_llama_fp8_kv_cache.generate(
"What is deep learning?", max_new_tokens=10, decoder_input_details=True
)
assert (
response.generated_text
== " Deep learning is a subset of machine learning that is"
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_fp8_kv_cache_all_params(
flash_llama_fp8_kv_cache, response_snapshot
):
response = await flash_llama_fp8_kv_cache.generate(
"What is deep learning?",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_fp8_kv_cache_load(
flash_llama_fp8_kv_cache, generate_load, response_snapshot
):
responses = await generate_load(
flash_llama_fp8_kv_cache, "What is deep learning?", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert (
responses[0].generated_text
== " Deep learning is a subset of machine learning that is"
)
assert all(
[r.generated_text == responses[0].generated_text for r in responses]
), f"Different messages : {[r.generated_text for r in responses]}"
assert responses == response_snapshot

View File

@ -301,6 +301,22 @@ impl std::fmt::Display for Dtype {
} }
} }
#[derive(Clone, Copy, Debug, ValueEnum)]
enum KVCacheDtype {
#[clap(name = "fp8_e5m2")]
Fp8e5m2,
}
impl std::fmt::Display for KVCacheDtype {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
KVCacheDtype::Fp8e5m2 => {
write!(f, "fp8_e5m2")
}
}
}
}
#[derive(Clone, Copy, Debug, ValueEnum)] #[derive(Clone, Copy, Debug, ValueEnum)]
enum RopeScaling { enum RopeScaling {
Linear, Linear,
@ -402,6 +418,12 @@ struct Args {
#[clap(long, env, value_enum)] #[clap(long, env, value_enum)]
dtype: Option<Dtype>, dtype: Option<Dtype>,
/// Specify the dtype for the key-value cache. When this option is not provided,
/// the dtype of the model is used (typically `float16` or `bfloat16`). Currently
/// the only supported value is `fp8_e5m2` on CUDA.
#[clap(long, env, value_enum)]
kv_cache_dtype: Option<KVCacheDtype>,
/// Whether you want to execute hub modelling code. Explicitly passing a `revision` is /// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
/// encouraged when loading a model with custom code to ensure no malicious code has been /// encouraged when loading a model with custom code to ensure no malicious code has been
/// contributed in a newer revision. /// contributed in a newer revision.
@ -670,6 +692,7 @@ fn shard_manager(
quantize: Option<Quantization>, quantize: Option<Quantization>,
speculate: Option<usize>, speculate: Option<usize>,
dtype: Option<Dtype>, dtype: Option<Dtype>,
kv_cache_dtype: Option<KVCacheDtype>,
trust_remote_code: bool, trust_remote_code: bool,
uds_path: String, uds_path: String,
rank: usize, rank: usize,
@ -743,6 +766,11 @@ fn shard_manager(
shard_args.push(dtype.to_string()) shard_args.push(dtype.to_string())
} }
if let Some(kv_cache_dtype) = kv_cache_dtype {
shard_args.push("--kv-cache-dtype".to_string());
shard_args.push(kv_cache_dtype.to_string())
}
// Model optional revision // Model optional revision
if let Some(revision) = revision { if let Some(revision) = revision {
shard_args.push("--revision".to_string()); shard_args.push("--revision".to_string());
@ -1299,6 +1327,7 @@ fn spawn_shards(
let otlp_service_name = args.otlp_service_name.clone(); let otlp_service_name = args.otlp_service_name.clone();
let speculate = args.speculate; let speculate = args.speculate;
let dtype = args.dtype; let dtype = args.dtype;
let kv_cache_dtype = args.kv_cache_dtype;
let trust_remote_code = args.trust_remote_code; let trust_remote_code = args.trust_remote_code;
let master_port = args.master_port; let master_port = args.master_port;
let disable_custom_kernels = args.disable_custom_kernels; let disable_custom_kernels = args.disable_custom_kernels;
@ -1317,6 +1346,7 @@ fn spawn_shards(
quantize, quantize,
speculate, speculate,
dtype, dtype,
kv_cache_dtype,
trust_remote_code, trust_remote_code,
uds_path, uds_path,
rank, rank,

View File

@ -30,6 +30,10 @@ class Dtype(str, Enum):
bloat16 = "bfloat16" bloat16 = "bfloat16"
class KVCacheDtype(str, Enum):
fp8_e5m2 = "fp8_e5m2"
@app.command() @app.command()
def serve( def serve(
model_id: str, model_id: str,
@ -38,6 +42,7 @@ def serve(
quantize: Optional[Quantization] = None, quantize: Optional[Quantization] = None,
speculate: Optional[int] = None, speculate: Optional[int] = None,
dtype: Optional[Dtype] = None, dtype: Optional[Dtype] = None,
kv_cache_dtype: Optional[KVCacheDtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
uds_path: Path = "/tmp/text-generation-server", uds_path: Path = "/tmp/text-generation-server",
logger_level: str = "INFO", logger_level: str = "INFO",
@ -97,6 +102,7 @@ def serve(
# Downgrade enum into str for easier management later on # Downgrade enum into str for easier management later on
quantize = None if quantize is None else quantize.value quantize = None if quantize is None else quantize.value
dtype = None if dtype is None else dtype.value dtype = None if dtype is None else dtype.value
kv_cache_dtype = None if kv_cache_dtype is None else kv_cache_dtype.value
if dtype is not None and quantize not in { if dtype is not None and quantize not in {
None, None,
"bitsandbytes", "bitsandbytes",
@ -114,6 +120,7 @@ def serve(
quantize, quantize,
speculate, speculate,
dtype, dtype,
kv_cache_dtype,
trust_remote_code, trust_remote_code,
uds_path, uds_path,
max_input_tokens, max_input_tokens,

View File

@ -1,37 +1,40 @@
from text_generation_server.utils.import_utils import SYSTEM
import os import os
from text_generation_server.utils.import_utils import SYSTEM
from .common import Seqlen from .common import Seqlen
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.") raise ImportError("`USE_FLASH_ATTENTION` is false.")
if SYSTEM == "cuda": if SYSTEM == "cuda":
from .cuda import ( from .cuda import (
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
attention, attention,
paged_attention, paged_attention,
reshape_and_cache, reshape_and_cache,
SUPPORTS_WINDOWING,
PREFILL_IN_KV_CACHE,
) )
elif SYSTEM == "rocm": elif SYSTEM == "rocm":
from .rocm import ( from .rocm import (
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
attention, attention,
paged_attention, paged_attention,
reshape_and_cache, reshape_and_cache,
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
) )
elif SYSTEM == "ipex": elif SYSTEM == "ipex":
from .ipex import ( from .ipex import (
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
attention, attention,
paged_attention, paged_attention,
reshape_and_cache, reshape_and_cache,
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
) )
else: else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
from .kv_cache import KVCache
__all__ = [ __all__ = [
"attention", "attention",
@ -39,5 +42,6 @@ __all__ = [
"reshape_and_cache", "reshape_and_cache",
"PREFILL_IN_KV_CACHE", "PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING", "SUPPORTS_WINDOWING",
"KVCache",
"Seqlen", "Seqlen",
] ]

View File

@ -355,3 +355,11 @@ else:
# have a configuration that requires flash-attention v1, which # have a configuration that requires flash-attention v1, which
# does not support block tables. # does not support block tables.
PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2 PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2
__all__ = [
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
"reshape_and_cache",
]

View File

@ -80,3 +80,12 @@ def paged_attention(
None, None,
) )
return out return out
__all__ = [
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
"reshape_and_cache",
]

View File

@ -0,0 +1,121 @@
from typing import Tuple
import torch
from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import reshape_and_cache
class KVCache:
"""
Key-value cache for attention layers.
"""
kv_cache: Tuple[torch.Tensor, torch.Tensor]
def __init__(
self,
*,
num_blocks: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
):
"""Construct the key-value cache for a layer."""
if (
dtype == torch.float8_e5m2
and ATTENTION != "flashinfer"
and SYSTEM != "cuda"
):
raise ValueError(
"float8_e5m2 KV cache is currently only supported for flashinfer on CUDA"
)
element_size = torch.tensor([], dtype=dtype).element_size()
if SYSTEM == "ipex" and device.type == "xpu":
x = 1
else:
x = BLOCK_SIZE // element_size
if ATTENTION in {"flashdecoding", "flashinfer"}:
self.kv_cache = (
torch.empty(
(num_blocks, BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
torch.empty(
(num_blocks, BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
)
elif SYSTEM == "ipex" and device == torch.device("cpu"):
self.kv_cache = (
torch.empty(
(num_blocks, num_heads, BLOCK_SIZE, head_size),
dtype=dtype,
device=device,
),
torch.empty(
(num_blocks, num_heads, BLOCK_SIZE, head_size),
dtype=dtype,
device=device,
),
)
else:
self.kv_cache = (
torch.zeros(
(num_blocks, num_heads, head_size // x, BLOCK_SIZE, x),
dtype=dtype,
device=device,
),
torch.zeros(
(num_blocks, num_heads, head_size, BLOCK_SIZE),
dtype=dtype,
device=device,
),
)
@property
def key(self):
"""Get the key cache."""
return self.kv_cache[0]
@property
def value(self):
"""Get the value cache."""
return self.kv_cache[1]
def store(
self,
*,
key: torch.Tensor,
value: torch.Tensor,
slots: torch.Tensor,
):
"""Store the key and value at the given slots."""
key_cache = self.kv_cache[0]
value_cache = self.kv_cache[1]
if ATTENTION in {"flashdecoding", "flashinfer"}:
# TODO: add scale
key = key.to(key_cache.dtype)
value = value.to(value_cache.dtype)
if key_cache.dtype == torch.float8_e5m2:
# Torch index_put does not support float8_e5m2 yet, so
# put as raw data instead.
key_cache = key_cache.view(torch.uint8)
value_cache = value_cache.view(torch.uint8)
key = key.view(torch.uint8)
value = value.view(torch.uint8)
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
reshape_and_cache(key, value, key_cache, value_cache, slots)

View File

@ -306,3 +306,11 @@ elif ENGINE == "triton":
else: else:
raise RuntimeError(f"Unknown attention engine {ENGINE}") raise RuntimeError(f"Unknown attention engine {ENGINE}")
__all__ = [
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
"reshape_and_cache",
]

View File

@ -342,6 +342,7 @@ def get_model(
quantize: Optional[str], quantize: Optional[str],
speculate: Optional[int], speculate: Optional[int],
dtype: Optional[str], dtype: Optional[str],
kv_cache_dtype: Optional[str],
trust_remote_code: bool, trust_remote_code: bool,
max_input_tokens: int, max_input_tokens: int,
) -> Model: ) -> Model:
@ -403,6 +404,13 @@ def get_model(
else: else:
raise RuntimeError(f"Unknown dtype {dtype}") raise RuntimeError(f"Unknown dtype {dtype}")
if kv_cache_dtype is None:
kv_cache_dtype = dtype
elif kv_cache_dtype == "fp8_e5m2":
kv_cache_dtype = torch.float8_e5m2
else:
raise RuntimeError(f"Unknown kv_cache_dtype: {kv_cache_dtype}")
if speculate is not None: if speculate is not None:
set_speculate(speculate) set_speculate(speculate)
else: else:
@ -563,6 +571,7 @@ def get_model(
speculator=speculator, speculator=speculator,
default_dtype=torch.bfloat16, default_dtype=torch.bfloat16,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
config_class=DeepseekV2Config, config_class=DeepseekV2Config,
@ -617,6 +626,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
aliases={"transformer.wte.weight": ["lm_head.weight"]}, aliases={"transformer.wte.weight": ["lm_head.weight"]},
@ -668,6 +678,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
@ -703,6 +714,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
@ -741,6 +753,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
config_class=GPTNeoXConfig, config_class=GPTNeoXConfig,
@ -774,6 +787,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
@ -797,6 +811,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
@ -836,6 +851,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
@ -859,6 +875,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Works better for these models # Works better for these models
default_dtype=torch.bfloat16, default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
@ -884,6 +901,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Works better for these models # Works better for these models
default_dtype=torch.bfloat16, default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
@ -910,6 +928,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
@ -934,6 +953,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Dbrx works better in bfloat16. # Dbrx works better in bfloat16.
default_dtype=torch.bfloat16, default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
@ -964,6 +984,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
aliases={ aliases={
"lm_head.weight": ["transformer.word_embeddings.weight"], "lm_head.weight": ["transformer.word_embeddings.weight"],
"transformer.word_embeddings.weight": ["lm_head.weight"], "transformer.word_embeddings.weight": ["lm_head.weight"],
@ -982,6 +1003,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
aliases={ aliases={
"lm_head.weight": ["transformer.word_embeddings.weight"], "lm_head.weight": ["transformer.word_embeddings.weight"],
"transformer.word_embeddings.weight": ["lm_head.weight"], "transformer.word_embeddings.weight": ["lm_head.weight"],
@ -1009,6 +1031,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
@ -1033,6 +1056,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
@ -1057,6 +1081,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
@ -1083,6 +1108,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
@ -1162,6 +1188,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
# XXX: Extremely important to cap resolution in order to limit # XXX: Extremely important to cap resolution in order to limit
@ -1179,6 +1206,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Works better for these models # Works better for these models
default_dtype=torch.bfloat16, default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
@ -1197,6 +1225,7 @@ def get_model(
quantize=quantize, quantize=quantize,
speculator=speculator, speculator=speculator,
dtype=dtype, dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
else: else:
@ -1269,6 +1298,7 @@ def get_model_with_lora_adapters(
quantize: Optional[str], quantize: Optional[str],
speculate: Optional[int], speculate: Optional[int],
dtype: Optional[str], dtype: Optional[str],
kv_cache_dtype: Optional[str],
trust_remote_code: bool, trust_remote_code: bool,
max_input_tokens: int, max_input_tokens: int,
adapter_to_index: Dict[str, int], adapter_to_index: Dict[str, int],
@ -1282,6 +1312,7 @@ def get_model_with_lora_adapters(
quantize, quantize,
speculate, speculate,
dtype, dtype,
kv_cache_dtype,
trust_remote_code, trust_remote_code,
max_input_tokens, max_input_tokens,
) )

View File

@ -28,7 +28,6 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
) )
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
@ -291,15 +290,15 @@ class FlashCohereAttention(torch.nn.Module):
self.rotary_emb(query, key, cos, sin) self.rotary_emb(query, key, cos, sin)
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) kv_cache.store(key=key, value=value, slots=slots)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PREFILL_IN_KV_CACHE else key, kv_cache.key if PREFILL_IN_KV_CACHE else key,
kv_cache[1] if PREFILL_IN_KV_CACHE else value, kv_cache.value if PREFILL_IN_KV_CACHE else value,
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -308,8 +307,8 @@ class FlashCohereAttention(torch.nn.Module):
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache.key,
kv_cache[1], kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,

View File

@ -28,7 +28,6 @@ if SYSTEM != "ipex":
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
PREFILL_IN_KV_CACHE, PREFILL_IN_KV_CACHE,
) )
@ -330,15 +329,15 @@ class DbrxAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -347,8 +346,8 @@ class DbrxAttention(torch.nn.Module):
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache.key,
kv_cache[1], kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,

View File

@ -33,7 +33,6 @@ from text_generation_server.layers.attention import (
Seqlen, Seqlen,
attention, attention,
paged_attention, paged_attention,
reshape_and_cache,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.layernorm import FastRMSNorm
@ -321,15 +320,15 @@ class DeepseekV2Attention(torch.nn.Module):
value, (0, self.head_pad_size - self.value_head_size), value=0 value, (0, self.head_pad_size - self.value_head_size), value=0
) )
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) kv_cache.store(key=key, value=value, slots=slots)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PREFILL_IN_KV_CACHE else key, kv_cache.key if PREFILL_IN_KV_CACHE else key,
kv_cache[1] if PREFILL_IN_KV_CACHE else value, kv_cache.value if PREFILL_IN_KV_CACHE else value,
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -338,8 +337,8 @@ class DeepseekV2Attention(torch.nn.Module):
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache.key,
kv_cache[1], kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,

View File

@ -28,7 +28,6 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
@ -253,15 +252,15 @@ class FlashGemma2Attention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -273,8 +272,8 @@ class FlashGemma2Attention(torch.nn.Module):
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache.key,
kv_cache[1], kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,

View File

@ -28,7 +28,6 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
PREFILL_IN_KV_CACHE, PREFILL_IN_KV_CACHE,
) )
@ -224,15 +223,15 @@ class FlashGemmaAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -242,8 +241,8 @@ class FlashGemmaAttention(torch.nn.Module):
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache.key,
kv_cache[1], kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,

View File

@ -28,7 +28,6 @@ from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
@ -224,15 +223,15 @@ class FlashGPT2Attention(torch.nn.Module):
key = key.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_heads, self.head_size)
value = value.view(-1, self.num_heads, self.head_size) value = value.view(-1, self.num_heads, self.head_size)
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) kv_cache.store(key=key, value=value, slots=slots)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PREFILL_IN_KV_CACHE else key, kv_cache.key if PREFILL_IN_KV_CACHE else key,
kv_cache[1] if PREFILL_IN_KV_CACHE else value, kv_cache.value if PREFILL_IN_KV_CACHE else value,
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -241,8 +240,8 @@ class FlashGPT2Attention(torch.nn.Module):
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache.key,
kv_cache[1], kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,

View File

@ -28,7 +28,6 @@ from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
@ -186,15 +185,15 @@ class FlashGPTJAttention(torch.nn.Module):
else: else:
self.rotary_emb(query, key, cos, sin) self.rotary_emb(query, key, cos, sin)
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) kv_cache.store(key=key, value=value, slots=slots)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PREFILL_IN_KV_CACHE else key, kv_cache.key if PREFILL_IN_KV_CACHE else key,
kv_cache[1] if PREFILL_IN_KV_CACHE else value, kv_cache.value if PREFILL_IN_KV_CACHE else value,
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -203,8 +202,8 @@ class FlashGPTJAttention(torch.nn.Module):
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache.key,
kv_cache[1], kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,

View File

@ -27,13 +27,12 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE, KVCache
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
@ -202,7 +201,7 @@ class FlashLlamaAttention(torch.nn.Module):
cos, cos,
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache: KVCache,
block_tables, block_tables,
slots, slots,
seqlen, seqlen,
@ -222,15 +221,15 @@ class FlashLlamaAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -239,8 +238,8 @@ class FlashLlamaAttention(torch.nn.Module):
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache.key,
kv_cache[1], kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,

View File

@ -30,7 +30,6 @@ from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
@ -210,17 +209,15 @@ class MistralAttention(torch.nn.Module):
else: else:
kv_to_cache = kv kv_to_cache = kv
reshape_and_cache( kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots)
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -230,8 +227,8 @@ class MistralAttention(torch.nn.Module):
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache.key,
kv_cache[1], kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,

View File

@ -37,7 +37,6 @@ from text_generation_server.layers.attention import (
Seqlen, Seqlen,
attention, attention,
paged_attention, paged_attention,
reshape_and_cache,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.layernorm import FastRMSNorm
@ -258,17 +257,15 @@ class MixtralAttention(torch.nn.Module):
else: else:
kv_to_cache = kv kv_to_cache = kv
reshape_and_cache( kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots)
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -278,8 +275,8 @@ class MixtralAttention(torch.nn.Module):
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache.key,
kv_cache[1], kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,

View File

@ -29,7 +29,6 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
@ -165,15 +164,15 @@ class FlashNeoxAttention(torch.nn.Module):
qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1) qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1)
qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1) qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1)
reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots) kv_cache.store(key=qkv[:, 1], value=qkv[:, 2], slots=slots)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
qkv[:, 0], qkv[:, 0],
kv_cache[0] if PREFILL_IN_KV_CACHE else qkv[:, 1], kv_cache.key if PREFILL_IN_KV_CACHE else qkv[:, 1],
kv_cache[1] if PREFILL_IN_KV_CACHE else qkv[:, 2], kv_cache.value if PREFILL_IN_KV_CACHE else qkv[:, 2],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -182,8 +181,8 @@ class FlashNeoxAttention(torch.nn.Module):
else: else:
attn_output = paged_attention( attn_output = paged_attention(
qkv[:, 0], qkv[:, 0],
kv_cache[0], kv_cache.key,
kv_cache[1], kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,

View File

@ -9,7 +9,6 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
@ -188,14 +187,14 @@ class FlashPhiAttention(torch.nn.Module):
) )
# Reshape key and value and cache # Reshape key and value and cache
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -204,8 +203,8 @@ class FlashPhiAttention(torch.nn.Module):
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache.key,
kv_cache[1], kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,

View File

@ -8,7 +8,6 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
@ -128,17 +127,15 @@ class Qwen2Attention(torch.nn.Module):
else: else:
kv_to_cache = kv kv_to_cache = kv
reshape_and_cache( kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots)
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -148,8 +145,8 @@ class Qwen2Attention(torch.nn.Module):
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache.key,
kv_cache[1], kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,

View File

@ -18,7 +18,6 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
attention, attention,
paged_attention, paged_attention,
reshape_and_cache,
Seqlen, Seqlen,
) )
@ -200,15 +199,15 @@ class FlashRWAttention(torch.nn.Module):
# Inplace rotary # Inplace rotary
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -217,8 +216,8 @@ class FlashRWAttention(torch.nn.Module):
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache.key,
kv_cache[1], kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
@ -312,12 +311,8 @@ class FlashRWLargeAttention(torch.nn.Module):
# Inplace rotary # Inplace rotary
self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin)
reshape_and_cache( kv_cache.store(
kv[:, :, 0].contiguous(), key=kv[:, :, 0].contiguous(), value=kv[:, :, 1].contiguous(), slots=slots
kv[:, :, 1].contiguous(),
kv_cache[0],
kv_cache[1],
slots,
) )
# Prefill # Prefill
@ -325,8 +320,8 @@ class FlashRWLargeAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, :, 0].contiguous(), kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, :, 0].contiguous(),
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, :, 1].contiguous(), kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, :, 1].contiguous(),
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -335,8 +330,8 @@ class FlashRWLargeAttention(torch.nn.Module):
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache.key,
kv_cache[1], kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,

View File

@ -8,7 +8,6 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
@ -284,17 +283,15 @@ class FlashMQAttention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
key_value = key_value.view(-1, 2, 1, self.head_size) key_value = key_value.view(-1, 2, 1, self.head_size)
reshape_and_cache( kv_cache.store(key=key_value[:, 0], value=key_value[:, 1], slots=slots)
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PREFILL_IN_KV_CACHE else key_value[:, 0], kv_cache.key if PREFILL_IN_KV_CACHE else key_value[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else key_value[:, 1], kv_cache.value if PREFILL_IN_KV_CACHE else key_value[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -303,8 +300,8 @@ class FlashMQAttention(torch.nn.Module):
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache.key,
kv_cache[1], kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,

View File

@ -29,7 +29,6 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
@ -233,17 +232,15 @@ class Starcoder2Attention(torch.nn.Module):
else: else:
kv_to_cache = kv kv_to_cache = kv
reshape_and_cache( kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots)
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -253,8 +250,8 @@ class Starcoder2Attention(torch.nn.Module):
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache.key,
kv_cache[1], kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,

View File

@ -46,7 +46,7 @@ from text_generation_server.models.globals import (
TGI_WIGGLE_ROOM, TGI_WIGGLE_ROOM,
get_adapter_to_index, get_adapter_to_index,
) )
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import KVCache, Seqlen
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION from text_generation_server.utils.dist import MEMORY_FRACTION
from text_generation_server.utils.quantization import get_loader from text_generation_server.utils.quantization import get_loader
@ -937,6 +937,7 @@ class FlashCausalLM(Model):
# Deepseek V2 uses different QK and V dims. # Deepseek V2 uses different QK and V dims.
head_size: Optional[int] = None, head_size: Optional[int] = None,
skip_special_tokens: bool = True, skip_special_tokens: bool = True,
kv_cache_dtype: Optional[torch.dtype] = None,
): ):
self.quantize = quantize self.quantize = quantize
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
@ -1034,6 +1035,7 @@ class FlashCausalLM(Model):
self.cuda_graphs = {} self.cuda_graphs = {}
self.kv_cache = [] self.kv_cache = []
self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
if ATTENTION == "flashinfer": if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flashinfer import ( from text_generation_server.layers.attention.flashinfer import (
@ -1083,58 +1085,13 @@ class FlashCausalLM(Model):
): ):
self.kv_cache = [] self.kv_cache = []
empty_cache() empty_cache()
element_size = torch.tensor([], dtype=dtype).element_size()
if SYSTEM == "ipex" and device.type == "xpu":
x = 1
else:
x = BLOCK_SIZE // element_size
if ATTENTION in {"flashdecoding", "flashinfer"}:
self.kv_cache = [ self.kv_cache = [
( KVCache(
torch.empty( num_blocks=num_blocks,
(num_blocks, BLOCK_SIZE, num_heads, head_size), num_heads=num_heads,
head_size=head_size,
dtype=dtype, dtype=dtype,
device=device, device=device,
),
torch.empty(
(num_blocks, BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
)
for _ in range(num_layers)
]
elif SYSTEM == "ipex" and device == torch.device("cpu"):
self.kv_cache = [
(
torch.empty(
(num_blocks, num_heads, BLOCK_SIZE, head_size),
dtype=dtype,
device=device,
),
torch.empty(
(num_blocks, num_heads, BLOCK_SIZE, head_size),
dtype=dtype,
device=device,
),
)
for _ in range(num_layers)
]
else:
self.kv_cache = [
(
torch.zeros(
(num_blocks, num_heads, head_size // x, BLOCK_SIZE, x),
dtype=dtype,
device=device,
),
torch.zeros(
(num_blocks, num_heads, head_size, BLOCK_SIZE),
dtype=dtype,
device=device,
),
) )
for _ in range(num_layers) for _ in range(num_layers)
] ]
@ -1258,7 +1215,7 @@ class FlashCausalLM(Model):
self.num_layers, self.num_layers,
self.num_kv_heads, self.num_kv_heads,
self.head_size, self.head_size,
self.dtype, self.kv_cache_dtype,
self.device, self.device,
) )
max_bt = batch.max_blocks max_bt = batch.max_blocks
@ -1277,7 +1234,7 @@ class FlashCausalLM(Model):
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the free memory # Calculate the number of blocks that can be allocated with the free memory
dtype_size = torch.tensor([], dtype=self.dtype).element_size() dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
@ -1291,6 +1248,8 @@ class FlashCausalLM(Model):
+ batch_num_blocks + batch_num_blocks
) )
log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}")
del batch del batch
self.init_kv_cache( self.init_kv_cache(
@ -1298,7 +1257,7 @@ class FlashCausalLM(Model):
self.num_layers, self.num_layers,
self.num_kv_heads, self.num_kv_heads,
self.head_size, self.head_size,
self.dtype, self.kv_cache_dtype,
self.device, self.device,
) )

View File

@ -205,6 +205,7 @@ def serve(
quantize: Optional[str], quantize: Optional[str],
speculate: Optional[int], speculate: Optional[int],
dtype: Optional[str], dtype: Optional[str],
kv_cache_dtype: Optional[str],
trust_remote_code: bool, trust_remote_code: bool,
uds_path: Path, uds_path: Path,
max_input_tokens: int, max_input_tokens: int,
@ -217,6 +218,7 @@ def serve(
quantize: Optional[str] = None, quantize: Optional[str] = None,
speculate: Optional[int] = None, speculate: Optional[int] = None,
dtype: Optional[str] = None, dtype: Optional[str] = None,
kv_cache_dtype: Optional[str] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
unix_socket_template = "unix://{}-{}" unix_socket_template = "unix://{}-{}"
@ -240,6 +242,7 @@ def serve(
quantize, quantize,
speculate, speculate,
dtype, dtype,
kv_cache_dtype,
trust_remote_code, trust_remote_code,
max_input_tokens, max_input_tokens,
adapter_to_index, adapter_to_index,
@ -286,6 +289,7 @@ def serve(
quantize, quantize,
speculate, speculate,
dtype, dtype,
kv_cache_dtype,
trust_remote_code, trust_remote_code,
) )
) )