More tensor cores. (#2558)
* More tensor cores. * Fixing the logic. * Gemma is modified by this.
This commit is contained in:
parent
c032280b17
commit
dd8691b7c5
|
@ -24,13 +24,13 @@
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 1736,
|
"id": 1736,
|
||||||
"logprob": -2.03125,
|
"logprob": -2.109375,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " form"
|
"text": " form"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 109,
|
"id": 109,
|
||||||
"logprob": -1.8671875,
|
"logprob": -1.90625,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n\n"
|
"text": "\n\n"
|
||||||
},
|
},
|
||||||
|
@ -42,48 +42,48 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2121,
|
"id": 2121,
|
||||||
"logprob": -1.8125,
|
"logprob": -1.796875,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " test"
|
"text": " test"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3853,
|
"id": 3853,
|
||||||
"logprob": -0.24121094,
|
"logprob": -0.24511719,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " request"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1736,
|
"id": 1736,
|
||||||
"logprob": -0.100097656,
|
"logprob": -0.09326172,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " form"
|
"text": " form"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 603,
|
"id": 603,
|
||||||
"logprob": -0.9453125,
|
"logprob": -0.95703125,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 476,
|
"id": 1671,
|
||||||
"logprob": -1.703125,
|
"logprob": -1.5859375,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " a"
|
"text": " used"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 4551,
|
"id": 577,
|
||||||
"logprob": -2.453125,
|
"logprob": -0.39257812,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " document"
|
"text": " to"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 674,
|
"id": 3853,
|
||||||
"logprob": -0.796875,
|
"logprob": -1.25,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " that"
|
"text": " request"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": " form\n\nThe test request form is a document that"
|
"generated_text": " form\n\nThe test request form is used to request"
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,12 +11,12 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2015,
|
"id": 2015,
|
||||||
"logprob": -9.640625,
|
"logprob": -9.6484375,
|
||||||
"text": "Test"
|
"text": "Test"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3853,
|
"id": 3853,
|
||||||
"logprob": -10.375,
|
"logprob": -10.3671875,
|
||||||
"text": " request"
|
"text": " request"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -24,19 +24,19 @@
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 604,
|
"id": 604,
|
||||||
"logprob": -0.2824707,
|
"logprob": -0.28271484,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " for"
|
"text": " for"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 573,
|
"id": 573,
|
||||||
"logprob": -0.19030762,
|
"logprob": -0.18493652,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " the"
|
"text": " the"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 16819,
|
"id": 16819,
|
||||||
"logprob": -1.4892578,
|
"logprob": -1.4804688,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " detection"
|
"text": " detection"
|
||||||
},
|
},
|
||||||
|
@ -46,44 +46,44 @@
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " of"
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"id": 573,
|
|
||||||
"logprob": -2.0195312,
|
|
||||||
"special": false,
|
|
||||||
"text": " the"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 8566,
|
|
||||||
"logprob": 0.0,
|
|
||||||
"special": false,
|
|
||||||
"text": " presence"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 689,
|
|
||||||
"logprob": -0.16491699,
|
|
||||||
"special": false,
|
|
||||||
"text": " or"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 14862,
|
|
||||||
"logprob": 0.0,
|
|
||||||
"special": false,
|
|
||||||
"text": " absence"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 576,
|
|
||||||
"logprob": -0.9946289,
|
|
||||||
"special": false,
|
|
||||||
"text": " of"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"id": 671,
|
"id": 671,
|
||||||
"logprob": -0.5263672,
|
"logprob": -2.1738281,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " an"
|
"text": " an"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 24646,
|
||||||
|
"logprob": -3.0449219,
|
||||||
|
"special": false,
|
||||||
|
"text": " RNA"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 12369,
|
||||||
|
"logprob": -0.19299316,
|
||||||
|
"special": false,
|
||||||
|
"text": " virus"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 575,
|
||||||
|
"logprob": -0.10632324,
|
||||||
|
"special": false,
|
||||||
|
"text": " in"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6022,
|
||||||
|
"logprob": -0.98095703,
|
||||||
|
"special": false,
|
||||||
|
"text": " patients"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1064,
|
||||||
|
"logprob": -1.3095703,
|
||||||
|
"special": false,
|
||||||
|
"text": " who"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "Test request for the detection of the presence or absence of an"
|
"generated_text": "Test request for the detection of an RNA virus in patients who"
|
||||||
}
|
}
|
||||||
|
|
|
@ -152,11 +152,13 @@ def create_decode_state(
|
||||||
):
|
):
|
||||||
"""Create a decode state."""
|
"""Create a decode state."""
|
||||||
workspace_buffer = get_workspace(device)
|
workspace_buffer = get_workspace(device)
|
||||||
|
num_groups = num_heads // num_kv_heads
|
||||||
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||||
workspace_buffer,
|
workspace_buffer,
|
||||||
kv_layout="NHD",
|
kv_layout="NHD",
|
||||||
use_cuda_graph=False,
|
use_cuda_graph=False,
|
||||||
use_tensor_cores=num_heads // num_kv_heads > 4,
|
# Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60
|
||||||
|
use_tensor_cores=num_groups not in [1, 2, 4, 8],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -175,6 +177,7 @@ def create_decode_state_cuda_graphs(
|
||||||
therefore stored as part of the state.
|
therefore stored as part of the state.
|
||||||
"""
|
"""
|
||||||
workspace_buffer = get_workspace(device)
|
workspace_buffer = get_workspace(device)
|
||||||
|
num_groups = num_heads // num_kv_heads
|
||||||
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||||
workspace_buffer,
|
workspace_buffer,
|
||||||
kv_layout="NHD",
|
kv_layout="NHD",
|
||||||
|
@ -182,7 +185,8 @@ def create_decode_state_cuda_graphs(
|
||||||
paged_kv_indices_buffer=block_tables,
|
paged_kv_indices_buffer=block_tables,
|
||||||
paged_kv_indptr_buffer=block_tables_ptr,
|
paged_kv_indptr_buffer=block_tables_ptr,
|
||||||
paged_kv_last_page_len_buffer=last_page_len,
|
paged_kv_last_page_len_buffer=last_page_len,
|
||||||
use_tensor_cores=num_heads // num_kv_heads > 4,
|
# Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60
|
||||||
|
use_tensor_cores=num_groups not in [1, 2, 4, 8],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue