Add basic FP8 KV cache support (#2603)
* Add basic FP8 KV cache support This change adds rudimentary FP8 KV cache support. The support is enabled by passing `--kv-cache-dtype fp8_e5m2` to the launcher. Doing so uses this type for the KV cache. However support is still limited: * Only the `fp8_e5m2` type is supported. * The KV cache layout is the same as `float16`/`bfloat16` (HND). * The FP8 KV cache is only supported for FlashInfer. * Loading of scales is not yet supported. * Fix Cargo.toml
This commit is contained in:
parent
68103079f4
commit
2358c2bb54
|
@ -4177,7 +4177,7 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "text-generation-backends-trtllm"
|
||||
version = "2.3.1-dev0"
|
||||
version = "2.3.2-dev0"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
|
@ -4200,7 +4200,7 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "text-generation-benchmark"
|
||||
version = "2.3.1-dev0"
|
||||
version = "2.3.2-dev0"
|
||||
dependencies = [
|
||||
"average",
|
||||
"clap 4.5.18",
|
||||
|
@ -4220,7 +4220,7 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "text-generation-client"
|
||||
version = "2.3.1-dev0"
|
||||
version = "2.3.2-dev0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"base64 0.22.1",
|
||||
|
@ -4238,7 +4238,7 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "text-generation-launcher"
|
||||
version = "2.3.1-dev0"
|
||||
version = "2.3.2-dev0"
|
||||
dependencies = [
|
||||
"clap 4.5.18",
|
||||
"ctrlc",
|
||||
|
@ -4259,7 +4259,7 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "text-generation-router"
|
||||
version = "2.3.1-dev0"
|
||||
version = "2.3.2-dev0"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
|
@ -4308,7 +4308,7 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "text-generation-router-v2"
|
||||
version = "2.3.1-dev0"
|
||||
version = "2.3.2-dev0"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
|
@ -4357,7 +4357,7 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "text-generation-router-v3"
|
||||
version = "2.3.1-dev0"
|
||||
version = "2.3.2-dev0"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
|
|
|
@ -89,6 +89,15 @@ Options:
|
|||
[env: DTYPE=]
|
||||
[possible values: float16, bfloat16]
|
||||
|
||||
```
|
||||
## KV_CACHE_DTYPE
|
||||
```shell
|
||||
--kv-cache-dtype <KV_CACHE_DTYPE>
|
||||
Specify the dtype for the key-value cache. When this option is not provided, the dtype of the model is used (typically `float16` or `bfloat16`). Currently the only supported value is `fp8_e5m2` on CUDA
|
||||
|
||||
[env: KV_CACHE_DTYPE=]
|
||||
[possible values: fp8_e5m2]
|
||||
|
||||
```
|
||||
## TRUST_REMOTE_CODE
|
||||
```shell
|
||||
|
|
|
@ -336,6 +336,7 @@ def launcher(event_loop):
|
|||
use_flash_attention: bool = True,
|
||||
disable_grammar_support: bool = False,
|
||||
dtype: Optional[str] = None,
|
||||
kv_cache_dtype: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
max_input_length: Optional[int] = None,
|
||||
max_batch_prefill_tokens: Optional[int] = None,
|
||||
|
@ -375,6 +376,9 @@ def launcher(event_loop):
|
|||
if dtype is not None:
|
||||
args.append("--dtype")
|
||||
args.append(dtype)
|
||||
if kv_cache_dtype is not None:
|
||||
args.append("--kv-cache-dtype")
|
||||
args.append(kv_cache_dtype)
|
||||
if revision is not None:
|
||||
args.append("--revision")
|
||||
args.append(revision)
|
||||
|
@ -434,6 +438,7 @@ def launcher(event_loop):
|
|||
use_flash_attention: bool = True,
|
||||
disable_grammar_support: bool = False,
|
||||
dtype: Optional[str] = None,
|
||||
kv_cache_dtype: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
max_input_length: Optional[int] = None,
|
||||
max_batch_prefill_tokens: Optional[int] = None,
|
||||
|
@ -456,6 +461,9 @@ def launcher(event_loop):
|
|||
if dtype is not None:
|
||||
args.append("--dtype")
|
||||
args.append(dtype)
|
||||
if kv_cache_dtype is not None:
|
||||
args.append("--kv-cache-dtype")
|
||||
args.append(kv_cache_dtype)
|
||||
if revision is not None:
|
||||
args.append("--revision")
|
||||
args.append(revision)
|
||||
|
@ -589,7 +597,6 @@ def generate_multi():
|
|||
max_new_tokens: int,
|
||||
seed: Optional[int] = None,
|
||||
) -> List[Response]:
|
||||
|
||||
import numpy as np
|
||||
|
||||
arange = np.arange(len(prompts))
|
||||
|
|
|
@ -0,0 +1,104 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 128000,
|
||||
"logprob": null,
|
||||
"text": "<|begin_of_text|>"
|
||||
},
|
||||
{
|
||||
"id": 3923,
|
||||
"logprob": -5.6328125,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -1.2265625,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 5655,
|
||||
"logprob": -9.1015625,
|
||||
"text": " deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -1.8085938,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": -1.0439453,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 18682,
|
||||
"logprob": -2.1992188,
|
||||
"special": false,
|
||||
"text": " Deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.079956055,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -0.2763672,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.37548828,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 27084,
|
||||
"logprob": -1.4628906,
|
||||
"special": false,
|
||||
"text": " subset"
|
||||
},
|
||||
{
|
||||
"id": 315,
|
||||
"logprob": -0.02885437,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5780,
|
||||
"logprob": -0.2565918,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.0063438416,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 430,
|
||||
"logprob": -1.3056641,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -1.6035156,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": " Deep learning is a subset of machine learning that is"
|
||||
}
|
|
@ -0,0 +1,57 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "eos_token",
|
||||
"generated_tokens": 3,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 128000,
|
||||
"logprob": null,
|
||||
"text": "<|begin_of_text|>"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -22.96875,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 5655,
|
||||
"logprob": -10.71875,
|
||||
"text": " deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -2.6992188,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": -4.8398438,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": 0,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 720,
|
||||
"logprob": -0.4411621,
|
||||
"special": false,
|
||||
"text": " \n"
|
||||
},
|
||||
{
|
||||
"id": 220,
|
||||
"logprob": -0.35864258,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 128001,
|
||||
"logprob": 0.0,
|
||||
"special": true,
|
||||
"text": "<|end_of_text|>"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "What is deep learning? \n "
|
||||
}
|
|
@ -0,0 +1,418 @@
|
|||
[
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 128000,
|
||||
"logprob": null,
|
||||
"text": "<|begin_of_text|>"
|
||||
},
|
||||
{
|
||||
"id": 3923,
|
||||
"logprob": -5.6328125,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -1.2265625,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 5655,
|
||||
"logprob": -9.1015625,
|
||||
"text": " deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -1.8085938,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": -1.0439453,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 18682,
|
||||
"logprob": -2.1992188,
|
||||
"special": false,
|
||||
"text": " Deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.07897949,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -0.27734375,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.37402344,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 27084,
|
||||
"logprob": -1.4511719,
|
||||
"special": false,
|
||||
"text": " subset"
|
||||
},
|
||||
{
|
||||
"id": 315,
|
||||
"logprob": -0.02909851,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5780,
|
||||
"logprob": -0.25854492,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.0061798096,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 430,
|
||||
"logprob": -1.3046875,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -1.5537109,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": " Deep learning is a subset of machine learning that is"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 128000,
|
||||
"logprob": null,
|
||||
"text": "<|begin_of_text|>"
|
||||
},
|
||||
{
|
||||
"id": 3923,
|
||||
"logprob": -5.6328125,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -1.2265625,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 5655,
|
||||
"logprob": -9.1015625,
|
||||
"text": " deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -1.8085938,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": -1.0439453,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 18682,
|
||||
"logprob": -2.1992188,
|
||||
"special": false,
|
||||
"text": " Deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.07897949,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -0.27734375,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.37402344,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 27084,
|
||||
"logprob": -1.4511719,
|
||||
"special": false,
|
||||
"text": " subset"
|
||||
},
|
||||
{
|
||||
"id": 315,
|
||||
"logprob": -0.02909851,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5780,
|
||||
"logprob": -0.25854492,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.0061798096,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 430,
|
||||
"logprob": -1.3046875,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -1.5537109,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": " Deep learning is a subset of machine learning that is"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 128000,
|
||||
"logprob": null,
|
||||
"text": "<|begin_of_text|>"
|
||||
},
|
||||
{
|
||||
"id": 3923,
|
||||
"logprob": -5.6328125,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -1.2265625,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 5655,
|
||||
"logprob": -9.1015625,
|
||||
"text": " deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -1.8085938,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": -1.0439453,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 18682,
|
||||
"logprob": -2.1992188,
|
||||
"special": false,
|
||||
"text": " Deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.07897949,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -0.27734375,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.37402344,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 27084,
|
||||
"logprob": -1.4511719,
|
||||
"special": false,
|
||||
"text": " subset"
|
||||
},
|
||||
{
|
||||
"id": 315,
|
||||
"logprob": -0.02909851,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5780,
|
||||
"logprob": -0.25854492,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.0061798096,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 430,
|
||||
"logprob": -1.3046875,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -1.5537109,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": " Deep learning is a subset of machine learning that is"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 128000,
|
||||
"logprob": null,
|
||||
"text": "<|begin_of_text|>"
|
||||
},
|
||||
{
|
||||
"id": 3923,
|
||||
"logprob": -5.6328125,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -1.2265625,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 5655,
|
||||
"logprob": -9.1015625,
|
||||
"text": " deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -1.8085938,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": -1.0439453,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 18682,
|
||||
"logprob": -2.1992188,
|
||||
"special": false,
|
||||
"text": " Deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.07897949,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -0.27734375,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.37402344,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 27084,
|
||||
"logprob": -1.4511719,
|
||||
"special": false,
|
||||
"text": " subset"
|
||||
},
|
||||
{
|
||||
"id": 315,
|
||||
"logprob": -0.02909851,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5780,
|
||||
"logprob": -0.25854492,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.0061798096,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 430,
|
||||
"logprob": -1.3046875,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -1.5537109,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": " Deep learning is a subset of machine learning that is"
|
||||
}
|
||||
]
|
|
@ -0,0 +1,77 @@
|
|||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_llama_fp8_kv_cache_handle(launcher):
|
||||
with launcher(
|
||||
"meta-llama/Meta-Llama-3-8B", num_shard=2, kv_cache_dtype="fp8_e5m2"
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_llama_fp8_kv_cache(flash_llama_fp8_kv_cache_handle):
|
||||
await flash_llama_fp8_kv_cache_handle.health(300)
|
||||
return flash_llama_fp8_kv_cache_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_fp8_kv_cache(flash_llama_fp8_kv_cache, response_snapshot):
|
||||
response = await flash_llama_fp8_kv_cache.generate(
|
||||
"What is deep learning?", max_new_tokens=10, decoder_input_details=True
|
||||
)
|
||||
|
||||
assert (
|
||||
response.generated_text
|
||||
== " Deep learning is a subset of machine learning that is"
|
||||
)
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_fp8_kv_cache_all_params(
|
||||
flash_llama_fp8_kv_cache, response_snapshot
|
||||
):
|
||||
response = await flash_llama_fp8_kv_cache.generate(
|
||||
"What is deep learning?",
|
||||
max_new_tokens=10,
|
||||
repetition_penalty=1.2,
|
||||
return_full_text=True,
|
||||
stop_sequences=["test"],
|
||||
temperature=0.5,
|
||||
top_p=0.9,
|
||||
top_k=10,
|
||||
truncate=5,
|
||||
typical_p=0.9,
|
||||
watermark=True,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_fp8_kv_cache_load(
|
||||
flash_llama_fp8_kv_cache, generate_load, response_snapshot
|
||||
):
|
||||
responses = await generate_load(
|
||||
flash_llama_fp8_kv_cache, "What is deep learning?", max_new_tokens=10, n=4
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
assert (
|
||||
responses[0].generated_text
|
||||
== " Deep learning is a subset of machine learning that is"
|
||||
)
|
||||
assert all(
|
||||
[r.generated_text == responses[0].generated_text for r in responses]
|
||||
), f"Different messages : {[r.generated_text for r in responses]}"
|
||||
assert responses == response_snapshot
|
|
@ -301,6 +301,22 @@ impl std::fmt::Display for Dtype {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum KVCacheDtype {
|
||||
#[clap(name = "fp8_e5m2")]
|
||||
Fp8e5m2,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for KVCacheDtype {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
KVCacheDtype::Fp8e5m2 => {
|
||||
write!(f, "fp8_e5m2")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum RopeScaling {
|
||||
Linear,
|
||||
|
@ -402,6 +418,12 @@ struct Args {
|
|||
#[clap(long, env, value_enum)]
|
||||
dtype: Option<Dtype>,
|
||||
|
||||
/// Specify the dtype for the key-value cache. When this option is not provided,
|
||||
/// the dtype of the model is used (typically `float16` or `bfloat16`). Currently
|
||||
/// the only supported value is `fp8_e5m2` on CUDA.
|
||||
#[clap(long, env, value_enum)]
|
||||
kv_cache_dtype: Option<KVCacheDtype>,
|
||||
|
||||
/// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
|
||||
/// encouraged when loading a model with custom code to ensure no malicious code has been
|
||||
/// contributed in a newer revision.
|
||||
|
@ -670,6 +692,7 @@ fn shard_manager(
|
|||
quantize: Option<Quantization>,
|
||||
speculate: Option<usize>,
|
||||
dtype: Option<Dtype>,
|
||||
kv_cache_dtype: Option<KVCacheDtype>,
|
||||
trust_remote_code: bool,
|
||||
uds_path: String,
|
||||
rank: usize,
|
||||
|
@ -743,6 +766,11 @@ fn shard_manager(
|
|||
shard_args.push(dtype.to_string())
|
||||
}
|
||||
|
||||
if let Some(kv_cache_dtype) = kv_cache_dtype {
|
||||
shard_args.push("--kv-cache-dtype".to_string());
|
||||
shard_args.push(kv_cache_dtype.to_string())
|
||||
}
|
||||
|
||||
// Model optional revision
|
||||
if let Some(revision) = revision {
|
||||
shard_args.push("--revision".to_string());
|
||||
|
@ -1299,6 +1327,7 @@ fn spawn_shards(
|
|||
let otlp_service_name = args.otlp_service_name.clone();
|
||||
let speculate = args.speculate;
|
||||
let dtype = args.dtype;
|
||||
let kv_cache_dtype = args.kv_cache_dtype;
|
||||
let trust_remote_code = args.trust_remote_code;
|
||||
let master_port = args.master_port;
|
||||
let disable_custom_kernels = args.disable_custom_kernels;
|
||||
|
@ -1317,6 +1346,7 @@ fn spawn_shards(
|
|||
quantize,
|
||||
speculate,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
trust_remote_code,
|
||||
uds_path,
|
||||
rank,
|
||||
|
|
|
@ -30,6 +30,10 @@ class Dtype(str, Enum):
|
|||
bloat16 = "bfloat16"
|
||||
|
||||
|
||||
class KVCacheDtype(str, Enum):
|
||||
fp8_e5m2 = "fp8_e5m2"
|
||||
|
||||
|
||||
@app.command()
|
||||
def serve(
|
||||
model_id: str,
|
||||
|
@ -38,6 +42,7 @@ def serve(
|
|||
quantize: Optional[Quantization] = None,
|
||||
speculate: Optional[int] = None,
|
||||
dtype: Optional[Dtype] = None,
|
||||
kv_cache_dtype: Optional[KVCacheDtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
uds_path: Path = "/tmp/text-generation-server",
|
||||
logger_level: str = "INFO",
|
||||
|
@ -97,6 +102,7 @@ def serve(
|
|||
# Downgrade enum into str for easier management later on
|
||||
quantize = None if quantize is None else quantize.value
|
||||
dtype = None if dtype is None else dtype.value
|
||||
kv_cache_dtype = None if kv_cache_dtype is None else kv_cache_dtype.value
|
||||
if dtype is not None and quantize not in {
|
||||
None,
|
||||
"bitsandbytes",
|
||||
|
@ -114,6 +120,7 @@ def serve(
|
|||
quantize,
|
||||
speculate,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
trust_remote_code,
|
||||
uds_path,
|
||||
max_input_tokens,
|
||||
|
|
|
@ -1,37 +1,40 @@
|
|||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
import os
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
from .common import Seqlen
|
||||
|
||||
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||
if SYSTEM == "cuda":
|
||||
from .cuda import (
|
||||
PREFILL_IN_KV_CACHE,
|
||||
SUPPORTS_WINDOWING,
|
||||
attention,
|
||||
paged_attention,
|
||||
reshape_and_cache,
|
||||
SUPPORTS_WINDOWING,
|
||||
PREFILL_IN_KV_CACHE,
|
||||
)
|
||||
elif SYSTEM == "rocm":
|
||||
from .rocm import (
|
||||
PREFILL_IN_KV_CACHE,
|
||||
SUPPORTS_WINDOWING,
|
||||
attention,
|
||||
paged_attention,
|
||||
reshape_and_cache,
|
||||
PREFILL_IN_KV_CACHE,
|
||||
SUPPORTS_WINDOWING,
|
||||
)
|
||||
elif SYSTEM == "ipex":
|
||||
from .ipex import (
|
||||
PREFILL_IN_KV_CACHE,
|
||||
SUPPORTS_WINDOWING,
|
||||
attention,
|
||||
paged_attention,
|
||||
reshape_and_cache,
|
||||
PREFILL_IN_KV_CACHE,
|
||||
SUPPORTS_WINDOWING,
|
||||
)
|
||||
else:
|
||||
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
|
||||
|
||||
# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
|
||||
from .kv_cache import KVCache
|
||||
|
||||
__all__ = [
|
||||
"attention",
|
||||
|
@ -39,5 +42,6 @@ __all__ = [
|
|||
"reshape_and_cache",
|
||||
"PREFILL_IN_KV_CACHE",
|
||||
"SUPPORTS_WINDOWING",
|
||||
"KVCache",
|
||||
"Seqlen",
|
||||
]
|
||||
|
|
|
@ -355,3 +355,11 @@ else:
|
|||
# have a configuration that requires flash-attention v1, which
|
||||
# does not support block tables.
|
||||
PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2
|
||||
|
||||
__all__ = [
|
||||
"PREFILL_IN_KV_CACHE",
|
||||
"SUPPORTS_WINDOWING",
|
||||
"attention",
|
||||
"paged_attention",
|
||||
"reshape_and_cache",
|
||||
]
|
||||
|
|
|
@ -80,3 +80,12 @@ def paged_attention(
|
|||
None,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PREFILL_IN_KV_CACHE",
|
||||
"SUPPORTS_WINDOWING",
|
||||
"attention",
|
||||
"paged_attention",
|
||||
"reshape_and_cache",
|
||||
]
|
||||
|
|
|
@ -0,0 +1,121 @@
|
|||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.attention import reshape_and_cache
|
||||
|
||||
|
||||
class KVCache:
|
||||
"""
|
||||
Key-value cache for attention layers.
|
||||
"""
|
||||
|
||||
kv_cache: Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_blocks: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
):
|
||||
"""Construct the key-value cache for a layer."""
|
||||
|
||||
if (
|
||||
dtype == torch.float8_e5m2
|
||||
and ATTENTION != "flashinfer"
|
||||
and SYSTEM != "cuda"
|
||||
):
|
||||
raise ValueError(
|
||||
"float8_e5m2 KV cache is currently only supported for flashinfer on CUDA"
|
||||
)
|
||||
|
||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||
if SYSTEM == "ipex" and device.type == "xpu":
|
||||
x = 1
|
||||
else:
|
||||
x = BLOCK_SIZE // element_size
|
||||
|
||||
if ATTENTION in {"flashdecoding", "flashinfer"}:
|
||||
self.kv_cache = (
|
||||
torch.empty(
|
||||
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
torch.empty(
|
||||
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
elif SYSTEM == "ipex" and device == torch.device("cpu"):
|
||||
self.kv_cache = (
|
||||
torch.empty(
|
||||
(num_blocks, num_heads, BLOCK_SIZE, head_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
torch.empty(
|
||||
(num_blocks, num_heads, BLOCK_SIZE, head_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.kv_cache = (
|
||||
torch.zeros(
|
||||
(num_blocks, num_heads, head_size // x, BLOCK_SIZE, x),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
torch.zeros(
|
||||
(num_blocks, num_heads, head_size, BLOCK_SIZE),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def key(self):
|
||||
"""Get the key cache."""
|
||||
|
||||
return self.kv_cache[0]
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
"""Get the value cache."""
|
||||
|
||||
return self.kv_cache[1]
|
||||
|
||||
def store(
|
||||
self,
|
||||
*,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
):
|
||||
"""Store the key and value at the given slots."""
|
||||
|
||||
key_cache = self.kv_cache[0]
|
||||
value_cache = self.kv_cache[1]
|
||||
|
||||
if ATTENTION in {"flashdecoding", "flashinfer"}:
|
||||
# TODO: add scale
|
||||
key = key.to(key_cache.dtype)
|
||||
value = value.to(value_cache.dtype)
|
||||
if key_cache.dtype == torch.float8_e5m2:
|
||||
# Torch index_put does not support float8_e5m2 yet, so
|
||||
# put as raw data instead.
|
||||
key_cache = key_cache.view(torch.uint8)
|
||||
value_cache = value_cache.view(torch.uint8)
|
||||
key = key.view(torch.uint8)
|
||||
value = value.view(torch.uint8)
|
||||
shape = key_cache.shape
|
||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
||||
else:
|
||||
reshape_and_cache(key, value, key_cache, value_cache, slots)
|
|
@ -306,3 +306,11 @@ elif ENGINE == "triton":
|
|||
|
||||
else:
|
||||
raise RuntimeError(f"Unknown attention engine {ENGINE}")
|
||||
|
||||
__all__ = [
|
||||
"PREFILL_IN_KV_CACHE",
|
||||
"SUPPORTS_WINDOWING",
|
||||
"attention",
|
||||
"paged_attention",
|
||||
"reshape_and_cache",
|
||||
]
|
||||
|
|
|
@ -342,6 +342,7 @@ def get_model(
|
|||
quantize: Optional[str],
|
||||
speculate: Optional[int],
|
||||
dtype: Optional[str],
|
||||
kv_cache_dtype: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
max_input_tokens: int,
|
||||
) -> Model:
|
||||
|
@ -403,6 +404,13 @@ def get_model(
|
|||
else:
|
||||
raise RuntimeError(f"Unknown dtype {dtype}")
|
||||
|
||||
if kv_cache_dtype is None:
|
||||
kv_cache_dtype = dtype
|
||||
elif kv_cache_dtype == "fp8_e5m2":
|
||||
kv_cache_dtype = torch.float8_e5m2
|
||||
else:
|
||||
raise RuntimeError(f"Unknown kv_cache_dtype: {kv_cache_dtype}")
|
||||
|
||||
if speculate is not None:
|
||||
set_speculate(speculate)
|
||||
else:
|
||||
|
@ -563,6 +571,7 @@ def get_model(
|
|||
speculator=speculator,
|
||||
default_dtype=torch.bfloat16,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
config_class=DeepseekV2Config,
|
||||
|
@ -617,6 +626,7 @@ def get_model(
|
|||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
aliases={"transformer.wte.weight": ["lm_head.weight"]},
|
||||
|
@ -668,6 +678,7 @@ def get_model(
|
|||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
|
@ -703,6 +714,7 @@ def get_model(
|
|||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
|
@ -741,6 +753,7 @@ def get_model(
|
|||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
config_class=GPTNeoXConfig,
|
||||
|
@ -774,6 +787,7 @@ def get_model(
|
|||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
|
@ -797,6 +811,7 @@ def get_model(
|
|||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
|
@ -836,6 +851,7 @@ def get_model(
|
|||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
|
@ -859,6 +875,7 @@ def get_model(
|
|||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
# Works better for these models
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
|
@ -884,6 +901,7 @@ def get_model(
|
|||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
# Works better for these models
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
|
@ -910,6 +928,7 @@ def get_model(
|
|||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
|
@ -934,6 +953,7 @@ def get_model(
|
|||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
# Dbrx works better in bfloat16.
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
|
@ -964,6 +984,7 @@ def get_model(
|
|||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
aliases={
|
||||
"lm_head.weight": ["transformer.word_embeddings.weight"],
|
||||
"transformer.word_embeddings.weight": ["lm_head.weight"],
|
||||
|
@ -982,6 +1003,7 @@ def get_model(
|
|||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
aliases={
|
||||
"lm_head.weight": ["transformer.word_embeddings.weight"],
|
||||
"transformer.word_embeddings.weight": ["lm_head.weight"],
|
||||
|
@ -1009,6 +1031,7 @@ def get_model(
|
|||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
|
@ -1033,6 +1056,7 @@ def get_model(
|
|||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
|
@ -1057,6 +1081,7 @@ def get_model(
|
|||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
|
@ -1083,6 +1108,7 @@ def get_model(
|
|||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
|
@ -1162,6 +1188,7 @@ def get_model(
|
|||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
# XXX: Extremely important to cap resolution in order to limit
|
||||
|
@ -1179,6 +1206,7 @@ def get_model(
|
|||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
# Works better for these models
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
|
@ -1197,6 +1225,7 @@ def get_model(
|
|||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
else:
|
||||
|
@ -1269,6 +1298,7 @@ def get_model_with_lora_adapters(
|
|||
quantize: Optional[str],
|
||||
speculate: Optional[int],
|
||||
dtype: Optional[str],
|
||||
kv_cache_dtype: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
max_input_tokens: int,
|
||||
adapter_to_index: Dict[str, int],
|
||||
|
@ -1282,6 +1312,7 @@ def get_model_with_lora_adapters(
|
|||
quantize,
|
||||
speculate,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
trust_remote_code,
|
||||
max_input_tokens,
|
||||
)
|
||||
|
|
|
@ -28,7 +28,6 @@ from typing import Optional, List, Tuple
|
|||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
@ -291,15 +290,15 @@ class FlashCohereAttention(torch.nn.Module):
|
|||
|
||||
self.rotary_emb(query, key, cos, sin)
|
||||
|
||||
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
||||
kv_cache.store(key=key, value=value, slots=slots)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else key,
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else value,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -308,8 +307,8 @@ class FlashCohereAttention(torch.nn.Module):
|
|||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
|
|
@ -28,7 +28,6 @@ if SYSTEM != "ipex":
|
|||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
PREFILL_IN_KV_CACHE,
|
||||
)
|
||||
|
@ -330,15 +329,15 @@ class DbrxAttention(torch.nn.Module):
|
|||
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -347,8 +346,8 @@ class DbrxAttention(torch.nn.Module):
|
|||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
|
|
@ -33,7 +33,6 @@ from text_generation_server.layers.attention import (
|
|||
Seqlen,
|
||||
attention,
|
||||
paged_attention,
|
||||
reshape_and_cache,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||
|
@ -321,15 +320,15 @@ class DeepseekV2Attention(torch.nn.Module):
|
|||
value, (0, self.head_pad_size - self.value_head_size), value=0
|
||||
)
|
||||
|
||||
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
||||
kv_cache.store(key=key, value=value, slots=slots)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else key,
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else value,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -338,8 +337,8 @@ class DeepseekV2Attention(torch.nn.Module):
|
|||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
|
|
@ -28,7 +28,6 @@ from typing import Optional, List, Tuple
|
|||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
|
@ -253,15 +252,15 @@ class FlashGemma2Attention(torch.nn.Module):
|
|||
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -273,8 +272,8 @@ class FlashGemma2Attention(torch.nn.Module):
|
|||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
|
|
@ -28,7 +28,6 @@ from typing import Optional, List, Tuple
|
|||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
PREFILL_IN_KV_CACHE,
|
||||
)
|
||||
|
@ -224,15 +223,15 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -242,8 +241,8 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
|
|
@ -28,7 +28,6 @@ from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
|||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
|
@ -224,15 +223,15 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||
key = key.view(-1, self.num_heads, self.head_size)
|
||||
value = value.view(-1, self.num_heads, self.head_size)
|
||||
|
||||
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
||||
kv_cache.store(key=key, value=value, slots=slots)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else key,
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else value,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -241,8 +240,8 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
|
|
@ -28,7 +28,6 @@ from text_generation_server.utils.import_utils import SYSTEM
|
|||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
|
@ -186,15 +185,15 @@ class FlashGPTJAttention(torch.nn.Module):
|
|||
else:
|
||||
self.rotary_emb(query, key, cos, sin)
|
||||
|
||||
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
||||
kv_cache.store(key=key, value=value, slots=slots)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else key,
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else value,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -203,8 +202,8 @@ class FlashGPTJAttention(torch.nn.Module):
|
|||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
|
|
@ -27,13 +27,12 @@ import torch.distributed
|
|||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE, KVCache
|
||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
|
@ -202,7 +201,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
kv_cache: KVCache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
|
@ -222,15 +221,15 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -239,8 +238,8 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
|
|
@ -30,7 +30,6 @@ from text_generation_server.utils.import_utils import SYSTEM
|
|||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
|
@ -210,17 +209,15 @@ class MistralAttention(torch.nn.Module):
|
|||
else:
|
||||
kv_to_cache = kv
|
||||
|
||||
reshape_and_cache(
|
||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -230,8 +227,8 @@ class MistralAttention(torch.nn.Module):
|
|||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
|
|
@ -37,7 +37,6 @@ from text_generation_server.layers.attention import (
|
|||
Seqlen,
|
||||
attention,
|
||||
paged_attention,
|
||||
reshape_and_cache,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||
|
@ -258,17 +257,15 @@ class MixtralAttention(torch.nn.Module):
|
|||
else:
|
||||
kv_to_cache = kv
|
||||
|
||||
reshape_and_cache(
|
||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -278,8 +275,8 @@ class MixtralAttention(torch.nn.Module):
|
|||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
|
|
@ -29,7 +29,6 @@ from typing import Optional, List, Tuple
|
|||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
|
@ -165,15 +164,15 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||
qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1)
|
||||
qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1)
|
||||
|
||||
reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots)
|
||||
kv_cache.store(key=qkv[:, 1], value=qkv[:, 2], slots=slots)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
qkv[:, 0],
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else qkv[:, 1],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else qkv[:, 2],
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else qkv[:, 1],
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else qkv[:, 2],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -182,8 +181,8 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||
else:
|
||||
attn_output = paged_attention(
|
||||
qkv[:, 0],
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
|
|
@ -9,7 +9,6 @@ from typing import Optional, List, Tuple
|
|||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
|
@ -188,14 +187,14 @@ class FlashPhiAttention(torch.nn.Module):
|
|||
)
|
||||
|
||||
# Reshape key and value and cache
|
||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -204,8 +203,8 @@ class FlashPhiAttention(torch.nn.Module):
|
|||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
|
|
@ -8,7 +8,6 @@ from typing import Optional, List, Tuple
|
|||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
|
@ -128,17 +127,15 @@ class Qwen2Attention(torch.nn.Module):
|
|||
else:
|
||||
kv_to_cache = kv
|
||||
|
||||
reshape_and_cache(
|
||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -148,8 +145,8 @@ class Qwen2Attention(torch.nn.Module):
|
|||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
|
|
@ -18,7 +18,6 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
|||
from text_generation_server.layers.attention import (
|
||||
attention,
|
||||
paged_attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
|
||||
|
@ -200,15 +199,15 @@ class FlashRWAttention(torch.nn.Module):
|
|||
# Inplace rotary
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -217,8 +216,8 @@ class FlashRWAttention(torch.nn.Module):
|
|||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
@ -312,12 +311,8 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||
# Inplace rotary
|
||||
self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin)
|
||||
|
||||
reshape_and_cache(
|
||||
kv[:, :, 0].contiguous(),
|
||||
kv[:, :, 1].contiguous(),
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
slots,
|
||||
kv_cache.store(
|
||||
key=kv[:, :, 0].contiguous(), value=kv[:, :, 1].contiguous(), slots=slots
|
||||
)
|
||||
|
||||
# Prefill
|
||||
|
@ -325,8 +320,8 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, :, 0].contiguous(),
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, :, 1].contiguous(),
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, :, 0].contiguous(),
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, :, 1].contiguous(),
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -335,8 +330,8 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
|
|
@ -8,7 +8,6 @@ from typing import Optional, List, Tuple
|
|||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
|
@ -284,17 +283,15 @@ class FlashMQAttention(torch.nn.Module):
|
|||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key_value = key_value.view(-1, 2, 1, self.head_size)
|
||||
|
||||
reshape_and_cache(
|
||||
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
kv_cache.store(key=key_value[:, 0], value=key_value[:, 1], slots=slots)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else key_value[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else key_value[:, 1],
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else key_value[:, 0],
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else key_value[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -303,8 +300,8 @@ class FlashMQAttention(torch.nn.Module):
|
|||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
|
|
@ -29,7 +29,6 @@ from typing import Optional, List, Tuple
|
|||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
|
@ -233,17 +232,15 @@ class Starcoder2Attention(torch.nn.Module):
|
|||
else:
|
||||
kv_to_cache = kv
|
||||
|
||||
reshape_and_cache(
|
||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -253,8 +250,8 @@ class Starcoder2Attention(torch.nn.Module):
|
|||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
|
|
@ -46,7 +46,7 @@ from text_generation_server.models.globals import (
|
|||
TGI_WIGGLE_ROOM,
|
||||
get_adapter_to_index,
|
||||
)
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
from text_generation_server.layers.attention import KVCache, Seqlen
|
||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||
from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||
from text_generation_server.utils.quantization import get_loader
|
||||
|
@ -937,6 +937,7 @@ class FlashCausalLM(Model):
|
|||
# Deepseek V2 uses different QK and V dims.
|
||||
head_size: Optional[int] = None,
|
||||
skip_special_tokens: bool = True,
|
||||
kv_cache_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
self.quantize = quantize
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
|
@ -1034,6 +1035,7 @@ class FlashCausalLM(Model):
|
|||
|
||||
self.cuda_graphs = {}
|
||||
self.kv_cache = []
|
||||
self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
|
||||
|
||||
if ATTENTION == "flashinfer":
|
||||
from text_generation_server.layers.attention.flashinfer import (
|
||||
|
@ -1083,61 +1085,16 @@ class FlashCausalLM(Model):
|
|||
):
|
||||
self.kv_cache = []
|
||||
empty_cache()
|
||||
|
||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||
if SYSTEM == "ipex" and device.type == "xpu":
|
||||
x = 1
|
||||
else:
|
||||
x = BLOCK_SIZE // element_size
|
||||
|
||||
if ATTENTION in {"flashdecoding", "flashinfer"}:
|
||||
self.kv_cache = [
|
||||
(
|
||||
torch.empty(
|
||||
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
torch.empty(
|
||||
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
elif SYSTEM == "ipex" and device == torch.device("cpu"):
|
||||
self.kv_cache = [
|
||||
(
|
||||
torch.empty(
|
||||
(num_blocks, num_heads, BLOCK_SIZE, head_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
torch.empty(
|
||||
(num_blocks, num_heads, BLOCK_SIZE, head_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
else:
|
||||
self.kv_cache = [
|
||||
(
|
||||
torch.zeros(
|
||||
(num_blocks, num_heads, head_size // x, BLOCK_SIZE, x),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
torch.zeros(
|
||||
(num_blocks, num_heads, head_size, BLOCK_SIZE),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
self.kv_cache = [
|
||||
KVCache(
|
||||
num_blocks=num_blocks,
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
|
||||
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
|
||||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||
|
@ -1258,7 +1215,7 @@ class FlashCausalLM(Model):
|
|||
self.num_layers,
|
||||
self.num_kv_heads,
|
||||
self.head_size,
|
||||
self.dtype,
|
||||
self.kv_cache_dtype,
|
||||
self.device,
|
||||
)
|
||||
max_bt = batch.max_blocks
|
||||
|
@ -1277,7 +1234,7 @@ class FlashCausalLM(Model):
|
|||
|
||||
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
||||
# Calculate the number of blocks that can be allocated with the free memory
|
||||
dtype_size = torch.tensor([], dtype=self.dtype).element_size()
|
||||
dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
|
||||
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
||||
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
||||
|
||||
|
@ -1291,6 +1248,8 @@ class FlashCausalLM(Model):
|
|||
+ batch_num_blocks
|
||||
)
|
||||
|
||||
log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}")
|
||||
|
||||
del batch
|
||||
|
||||
self.init_kv_cache(
|
||||
|
@ -1298,7 +1257,7 @@ class FlashCausalLM(Model):
|
|||
self.num_layers,
|
||||
self.num_kv_heads,
|
||||
self.head_size,
|
||||
self.dtype,
|
||||
self.kv_cache_dtype,
|
||||
self.device,
|
||||
)
|
||||
|
||||
|
|
|
@ -205,6 +205,7 @@ def serve(
|
|||
quantize: Optional[str],
|
||||
speculate: Optional[int],
|
||||
dtype: Optional[str],
|
||||
kv_cache_dtype: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
uds_path: Path,
|
||||
max_input_tokens: int,
|
||||
|
@ -217,6 +218,7 @@ def serve(
|
|||
quantize: Optional[str] = None,
|
||||
speculate: Optional[int] = None,
|
||||
dtype: Optional[str] = None,
|
||||
kv_cache_dtype: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
unix_socket_template = "unix://{}-{}"
|
||||
|
@ -240,6 +242,7 @@ def serve(
|
|||
quantize,
|
||||
speculate,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
trust_remote_code,
|
||||
max_input_tokens,
|
||||
adapter_to_index,
|
||||
|
@ -286,6 +289,7 @@ def serve(
|
|||
quantize,
|
||||
speculate,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
trust_remote_code,
|
||||
)
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue