More tensor cores. (#2558)

* More tensor cores.

* Fixing the logic.

* Gemma is modified by this.
This commit is contained in:
Nicolas Patry 2024-09-24 23:57:26 +02:00 committed by GitHub
parent c032280b17
commit dd8691b7c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 59 additions and 55 deletions

View File

@ -24,13 +24,13 @@
"tokens": [ "tokens": [
{ {
"id": 1736, "id": 1736,
"logprob": -2.03125, "logprob": -2.109375,
"special": false, "special": false,
"text": " form" "text": " form"
}, },
{ {
"id": 109, "id": 109,
"logprob": -1.8671875, "logprob": -1.90625,
"special": false, "special": false,
"text": "\n\n" "text": "\n\n"
}, },
@ -42,48 +42,48 @@
}, },
{ {
"id": 2121, "id": 2121,
"logprob": -1.8125, "logprob": -1.796875,
"special": false, "special": false,
"text": " test" "text": " test"
}, },
{ {
"id": 3853, "id": 3853,
"logprob": -0.24121094, "logprob": -0.24511719,
"special": false, "special": false,
"text": " request" "text": " request"
}, },
{ {
"id": 1736, "id": 1736,
"logprob": -0.100097656, "logprob": -0.09326172,
"special": false, "special": false,
"text": " form" "text": " form"
}, },
{ {
"id": 603, "id": 603,
"logprob": -0.9453125, "logprob": -0.95703125,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 476, "id": 1671,
"logprob": -1.703125, "logprob": -1.5859375,
"special": false, "special": false,
"text": " a" "text": " used"
}, },
{ {
"id": 4551, "id": 577,
"logprob": -2.453125, "logprob": -0.39257812,
"special": false, "special": false,
"text": " document" "text": " to"
}, },
{ {
"id": 674, "id": 3853,
"logprob": -0.796875, "logprob": -1.25,
"special": false, "special": false,
"text": " that" "text": " request"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": " form\n\nThe test request form is a document that" "generated_text": " form\n\nThe test request form is used to request"
} }

View File

@ -11,12 +11,12 @@
}, },
{ {
"id": 2015, "id": 2015,
"logprob": -9.640625, "logprob": -9.6484375,
"text": "Test" "text": "Test"
}, },
{ {
"id": 3853, "id": 3853,
"logprob": -10.375, "logprob": -10.3671875,
"text": " request" "text": " request"
} }
], ],
@ -24,19 +24,19 @@
"tokens": [ "tokens": [
{ {
"id": 604, "id": 604,
"logprob": -0.2824707, "logprob": -0.28271484,
"special": false, "special": false,
"text": " for" "text": " for"
}, },
{ {
"id": 573, "id": 573,
"logprob": -0.19030762, "logprob": -0.18493652,
"special": false, "special": false,
"text": " the" "text": " the"
}, },
{ {
"id": 16819, "id": 16819,
"logprob": -1.4892578, "logprob": -1.4804688,
"special": false, "special": false,
"text": " detection" "text": " detection"
}, },
@ -46,44 +46,44 @@
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{
"id": 573,
"logprob": -2.0195312,
"special": false,
"text": " the"
},
{
"id": 8566,
"logprob": 0.0,
"special": false,
"text": " presence"
},
{
"id": 689,
"logprob": -0.16491699,
"special": false,
"text": " or"
},
{
"id": 14862,
"logprob": 0.0,
"special": false,
"text": " absence"
},
{
"id": 576,
"logprob": -0.9946289,
"special": false,
"text": " of"
},
{ {
"id": 671, "id": 671,
"logprob": -0.5263672, "logprob": -2.1738281,
"special": false, "special": false,
"text": " an" "text": " an"
},
{
"id": 24646,
"logprob": -3.0449219,
"special": false,
"text": " RNA"
},
{
"id": 12369,
"logprob": -0.19299316,
"special": false,
"text": " virus"
},
{
"id": 575,
"logprob": -0.10632324,
"special": false,
"text": " in"
},
{
"id": 6022,
"logprob": -0.98095703,
"special": false,
"text": " patients"
},
{
"id": 1064,
"logprob": -1.3095703,
"special": false,
"text": " who"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "Test request for the detection of the presence or absence of an" "generated_text": "Test request for the detection of an RNA virus in patients who"
} }

View File

@ -152,11 +152,13 @@ def create_decode_state(
): ):
"""Create a decode state.""" """Create a decode state."""
workspace_buffer = get_workspace(device) workspace_buffer = get_workspace(device)
num_groups = num_heads // num_kv_heads
return flashinfer.BatchDecodeWithPagedKVCacheWrapper( return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, workspace_buffer,
kv_layout="NHD", kv_layout="NHD",
use_cuda_graph=False, use_cuda_graph=False,
use_tensor_cores=num_heads // num_kv_heads > 4, # Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60
use_tensor_cores=num_groups not in [1, 2, 4, 8],
) )
@ -175,6 +177,7 @@ def create_decode_state_cuda_graphs(
therefore stored as part of the state. therefore stored as part of the state.
""" """
workspace_buffer = get_workspace(device) workspace_buffer = get_workspace(device)
num_groups = num_heads // num_kv_heads
return flashinfer.BatchDecodeWithPagedKVCacheWrapper( return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, workspace_buffer,
kv_layout="NHD", kv_layout="NHD",
@ -182,7 +185,8 @@ def create_decode_state_cuda_graphs(
paged_kv_indices_buffer=block_tables, paged_kv_indices_buffer=block_tables,
paged_kv_indptr_buffer=block_tables_ptr, paged_kv_indptr_buffer=block_tables_ptr,
paged_kv_last_page_len_buffer=last_page_len, paged_kv_last_page_len_buffer=last_page_len,
use_tensor_cores=num_heads // num_kv_heads > 4, # Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60
use_tensor_cores=num_groups not in [1, 2, 4, 8],
) )