More tensor cores. (#2558)

* More tensor cores.

* Fixing the logic.

* Gemma is modified by this.
This commit is contained in:
Nicolas Patry 2024-09-24 23:57:26 +02:00 committed by GitHub
parent c032280b17
commit dd8691b7c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 59 additions and 55 deletions

View File

@ -24,13 +24,13 @@
"tokens": [
{
"id": 1736,
"logprob": -2.03125,
"logprob": -2.109375,
"special": false,
"text": " form"
},
{
"id": 109,
"logprob": -1.8671875,
"logprob": -1.90625,
"special": false,
"text": "\n\n"
},
@ -42,48 +42,48 @@
},
{
"id": 2121,
"logprob": -1.8125,
"logprob": -1.796875,
"special": false,
"text": " test"
},
{
"id": 3853,
"logprob": -0.24121094,
"logprob": -0.24511719,
"special": false,
"text": " request"
},
{
"id": 1736,
"logprob": -0.100097656,
"logprob": -0.09326172,
"special": false,
"text": " form"
},
{
"id": 603,
"logprob": -0.9453125,
"logprob": -0.95703125,
"special": false,
"text": " is"
},
{
"id": 476,
"logprob": -1.703125,
"id": 1671,
"logprob": -1.5859375,
"special": false,
"text": " a"
"text": " used"
},
{
"id": 4551,
"logprob": -2.453125,
"id": 577,
"logprob": -0.39257812,
"special": false,
"text": " document"
"text": " to"
},
{
"id": 674,
"logprob": -0.796875,
"id": 3853,
"logprob": -1.25,
"special": false,
"text": " that"
"text": " request"
}
],
"top_tokens": null
},
"generated_text": " form\n\nThe test request form is a document that"
"generated_text": " form\n\nThe test request form is used to request"
}

View File

@ -11,12 +11,12 @@
},
{
"id": 2015,
"logprob": -9.640625,
"logprob": -9.6484375,
"text": "Test"
},
{
"id": 3853,
"logprob": -10.375,
"logprob": -10.3671875,
"text": " request"
}
],
@ -24,19 +24,19 @@
"tokens": [
{
"id": 604,
"logprob": -0.2824707,
"logprob": -0.28271484,
"special": false,
"text": " for"
},
{
"id": 573,
"logprob": -0.19030762,
"logprob": -0.18493652,
"special": false,
"text": " the"
},
{
"id": 16819,
"logprob": -1.4892578,
"logprob": -1.4804688,
"special": false,
"text": " detection"
},
@ -46,44 +46,44 @@
"special": false,
"text": " of"
},
{
"id": 573,
"logprob": -2.0195312,
"special": false,
"text": " the"
},
{
"id": 8566,
"logprob": 0.0,
"special": false,
"text": " presence"
},
{
"id": 689,
"logprob": -0.16491699,
"special": false,
"text": " or"
},
{
"id": 14862,
"logprob": 0.0,
"special": false,
"text": " absence"
},
{
"id": 576,
"logprob": -0.9946289,
"special": false,
"text": " of"
},
{
"id": 671,
"logprob": -0.5263672,
"logprob": -2.1738281,
"special": false,
"text": " an"
},
{
"id": 24646,
"logprob": -3.0449219,
"special": false,
"text": " RNA"
},
{
"id": 12369,
"logprob": -0.19299316,
"special": false,
"text": " virus"
},
{
"id": 575,
"logprob": -0.10632324,
"special": false,
"text": " in"
},
{
"id": 6022,
"logprob": -0.98095703,
"special": false,
"text": " patients"
},
{
"id": 1064,
"logprob": -1.3095703,
"special": false,
"text": " who"
}
],
"top_tokens": null
},
"generated_text": "Test request for the detection of the presence or absence of an"
"generated_text": "Test request for the detection of an RNA virus in patients who"
}

View File

@ -152,11 +152,13 @@ def create_decode_state(
):
"""Create a decode state."""
workspace_buffer = get_workspace(device)
num_groups = num_heads // num_kv_heads
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout="NHD",
use_cuda_graph=False,
use_tensor_cores=num_heads // num_kv_heads > 4,
# Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60
use_tensor_cores=num_groups not in [1, 2, 4, 8],
)
@ -175,6 +177,7 @@ def create_decode_state_cuda_graphs(
therefore stored as part of the state.
"""
workspace_buffer = get_workspace(device)
num_groups = num_heads // num_kv_heads
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout="NHD",
@ -182,7 +185,8 @@ def create_decode_state_cuda_graphs(
paged_kv_indices_buffer=block_tables,
paged_kv_indptr_buffer=block_tables_ptr,
paged_kv_last_page_len_buffer=last_page_len,
use_tensor_cores=num_heads // num_kv_heads > 4,
# Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60
use_tensor_cores=num_groups not in [1, 2, 4, 8],
)