Add basic FP8 KV cache support (#2603)
* Add basic FP8 KV cache support This change adds rudimentary FP8 KV cache support. The support is enabled by passing `--kv-cache-dtype fp8_e5m2` to the launcher. Doing so uses this type for the KV cache. However support is still limited: * Only the `fp8_e5m2` type is supported. * The KV cache layout is the same as `float16`/`bfloat16` (HND). * The FP8 KV cache is only supported for FlashInfer. * Loading of scales is not yet supported. * Fix Cargo.toml
This commit is contained in:
parent
68103079f4
commit
2358c2bb54
|
@ -4177,7 +4177,7 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-backends-trtllm"
|
name = "text-generation-backends-trtllm"
|
||||||
version = "2.3.1-dev0"
|
version = "2.3.2-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
|
@ -4200,7 +4200,7 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-benchmark"
|
name = "text-generation-benchmark"
|
||||||
version = "2.3.1-dev0"
|
version = "2.3.2-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"average",
|
"average",
|
||||||
"clap 4.5.18",
|
"clap 4.5.18",
|
||||||
|
@ -4220,7 +4220,7 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-client"
|
name = "text-generation-client"
|
||||||
version = "2.3.1-dev0"
|
version = "2.3.2-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"base64 0.22.1",
|
"base64 0.22.1",
|
||||||
|
@ -4238,7 +4238,7 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-launcher"
|
name = "text-generation-launcher"
|
||||||
version = "2.3.1-dev0"
|
version = "2.3.2-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"clap 4.5.18",
|
"clap 4.5.18",
|
||||||
"ctrlc",
|
"ctrlc",
|
||||||
|
@ -4259,7 +4259,7 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-router"
|
name = "text-generation-router"
|
||||||
version = "2.3.1-dev0"
|
version = "2.3.2-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
|
@ -4308,7 +4308,7 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-router-v2"
|
name = "text-generation-router-v2"
|
||||||
version = "2.3.1-dev0"
|
version = "2.3.2-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
|
@ -4357,7 +4357,7 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-router-v3"
|
name = "text-generation-router-v3"
|
||||||
version = "2.3.1-dev0"
|
version = "2.3.2-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
|
|
|
@ -89,6 +89,15 @@ Options:
|
||||||
[env: DTYPE=]
|
[env: DTYPE=]
|
||||||
[possible values: float16, bfloat16]
|
[possible values: float16, bfloat16]
|
||||||
|
|
||||||
|
```
|
||||||
|
## KV_CACHE_DTYPE
|
||||||
|
```shell
|
||||||
|
--kv-cache-dtype <KV_CACHE_DTYPE>
|
||||||
|
Specify the dtype for the key-value cache. When this option is not provided, the dtype of the model is used (typically `float16` or `bfloat16`). Currently the only supported value is `fp8_e5m2` on CUDA
|
||||||
|
|
||||||
|
[env: KV_CACHE_DTYPE=]
|
||||||
|
[possible values: fp8_e5m2]
|
||||||
|
|
||||||
```
|
```
|
||||||
## TRUST_REMOTE_CODE
|
## TRUST_REMOTE_CODE
|
||||||
```shell
|
```shell
|
||||||
|
|
|
@ -336,6 +336,7 @@ def launcher(event_loop):
|
||||||
use_flash_attention: bool = True,
|
use_flash_attention: bool = True,
|
||||||
disable_grammar_support: bool = False,
|
disable_grammar_support: bool = False,
|
||||||
dtype: Optional[str] = None,
|
dtype: Optional[str] = None,
|
||||||
|
kv_cache_dtype: Optional[str] = None,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
max_input_length: Optional[int] = None,
|
max_input_length: Optional[int] = None,
|
||||||
max_batch_prefill_tokens: Optional[int] = None,
|
max_batch_prefill_tokens: Optional[int] = None,
|
||||||
|
@ -375,6 +376,9 @@ def launcher(event_loop):
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
args.append("--dtype")
|
args.append("--dtype")
|
||||||
args.append(dtype)
|
args.append(dtype)
|
||||||
|
if kv_cache_dtype is not None:
|
||||||
|
args.append("--kv-cache-dtype")
|
||||||
|
args.append(kv_cache_dtype)
|
||||||
if revision is not None:
|
if revision is not None:
|
||||||
args.append("--revision")
|
args.append("--revision")
|
||||||
args.append(revision)
|
args.append(revision)
|
||||||
|
@ -434,6 +438,7 @@ def launcher(event_loop):
|
||||||
use_flash_attention: bool = True,
|
use_flash_attention: bool = True,
|
||||||
disable_grammar_support: bool = False,
|
disable_grammar_support: bool = False,
|
||||||
dtype: Optional[str] = None,
|
dtype: Optional[str] = None,
|
||||||
|
kv_cache_dtype: Optional[str] = None,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
max_input_length: Optional[int] = None,
|
max_input_length: Optional[int] = None,
|
||||||
max_batch_prefill_tokens: Optional[int] = None,
|
max_batch_prefill_tokens: Optional[int] = None,
|
||||||
|
@ -456,6 +461,9 @@ def launcher(event_loop):
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
args.append("--dtype")
|
args.append("--dtype")
|
||||||
args.append(dtype)
|
args.append(dtype)
|
||||||
|
if kv_cache_dtype is not None:
|
||||||
|
args.append("--kv-cache-dtype")
|
||||||
|
args.append(kv_cache_dtype)
|
||||||
if revision is not None:
|
if revision is not None:
|
||||||
args.append("--revision")
|
args.append("--revision")
|
||||||
args.append(revision)
|
args.append(revision)
|
||||||
|
@ -589,7 +597,6 @@ def generate_multi():
|
||||||
max_new_tokens: int,
|
max_new_tokens: int,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
) -> List[Response]:
|
) -> List[Response]:
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
arange = np.arange(len(prompts))
|
arange = np.arange(len(prompts))
|
||||||
|
|
|
@ -0,0 +1,104 @@
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 128000,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<|begin_of_text|>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3923,
|
||||||
|
"logprob": -5.6328125,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -1.2265625,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5655,
|
||||||
|
"logprob": -9.1015625,
|
||||||
|
"text": " deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -1.8085938,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 30,
|
||||||
|
"logprob": -1.0439453,
|
||||||
|
"text": "?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 18682,
|
||||||
|
"logprob": -2.1992188,
|
||||||
|
"special": false,
|
||||||
|
"text": " Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -0.079956055,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -0.2763672,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 264,
|
||||||
|
"logprob": -0.37548828,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 27084,
|
||||||
|
"logprob": -1.4628906,
|
||||||
|
"special": false,
|
||||||
|
"text": " subset"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 315,
|
||||||
|
"logprob": -0.02885437,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5780,
|
||||||
|
"logprob": -0.2565918,
|
||||||
|
"special": false,
|
||||||
|
"text": " machine"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -0.0063438416,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 430,
|
||||||
|
"logprob": -1.3056641,
|
||||||
|
"special": false,
|
||||||
|
"text": " that"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -1.6035156,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " Deep learning is a subset of machine learning that is"
|
||||||
|
}
|
|
@ -0,0 +1,57 @@
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "eos_token",
|
||||||
|
"generated_tokens": 3,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 128000,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<|begin_of_text|>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -22.96875,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5655,
|
||||||
|
"logprob": -10.71875,
|
||||||
|
"text": " deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -2.6992188,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 30,
|
||||||
|
"logprob": -4.8398438,
|
||||||
|
"text": "?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": 0,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 720,
|
||||||
|
"logprob": -0.4411621,
|
||||||
|
"special": false,
|
||||||
|
"text": " \n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 220,
|
||||||
|
"logprob": -0.35864258,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 128001,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": true,
|
||||||
|
"text": "<|end_of_text|>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "What is deep learning? \n "
|
||||||
|
}
|
|
@ -0,0 +1,418 @@
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 128000,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<|begin_of_text|>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3923,
|
||||||
|
"logprob": -5.6328125,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -1.2265625,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5655,
|
||||||
|
"logprob": -9.1015625,
|
||||||
|
"text": " deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -1.8085938,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 30,
|
||||||
|
"logprob": -1.0439453,
|
||||||
|
"text": "?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 18682,
|
||||||
|
"logprob": -2.1992188,
|
||||||
|
"special": false,
|
||||||
|
"text": " Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -0.07897949,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -0.27734375,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 264,
|
||||||
|
"logprob": -0.37402344,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 27084,
|
||||||
|
"logprob": -1.4511719,
|
||||||
|
"special": false,
|
||||||
|
"text": " subset"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 315,
|
||||||
|
"logprob": -0.02909851,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5780,
|
||||||
|
"logprob": -0.25854492,
|
||||||
|
"special": false,
|
||||||
|
"text": " machine"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -0.0061798096,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 430,
|
||||||
|
"logprob": -1.3046875,
|
||||||
|
"special": false,
|
||||||
|
"text": " that"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -1.5537109,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " Deep learning is a subset of machine learning that is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 128000,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<|begin_of_text|>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3923,
|
||||||
|
"logprob": -5.6328125,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -1.2265625,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5655,
|
||||||
|
"logprob": -9.1015625,
|
||||||
|
"text": " deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -1.8085938,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 30,
|
||||||
|
"logprob": -1.0439453,
|
||||||
|
"text": "?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 18682,
|
||||||
|
"logprob": -2.1992188,
|
||||||
|
"special": false,
|
||||||
|
"text": " Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -0.07897949,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -0.27734375,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 264,
|
||||||
|
"logprob": -0.37402344,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 27084,
|
||||||
|
"logprob": -1.4511719,
|
||||||
|
"special": false,
|
||||||
|
"text": " subset"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 315,
|
||||||
|
"logprob": -0.02909851,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5780,
|
||||||
|
"logprob": -0.25854492,
|
||||||
|
"special": false,
|
||||||
|
"text": " machine"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -0.0061798096,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 430,
|
||||||
|
"logprob": -1.3046875,
|
||||||
|
"special": false,
|
||||||
|
"text": " that"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -1.5537109,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " Deep learning is a subset of machine learning that is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 128000,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<|begin_of_text|>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3923,
|
||||||
|
"logprob": -5.6328125,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -1.2265625,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5655,
|
||||||
|
"logprob": -9.1015625,
|
||||||
|
"text": " deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -1.8085938,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 30,
|
||||||
|
"logprob": -1.0439453,
|
||||||
|
"text": "?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 18682,
|
||||||
|
"logprob": -2.1992188,
|
||||||
|
"special": false,
|
||||||
|
"text": " Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -0.07897949,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -0.27734375,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 264,
|
||||||
|
"logprob": -0.37402344,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 27084,
|
||||||
|
"logprob": -1.4511719,
|
||||||
|
"special": false,
|
||||||
|
"text": " subset"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 315,
|
||||||
|
"logprob": -0.02909851,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5780,
|
||||||
|
"logprob": -0.25854492,
|
||||||
|
"special": false,
|
||||||
|
"text": " machine"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -0.0061798096,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 430,
|
||||||
|
"logprob": -1.3046875,
|
||||||
|
"special": false,
|
||||||
|
"text": " that"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -1.5537109,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " Deep learning is a subset of machine learning that is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 128000,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<|begin_of_text|>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3923,
|
||||||
|
"logprob": -5.6328125,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -1.2265625,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5655,
|
||||||
|
"logprob": -9.1015625,
|
||||||
|
"text": " deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -1.8085938,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 30,
|
||||||
|
"logprob": -1.0439453,
|
||||||
|
"text": "?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 18682,
|
||||||
|
"logprob": -2.1992188,
|
||||||
|
"special": false,
|
||||||
|
"text": " Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -0.07897949,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -0.27734375,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 264,
|
||||||
|
"logprob": -0.37402344,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 27084,
|
||||||
|
"logprob": -1.4511719,
|
||||||
|
"special": false,
|
||||||
|
"text": " subset"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 315,
|
||||||
|
"logprob": -0.02909851,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5780,
|
||||||
|
"logprob": -0.25854492,
|
||||||
|
"special": false,
|
||||||
|
"text": " machine"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -0.0061798096,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 430,
|
||||||
|
"logprob": -1.3046875,
|
||||||
|
"special": false,
|
||||||
|
"text": " that"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -1.5537109,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " Deep learning is a subset of machine learning that is"
|
||||||
|
}
|
||||||
|
]
|
|
@ -0,0 +1,77 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_llama_fp8_kv_cache_handle(launcher):
|
||||||
|
with launcher(
|
||||||
|
"meta-llama/Meta-Llama-3-8B", num_shard=2, kv_cache_dtype="fp8_e5m2"
|
||||||
|
) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_llama_fp8_kv_cache(flash_llama_fp8_kv_cache_handle):
|
||||||
|
await flash_llama_fp8_kv_cache_handle.health(300)
|
||||||
|
return flash_llama_fp8_kv_cache_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_fp8_kv_cache(flash_llama_fp8_kv_cache, response_snapshot):
|
||||||
|
response = await flash_llama_fp8_kv_cache.generate(
|
||||||
|
"What is deep learning?", max_new_tokens=10, decoder_input_details=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
response.generated_text
|
||||||
|
== " Deep learning is a subset of machine learning that is"
|
||||||
|
)
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_fp8_kv_cache_all_params(
|
||||||
|
flash_llama_fp8_kv_cache, response_snapshot
|
||||||
|
):
|
||||||
|
response = await flash_llama_fp8_kv_cache.generate(
|
||||||
|
"What is deep learning?",
|
||||||
|
max_new_tokens=10,
|
||||||
|
repetition_penalty=1.2,
|
||||||
|
return_full_text=True,
|
||||||
|
stop_sequences=["test"],
|
||||||
|
temperature=0.5,
|
||||||
|
top_p=0.9,
|
||||||
|
top_k=10,
|
||||||
|
truncate=5,
|
||||||
|
typical_p=0.9,
|
||||||
|
watermark=True,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_fp8_kv_cache_load(
|
||||||
|
flash_llama_fp8_kv_cache, generate_load, response_snapshot
|
||||||
|
):
|
||||||
|
responses = await generate_load(
|
||||||
|
flash_llama_fp8_kv_cache, "What is deep learning?", max_new_tokens=10, n=4
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
assert (
|
||||||
|
responses[0].generated_text
|
||||||
|
== " Deep learning is a subset of machine learning that is"
|
||||||
|
)
|
||||||
|
assert all(
|
||||||
|
[r.generated_text == responses[0].generated_text for r in responses]
|
||||||
|
), f"Different messages : {[r.generated_text for r in responses]}"
|
||||||
|
assert responses == response_snapshot
|
|
@ -301,6 +301,22 @@ impl std::fmt::Display for Dtype {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
|
enum KVCacheDtype {
|
||||||
|
#[clap(name = "fp8_e5m2")]
|
||||||
|
Fp8e5m2,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for KVCacheDtype {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
KVCacheDtype::Fp8e5m2 => {
|
||||||
|
write!(f, "fp8_e5m2")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
enum RopeScaling {
|
enum RopeScaling {
|
||||||
Linear,
|
Linear,
|
||||||
|
@ -402,6 +418,12 @@ struct Args {
|
||||||
#[clap(long, env, value_enum)]
|
#[clap(long, env, value_enum)]
|
||||||
dtype: Option<Dtype>,
|
dtype: Option<Dtype>,
|
||||||
|
|
||||||
|
/// Specify the dtype for the key-value cache. When this option is not provided,
|
||||||
|
/// the dtype of the model is used (typically `float16` or `bfloat16`). Currently
|
||||||
|
/// the only supported value is `fp8_e5m2` on CUDA.
|
||||||
|
#[clap(long, env, value_enum)]
|
||||||
|
kv_cache_dtype: Option<KVCacheDtype>,
|
||||||
|
|
||||||
/// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
|
/// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
|
||||||
/// encouraged when loading a model with custom code to ensure no malicious code has been
|
/// encouraged when loading a model with custom code to ensure no malicious code has been
|
||||||
/// contributed in a newer revision.
|
/// contributed in a newer revision.
|
||||||
|
@ -670,6 +692,7 @@ fn shard_manager(
|
||||||
quantize: Option<Quantization>,
|
quantize: Option<Quantization>,
|
||||||
speculate: Option<usize>,
|
speculate: Option<usize>,
|
||||||
dtype: Option<Dtype>,
|
dtype: Option<Dtype>,
|
||||||
|
kv_cache_dtype: Option<KVCacheDtype>,
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
uds_path: String,
|
uds_path: String,
|
||||||
rank: usize,
|
rank: usize,
|
||||||
|
@ -743,6 +766,11 @@ fn shard_manager(
|
||||||
shard_args.push(dtype.to_string())
|
shard_args.push(dtype.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(kv_cache_dtype) = kv_cache_dtype {
|
||||||
|
shard_args.push("--kv-cache-dtype".to_string());
|
||||||
|
shard_args.push(kv_cache_dtype.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
// Model optional revision
|
// Model optional revision
|
||||||
if let Some(revision) = revision {
|
if let Some(revision) = revision {
|
||||||
shard_args.push("--revision".to_string());
|
shard_args.push("--revision".to_string());
|
||||||
|
@ -1299,6 +1327,7 @@ fn spawn_shards(
|
||||||
let otlp_service_name = args.otlp_service_name.clone();
|
let otlp_service_name = args.otlp_service_name.clone();
|
||||||
let speculate = args.speculate;
|
let speculate = args.speculate;
|
||||||
let dtype = args.dtype;
|
let dtype = args.dtype;
|
||||||
|
let kv_cache_dtype = args.kv_cache_dtype;
|
||||||
let trust_remote_code = args.trust_remote_code;
|
let trust_remote_code = args.trust_remote_code;
|
||||||
let master_port = args.master_port;
|
let master_port = args.master_port;
|
||||||
let disable_custom_kernels = args.disable_custom_kernels;
|
let disable_custom_kernels = args.disable_custom_kernels;
|
||||||
|
@ -1317,6 +1346,7 @@ fn spawn_shards(
|
||||||
quantize,
|
quantize,
|
||||||
speculate,
|
speculate,
|
||||||
dtype,
|
dtype,
|
||||||
|
kv_cache_dtype,
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
uds_path,
|
uds_path,
|
||||||
rank,
|
rank,
|
||||||
|
|
|
@ -30,6 +30,10 @@ class Dtype(str, Enum):
|
||||||
bloat16 = "bfloat16"
|
bloat16 = "bfloat16"
|
||||||
|
|
||||||
|
|
||||||
|
class KVCacheDtype(str, Enum):
|
||||||
|
fp8_e5m2 = "fp8_e5m2"
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def serve(
|
def serve(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -38,6 +42,7 @@ def serve(
|
||||||
quantize: Optional[Quantization] = None,
|
quantize: Optional[Quantization] = None,
|
||||||
speculate: Optional[int] = None,
|
speculate: Optional[int] = None,
|
||||||
dtype: Optional[Dtype] = None,
|
dtype: Optional[Dtype] = None,
|
||||||
|
kv_cache_dtype: Optional[KVCacheDtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
uds_path: Path = "/tmp/text-generation-server",
|
uds_path: Path = "/tmp/text-generation-server",
|
||||||
logger_level: str = "INFO",
|
logger_level: str = "INFO",
|
||||||
|
@ -97,6 +102,7 @@ def serve(
|
||||||
# Downgrade enum into str for easier management later on
|
# Downgrade enum into str for easier management later on
|
||||||
quantize = None if quantize is None else quantize.value
|
quantize = None if quantize is None else quantize.value
|
||||||
dtype = None if dtype is None else dtype.value
|
dtype = None if dtype is None else dtype.value
|
||||||
|
kv_cache_dtype = None if kv_cache_dtype is None else kv_cache_dtype.value
|
||||||
if dtype is not None and quantize not in {
|
if dtype is not None and quantize not in {
|
||||||
None,
|
None,
|
||||||
"bitsandbytes",
|
"bitsandbytes",
|
||||||
|
@ -114,6 +120,7 @@ def serve(
|
||||||
quantize,
|
quantize,
|
||||||
speculate,
|
speculate,
|
||||||
dtype,
|
dtype,
|
||||||
|
kv_cache_dtype,
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
uds_path,
|
uds_path,
|
||||||
max_input_tokens,
|
max_input_tokens,
|
||||||
|
|
|
@ -1,37 +1,40 @@
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
from .common import Seqlen
|
from .common import Seqlen
|
||||||
|
|
||||||
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||||
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
from .cuda import (
|
from .cuda import (
|
||||||
|
PREFILL_IN_KV_CACHE,
|
||||||
|
SUPPORTS_WINDOWING,
|
||||||
attention,
|
attention,
|
||||||
paged_attention,
|
paged_attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
SUPPORTS_WINDOWING,
|
|
||||||
PREFILL_IN_KV_CACHE,
|
|
||||||
)
|
)
|
||||||
elif SYSTEM == "rocm":
|
elif SYSTEM == "rocm":
|
||||||
from .rocm import (
|
from .rocm import (
|
||||||
|
PREFILL_IN_KV_CACHE,
|
||||||
|
SUPPORTS_WINDOWING,
|
||||||
attention,
|
attention,
|
||||||
paged_attention,
|
paged_attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
PREFILL_IN_KV_CACHE,
|
|
||||||
SUPPORTS_WINDOWING,
|
|
||||||
)
|
)
|
||||||
elif SYSTEM == "ipex":
|
elif SYSTEM == "ipex":
|
||||||
from .ipex import (
|
from .ipex import (
|
||||||
|
PREFILL_IN_KV_CACHE,
|
||||||
|
SUPPORTS_WINDOWING,
|
||||||
attention,
|
attention,
|
||||||
paged_attention,
|
paged_attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
PREFILL_IN_KV_CACHE,
|
|
||||||
SUPPORTS_WINDOWING,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
|
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
|
||||||
|
|
||||||
|
# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
|
||||||
|
from .kv_cache import KVCache
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"attention",
|
"attention",
|
||||||
|
@ -39,5 +42,6 @@ __all__ = [
|
||||||
"reshape_and_cache",
|
"reshape_and_cache",
|
||||||
"PREFILL_IN_KV_CACHE",
|
"PREFILL_IN_KV_CACHE",
|
||||||
"SUPPORTS_WINDOWING",
|
"SUPPORTS_WINDOWING",
|
||||||
|
"KVCache",
|
||||||
"Seqlen",
|
"Seqlen",
|
||||||
]
|
]
|
||||||
|
|
|
@ -355,3 +355,11 @@ else:
|
||||||
# have a configuration that requires flash-attention v1, which
|
# have a configuration that requires flash-attention v1, which
|
||||||
# does not support block tables.
|
# does not support block tables.
|
||||||
PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2
|
PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"PREFILL_IN_KV_CACHE",
|
||||||
|
"SUPPORTS_WINDOWING",
|
||||||
|
"attention",
|
||||||
|
"paged_attention",
|
||||||
|
"reshape_and_cache",
|
||||||
|
]
|
||||||
|
|
|
@ -80,3 +80,12 @@ def paged_attention(
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"PREFILL_IN_KV_CACHE",
|
||||||
|
"SUPPORTS_WINDOWING",
|
||||||
|
"attention",
|
||||||
|
"paged_attention",
|
||||||
|
"reshape_and_cache",
|
||||||
|
]
|
||||||
|
|
|
@ -0,0 +1,121 @@
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from text_generation_server.layers.attention import reshape_and_cache
|
||||||
|
|
||||||
|
|
||||||
|
class KVCache:
|
||||||
|
"""
|
||||||
|
Key-value cache for attention layers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
kv_cache: Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
num_blocks: int,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
"""Construct the key-value cache for a layer."""
|
||||||
|
|
||||||
|
if (
|
||||||
|
dtype == torch.float8_e5m2
|
||||||
|
and ATTENTION != "flashinfer"
|
||||||
|
and SYSTEM != "cuda"
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"float8_e5m2 KV cache is currently only supported for flashinfer on CUDA"
|
||||||
|
)
|
||||||
|
|
||||||
|
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||||
|
if SYSTEM == "ipex" and device.type == "xpu":
|
||||||
|
x = 1
|
||||||
|
else:
|
||||||
|
x = BLOCK_SIZE // element_size
|
||||||
|
|
||||||
|
if ATTENTION in {"flashdecoding", "flashinfer"}:
|
||||||
|
self.kv_cache = (
|
||||||
|
torch.empty(
|
||||||
|
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
torch.empty(
|
||||||
|
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
elif SYSTEM == "ipex" and device == torch.device("cpu"):
|
||||||
|
self.kv_cache = (
|
||||||
|
torch.empty(
|
||||||
|
(num_blocks, num_heads, BLOCK_SIZE, head_size),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
torch.empty(
|
||||||
|
(num_blocks, num_heads, BLOCK_SIZE, head_size),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.kv_cache = (
|
||||||
|
torch.zeros(
|
||||||
|
(num_blocks, num_heads, head_size // x, BLOCK_SIZE, x),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
torch.zeros(
|
||||||
|
(num_blocks, num_heads, head_size, BLOCK_SIZE),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def key(self):
|
||||||
|
"""Get the key cache."""
|
||||||
|
|
||||||
|
return self.kv_cache[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def value(self):
|
||||||
|
"""Get the value cache."""
|
||||||
|
|
||||||
|
return self.kv_cache[1]
|
||||||
|
|
||||||
|
def store(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""Store the key and value at the given slots."""
|
||||||
|
|
||||||
|
key_cache = self.kv_cache[0]
|
||||||
|
value_cache = self.kv_cache[1]
|
||||||
|
|
||||||
|
if ATTENTION in {"flashdecoding", "flashinfer"}:
|
||||||
|
# TODO: add scale
|
||||||
|
key = key.to(key_cache.dtype)
|
||||||
|
value = value.to(value_cache.dtype)
|
||||||
|
if key_cache.dtype == torch.float8_e5m2:
|
||||||
|
# Torch index_put does not support float8_e5m2 yet, so
|
||||||
|
# put as raw data instead.
|
||||||
|
key_cache = key_cache.view(torch.uint8)
|
||||||
|
value_cache = value_cache.view(torch.uint8)
|
||||||
|
key = key.view(torch.uint8)
|
||||||
|
value = value.view(torch.uint8)
|
||||||
|
shape = key_cache.shape
|
||||||
|
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
||||||
|
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
||||||
|
else:
|
||||||
|
reshape_and_cache(key, value, key_cache, value_cache, slots)
|
|
@ -306,3 +306,11 @@ elif ENGINE == "triton":
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Unknown attention engine {ENGINE}")
|
raise RuntimeError(f"Unknown attention engine {ENGINE}")
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"PREFILL_IN_KV_CACHE",
|
||||||
|
"SUPPORTS_WINDOWING",
|
||||||
|
"attention",
|
||||||
|
"paged_attention",
|
||||||
|
"reshape_and_cache",
|
||||||
|
]
|
||||||
|
|
|
@ -342,6 +342,7 @@ def get_model(
|
||||||
quantize: Optional[str],
|
quantize: Optional[str],
|
||||||
speculate: Optional[int],
|
speculate: Optional[int],
|
||||||
dtype: Optional[str],
|
dtype: Optional[str],
|
||||||
|
kv_cache_dtype: Optional[str],
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
max_input_tokens: int,
|
max_input_tokens: int,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
|
@ -403,6 +404,13 @@ def get_model(
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Unknown dtype {dtype}")
|
raise RuntimeError(f"Unknown dtype {dtype}")
|
||||||
|
|
||||||
|
if kv_cache_dtype is None:
|
||||||
|
kv_cache_dtype = dtype
|
||||||
|
elif kv_cache_dtype == "fp8_e5m2":
|
||||||
|
kv_cache_dtype = torch.float8_e5m2
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unknown kv_cache_dtype: {kv_cache_dtype}")
|
||||||
|
|
||||||
if speculate is not None:
|
if speculate is not None:
|
||||||
set_speculate(speculate)
|
set_speculate(speculate)
|
||||||
else:
|
else:
|
||||||
|
@ -563,6 +571,7 @@ def get_model(
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
default_dtype=torch.bfloat16,
|
default_dtype=torch.bfloat16,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
config_class=DeepseekV2Config,
|
config_class=DeepseekV2Config,
|
||||||
|
@ -617,6 +626,7 @@ def get_model(
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
aliases={"transformer.wte.weight": ["lm_head.weight"]},
|
aliases={"transformer.wte.weight": ["lm_head.weight"]},
|
||||||
|
@ -668,6 +678,7 @@ def get_model(
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
|
@ -703,6 +714,7 @@ def get_model(
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
|
@ -741,6 +753,7 @@ def get_model(
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
config_class=GPTNeoXConfig,
|
config_class=GPTNeoXConfig,
|
||||||
|
@ -774,6 +787,7 @@ def get_model(
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
|
@ -797,6 +811,7 @@ def get_model(
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
|
@ -836,6 +851,7 @@ def get_model(
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
|
@ -859,6 +875,7 @@ def get_model(
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
# Works better for these models
|
# Works better for these models
|
||||||
default_dtype=torch.bfloat16,
|
default_dtype=torch.bfloat16,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
@ -884,6 +901,7 @@ def get_model(
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
# Works better for these models
|
# Works better for these models
|
||||||
default_dtype=torch.bfloat16,
|
default_dtype=torch.bfloat16,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
@ -910,6 +928,7 @@ def get_model(
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
|
@ -934,6 +953,7 @@ def get_model(
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
# Dbrx works better in bfloat16.
|
# Dbrx works better in bfloat16.
|
||||||
default_dtype=torch.bfloat16,
|
default_dtype=torch.bfloat16,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
@ -964,6 +984,7 @@ def get_model(
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
aliases={
|
aliases={
|
||||||
"lm_head.weight": ["transformer.word_embeddings.weight"],
|
"lm_head.weight": ["transformer.word_embeddings.weight"],
|
||||||
"transformer.word_embeddings.weight": ["lm_head.weight"],
|
"transformer.word_embeddings.weight": ["lm_head.weight"],
|
||||||
|
@ -982,6 +1003,7 @@ def get_model(
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
aliases={
|
aliases={
|
||||||
"lm_head.weight": ["transformer.word_embeddings.weight"],
|
"lm_head.weight": ["transformer.word_embeddings.weight"],
|
||||||
"transformer.word_embeddings.weight": ["lm_head.weight"],
|
"transformer.word_embeddings.weight": ["lm_head.weight"],
|
||||||
|
@ -1009,6 +1031,7 @@ def get_model(
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
|
@ -1033,6 +1056,7 @@ def get_model(
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
|
@ -1057,6 +1081,7 @@ def get_model(
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
|
@ -1083,6 +1108,7 @@ def get_model(
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
|
@ -1162,6 +1188,7 @@ def get_model(
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
# XXX: Extremely important to cap resolution in order to limit
|
# XXX: Extremely important to cap resolution in order to limit
|
||||||
|
@ -1179,6 +1206,7 @@ def get_model(
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
# Works better for these models
|
# Works better for these models
|
||||||
default_dtype=torch.bfloat16,
|
default_dtype=torch.bfloat16,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
@ -1197,6 +1225,7 @@ def get_model(
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
speculator=speculator,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -1269,6 +1298,7 @@ def get_model_with_lora_adapters(
|
||||||
quantize: Optional[str],
|
quantize: Optional[str],
|
||||||
speculate: Optional[int],
|
speculate: Optional[int],
|
||||||
dtype: Optional[str],
|
dtype: Optional[str],
|
||||||
|
kv_cache_dtype: Optional[str],
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
max_input_tokens: int,
|
max_input_tokens: int,
|
||||||
adapter_to_index: Dict[str, int],
|
adapter_to_index: Dict[str, int],
|
||||||
|
@ -1282,6 +1312,7 @@ def get_model_with_lora_adapters(
|
||||||
quantize,
|
quantize,
|
||||||
speculate,
|
speculate,
|
||||||
dtype,
|
dtype,
|
||||||
|
kv_cache_dtype,
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
max_input_tokens,
|
max_input_tokens,
|
||||||
)
|
)
|
||||||
|
|
|
@ -28,7 +28,6 @@ from typing import Optional, List, Tuple
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
|
||||||
Seqlen,
|
Seqlen,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
@ -291,15 +290,15 @@ class FlashCohereAttention(torch.nn.Module):
|
||||||
|
|
||||||
self.rotary_emb(query, key, cos, sin)
|
self.rotary_emb(query, key, cos, sin)
|
||||||
|
|
||||||
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
kv_cache.store(key=key, value=value, slots=slots)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
|
kv_cache.key if PREFILL_IN_KV_CACHE else key,
|
||||||
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
|
kv_cache.value if PREFILL_IN_KV_CACHE else value,
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -308,8 +307,8 @@ class FlashCohereAttention(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache.key,
|
||||||
kv_cache[1],
|
kv_cache.value,
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
|
|
@ -28,7 +28,6 @@ if SYSTEM != "ipex":
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
|
||||||
Seqlen,
|
Seqlen,
|
||||||
PREFILL_IN_KV_CACHE,
|
PREFILL_IN_KV_CACHE,
|
||||||
)
|
)
|
||||||
|
@ -330,15 +329,15 @@ class DbrxAttention(torch.nn.Module):
|
||||||
|
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -347,8 +346,8 @@ class DbrxAttention(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache.key,
|
||||||
kv_cache[1],
|
kv_cache.value,
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
|
|
@ -33,7 +33,6 @@ from text_generation_server.layers.attention import (
|
||||||
Seqlen,
|
Seqlen,
|
||||||
attention,
|
attention,
|
||||||
paged_attention,
|
paged_attention,
|
||||||
reshape_and_cache,
|
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||||
from text_generation_server.layers.layernorm import FastRMSNorm
|
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||||
|
@ -321,15 +320,15 @@ class DeepseekV2Attention(torch.nn.Module):
|
||||||
value, (0, self.head_pad_size - self.value_head_size), value=0
|
value, (0, self.head_pad_size - self.value_head_size), value=0
|
||||||
)
|
)
|
||||||
|
|
||||||
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
kv_cache.store(key=key, value=value, slots=slots)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
|
kv_cache.key if PREFILL_IN_KV_CACHE else key,
|
||||||
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
|
kv_cache.value if PREFILL_IN_KV_CACHE else value,
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -338,8 +337,8 @@ class DeepseekV2Attention(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache.key,
|
||||||
kv_cache[1],
|
kv_cache.value,
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
|
|
@ -28,7 +28,6 @@ from typing import Optional, List, Tuple
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
|
||||||
Seqlen,
|
Seqlen,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
|
@ -253,15 +252,15 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||||
|
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -273,8 +272,8 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache.key,
|
||||||
kv_cache[1],
|
kv_cache.value,
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
|
|
@ -28,7 +28,6 @@ from typing import Optional, List, Tuple
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
|
||||||
Seqlen,
|
Seqlen,
|
||||||
PREFILL_IN_KV_CACHE,
|
PREFILL_IN_KV_CACHE,
|
||||||
)
|
)
|
||||||
|
@ -224,15 +223,15 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||||
|
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -242,8 +241,8 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache.key,
|
||||||
kv_cache[1],
|
kv_cache.value,
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
|
|
@ -28,7 +28,6 @@ from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
|
||||||
Seqlen,
|
Seqlen,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
|
@ -224,15 +223,15 @@ class FlashGPT2Attention(torch.nn.Module):
|
||||||
key = key.view(-1, self.num_heads, self.head_size)
|
key = key.view(-1, self.num_heads, self.head_size)
|
||||||
value = value.view(-1, self.num_heads, self.head_size)
|
value = value.view(-1, self.num_heads, self.head_size)
|
||||||
|
|
||||||
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
kv_cache.store(key=key, value=value, slots=slots)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
|
kv_cache.key if PREFILL_IN_KV_CACHE else key,
|
||||||
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
|
kv_cache.value if PREFILL_IN_KV_CACHE else value,
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -241,8 +240,8 @@ class FlashGPT2Attention(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache.key,
|
||||||
kv_cache[1],
|
kv_cache.value,
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
|
|
@ -28,7 +28,6 @@ from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
|
||||||
Seqlen,
|
Seqlen,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
|
@ -186,15 +185,15 @@ class FlashGPTJAttention(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
self.rotary_emb(query, key, cos, sin)
|
self.rotary_emb(query, key, cos, sin)
|
||||||
|
|
||||||
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
kv_cache.store(key=key, value=value, slots=slots)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
|
kv_cache.key if PREFILL_IN_KV_CACHE else key,
|
||||||
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
|
kv_cache.value if PREFILL_IN_KV_CACHE else value,
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -203,8 +202,8 @@ class FlashGPTJAttention(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache.key,
|
||||||
kv_cache[1],
|
kv_cache.value,
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
|
|
@ -27,13 +27,12 @@ import torch.distributed
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
|
|
||||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE, KVCache
|
||||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
|
||||||
Seqlen,
|
Seqlen,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
|
@ -202,7 +201,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache: KVCache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
|
@ -222,15 +221,15 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
|
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -239,8 +238,8 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache.key,
|
||||||
kv_cache[1],
|
kv_cache.value,
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
|
|
@ -30,7 +30,6 @@ from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
|
||||||
Seqlen,
|
Seqlen,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
|
@ -210,17 +209,15 @@ class MistralAttention(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
kv_to_cache = kv
|
kv_to_cache = kv
|
||||||
|
|
||||||
reshape_and_cache(
|
kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots)
|
||||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
||||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -230,8 +227,8 @@ class MistralAttention(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache.key,
|
||||||
kv_cache[1],
|
kv_cache.value,
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
|
|
@ -37,7 +37,6 @@ from text_generation_server.layers.attention import (
|
||||||
Seqlen,
|
Seqlen,
|
||||||
attention,
|
attention,
|
||||||
paged_attention,
|
paged_attention,
|
||||||
reshape_and_cache,
|
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||||
from text_generation_server.layers.layernorm import FastRMSNorm
|
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||||
|
@ -258,17 +257,15 @@ class MixtralAttention(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
kv_to_cache = kv
|
kv_to_cache = kv
|
||||||
|
|
||||||
reshape_and_cache(
|
kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots)
|
||||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
||||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -278,8 +275,8 @@ class MixtralAttention(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache.key,
|
||||||
kv_cache[1],
|
kv_cache.value,
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
|
|
@ -29,7 +29,6 @@ from typing import Optional, List, Tuple
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
|
||||||
Seqlen,
|
Seqlen,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
|
@ -165,15 +164,15 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||||
qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1)
|
qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1)
|
||||||
qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1)
|
qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1)
|
||||||
|
|
||||||
reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots)
|
kv_cache.store(key=qkv[:, 1], value=qkv[:, 2], slots=slots)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
qkv[:, 0],
|
qkv[:, 0],
|
||||||
kv_cache[0] if PREFILL_IN_KV_CACHE else qkv[:, 1],
|
kv_cache.key if PREFILL_IN_KV_CACHE else qkv[:, 1],
|
||||||
kv_cache[1] if PREFILL_IN_KV_CACHE else qkv[:, 2],
|
kv_cache.value if PREFILL_IN_KV_CACHE else qkv[:, 2],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -182,8 +181,8 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
qkv[:, 0],
|
qkv[:, 0],
|
||||||
kv_cache[0],
|
kv_cache.key,
|
||||||
kv_cache[1],
|
kv_cache.value,
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
|
|
@ -9,7 +9,6 @@ from typing import Optional, List, Tuple
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
|
||||||
Seqlen,
|
Seqlen,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
|
@ -188,14 +187,14 @@ class FlashPhiAttention(torch.nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reshape key and value and cache
|
# Reshape key and value and cache
|
||||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -204,8 +203,8 @@ class FlashPhiAttention(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache.key,
|
||||||
kv_cache[1],
|
kv_cache.value,
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
|
|
@ -8,7 +8,6 @@ from typing import Optional, List, Tuple
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
|
||||||
Seqlen,
|
Seqlen,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
|
@ -128,17 +127,15 @@ class Qwen2Attention(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
kv_to_cache = kv
|
kv_to_cache = kv
|
||||||
|
|
||||||
reshape_and_cache(
|
kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots)
|
||||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
||||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -148,8 +145,8 @@ class Qwen2Attention(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache.key,
|
||||||
kv_cache[1],
|
kv_cache.value,
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
|
|
@ -18,7 +18,6 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
attention,
|
attention,
|
||||||
paged_attention,
|
paged_attention,
|
||||||
reshape_and_cache,
|
|
||||||
Seqlen,
|
Seqlen,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -200,15 +199,15 @@ class FlashRWAttention(torch.nn.Module):
|
||||||
# Inplace rotary
|
# Inplace rotary
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -217,8 +216,8 @@ class FlashRWAttention(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache.key,
|
||||||
kv_cache[1],
|
kv_cache.value,
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
@ -312,12 +311,8 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||||
# Inplace rotary
|
# Inplace rotary
|
||||||
self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin)
|
||||||
|
|
||||||
reshape_and_cache(
|
kv_cache.store(
|
||||||
kv[:, :, 0].contiguous(),
|
key=kv[:, :, 0].contiguous(), value=kv[:, :, 1].contiguous(), slots=slots
|
||||||
kv[:, :, 1].contiguous(),
|
|
||||||
kv_cache[0],
|
|
||||||
kv_cache[1],
|
|
||||||
slots,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
|
@ -325,8 +320,8 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, :, 0].contiguous(),
|
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, :, 0].contiguous(),
|
||||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, :, 1].contiguous(),
|
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, :, 1].contiguous(),
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -335,8 +330,8 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache.key,
|
||||||
kv_cache[1],
|
kv_cache.value,
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
|
|
@ -8,7 +8,6 @@ from typing import Optional, List, Tuple
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
|
||||||
Seqlen,
|
Seqlen,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
|
@ -284,17 +283,15 @@ class FlashMQAttention(torch.nn.Module):
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
key_value = key_value.view(-1, 2, 1, self.head_size)
|
key_value = key_value.view(-1, 2, 1, self.head_size)
|
||||||
|
|
||||||
reshape_and_cache(
|
kv_cache.store(key=key_value[:, 0], value=key_value[:, 1], slots=slots)
|
||||||
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if PREFILL_IN_KV_CACHE else key_value[:, 0],
|
kv_cache.key if PREFILL_IN_KV_CACHE else key_value[:, 0],
|
||||||
kv_cache[1] if PREFILL_IN_KV_CACHE else key_value[:, 1],
|
kv_cache.value if PREFILL_IN_KV_CACHE else key_value[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -303,8 +300,8 @@ class FlashMQAttention(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache.key,
|
||||||
kv_cache[1],
|
kv_cache.value,
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
|
|
@ -29,7 +29,6 @@ from typing import Optional, List, Tuple
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
|
||||||
Seqlen,
|
Seqlen,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
|
@ -233,17 +232,15 @@ class Starcoder2Attention(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
kv_to_cache = kv
|
kv_to_cache = kv
|
||||||
|
|
||||||
reshape_and_cache(
|
kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots)
|
||||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
||||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -253,8 +250,8 @@ class Starcoder2Attention(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache.key,
|
||||||
kv_cache[1],
|
kv_cache.value,
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
|
|
@ -46,7 +46,7 @@ from text_generation_server.models.globals import (
|
||||||
TGI_WIGGLE_ROOM,
|
TGI_WIGGLE_ROOM,
|
||||||
get_adapter_to_index,
|
get_adapter_to_index,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import KVCache, Seqlen
|
||||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||||
from text_generation_server.utils.dist import MEMORY_FRACTION
|
from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||||
from text_generation_server.utils.quantization import get_loader
|
from text_generation_server.utils.quantization import get_loader
|
||||||
|
@ -937,6 +937,7 @@ class FlashCausalLM(Model):
|
||||||
# Deepseek V2 uses different QK and V dims.
|
# Deepseek V2 uses different QK and V dims.
|
||||||
head_size: Optional[int] = None,
|
head_size: Optional[int] = None,
|
||||||
skip_special_tokens: bool = True,
|
skip_special_tokens: bool = True,
|
||||||
|
kv_cache_dtype: Optional[torch.dtype] = None,
|
||||||
):
|
):
|
||||||
self.quantize = quantize
|
self.quantize = quantize
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
|
@ -1034,6 +1035,7 @@ class FlashCausalLM(Model):
|
||||||
|
|
||||||
self.cuda_graphs = {}
|
self.cuda_graphs = {}
|
||||||
self.kv_cache = []
|
self.kv_cache = []
|
||||||
|
self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
|
||||||
|
|
||||||
if ATTENTION == "flashinfer":
|
if ATTENTION == "flashinfer":
|
||||||
from text_generation_server.layers.attention.flashinfer import (
|
from text_generation_server.layers.attention.flashinfer import (
|
||||||
|
@ -1083,58 +1085,13 @@ class FlashCausalLM(Model):
|
||||||
):
|
):
|
||||||
self.kv_cache = []
|
self.kv_cache = []
|
||||||
empty_cache()
|
empty_cache()
|
||||||
|
|
||||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
|
||||||
if SYSTEM == "ipex" and device.type == "xpu":
|
|
||||||
x = 1
|
|
||||||
else:
|
|
||||||
x = BLOCK_SIZE // element_size
|
|
||||||
|
|
||||||
if ATTENTION in {"flashdecoding", "flashinfer"}:
|
|
||||||
self.kv_cache = [
|
self.kv_cache = [
|
||||||
(
|
KVCache(
|
||||||
torch.empty(
|
num_blocks=num_blocks,
|
||||||
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
num_heads=num_heads,
|
||||||
|
head_size=head_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
),
|
|
||||||
torch.empty(
|
|
||||||
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
for _ in range(num_layers)
|
|
||||||
]
|
|
||||||
elif SYSTEM == "ipex" and device == torch.device("cpu"):
|
|
||||||
self.kv_cache = [
|
|
||||||
(
|
|
||||||
torch.empty(
|
|
||||||
(num_blocks, num_heads, BLOCK_SIZE, head_size),
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
),
|
|
||||||
torch.empty(
|
|
||||||
(num_blocks, num_heads, BLOCK_SIZE, head_size),
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
for _ in range(num_layers)
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
self.kv_cache = [
|
|
||||||
(
|
|
||||||
torch.zeros(
|
|
||||||
(num_blocks, num_heads, head_size // x, BLOCK_SIZE, x),
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
),
|
|
||||||
torch.zeros(
|
|
||||||
(num_blocks, num_heads, head_size, BLOCK_SIZE),
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
for _ in range(num_layers)
|
for _ in range(num_layers)
|
||||||
]
|
]
|
||||||
|
@ -1258,7 +1215,7 @@ class FlashCausalLM(Model):
|
||||||
self.num_layers,
|
self.num_layers,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.head_size,
|
self.head_size,
|
||||||
self.dtype,
|
self.kv_cache_dtype,
|
||||||
self.device,
|
self.device,
|
||||||
)
|
)
|
||||||
max_bt = batch.max_blocks
|
max_bt = batch.max_blocks
|
||||||
|
@ -1277,7 +1234,7 @@ class FlashCausalLM(Model):
|
||||||
|
|
||||||
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
||||||
# Calculate the number of blocks that can be allocated with the free memory
|
# Calculate the number of blocks that can be allocated with the free memory
|
||||||
dtype_size = torch.tensor([], dtype=self.dtype).element_size()
|
dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
|
||||||
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
||||||
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
||||||
|
|
||||||
|
@ -1291,6 +1248,8 @@ class FlashCausalLM(Model):
|
||||||
+ batch_num_blocks
|
+ batch_num_blocks
|
||||||
)
|
)
|
||||||
|
|
||||||
|
log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}")
|
||||||
|
|
||||||
del batch
|
del batch
|
||||||
|
|
||||||
self.init_kv_cache(
|
self.init_kv_cache(
|
||||||
|
@ -1298,7 +1257,7 @@ class FlashCausalLM(Model):
|
||||||
self.num_layers,
|
self.num_layers,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.head_size,
|
self.head_size,
|
||||||
self.dtype,
|
self.kv_cache_dtype,
|
||||||
self.device,
|
self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -205,6 +205,7 @@ def serve(
|
||||||
quantize: Optional[str],
|
quantize: Optional[str],
|
||||||
speculate: Optional[int],
|
speculate: Optional[int],
|
||||||
dtype: Optional[str],
|
dtype: Optional[str],
|
||||||
|
kv_cache_dtype: Optional[str],
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
uds_path: Path,
|
uds_path: Path,
|
||||||
max_input_tokens: int,
|
max_input_tokens: int,
|
||||||
|
@ -217,6 +218,7 @@ def serve(
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
speculate: Optional[int] = None,
|
speculate: Optional[int] = None,
|
||||||
dtype: Optional[str] = None,
|
dtype: Optional[str] = None,
|
||||||
|
kv_cache_dtype: Optional[str] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
unix_socket_template = "unix://{}-{}"
|
unix_socket_template = "unix://{}-{}"
|
||||||
|
@ -240,6 +242,7 @@ def serve(
|
||||||
quantize,
|
quantize,
|
||||||
speculate,
|
speculate,
|
||||||
dtype,
|
dtype,
|
||||||
|
kv_cache_dtype,
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
max_input_tokens,
|
max_input_tokens,
|
||||||
adapter_to_index,
|
adapter_to_index,
|
||||||
|
@ -286,6 +289,7 @@ def serve(
|
||||||
quantize,
|
quantize,
|
||||||
speculate,
|
speculate,
|
||||||
dtype,
|
dtype,
|
||||||
|
kv_cache_dtype,
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue