From f55278de2d87b0a60cff81ffe1d5f8ec5312e5c7 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Sat, 17 Aug 2024 12:04:21 +0200 Subject: [PATCH] Allowing window_left_size (dummy version). --- backends/v3/src/radix.rs | 12 +++++++++--- .../text_generation_server/layers/attention/cuda.py | 3 ++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs index 5bac1a31..c9ac12c2 100644 --- a/backends/v3/src/radix.rs +++ b/backends/v3/src/radix.rs @@ -16,6 +16,11 @@ pub struct RadixAllocator { /// Blocks that are immediately available for allocation. free_blocks: Vec, + + #[allow(dead_code)] + // This isn't used because the prefix need to match without the windowing + // mecanism. This at worst is overallocating, not necessarily being wrong. + window_size: Option, } impl RadixAllocator { @@ -25,9 +30,9 @@ impl RadixAllocator { "Radix tree allocator only works with block_size=1, was: {}", block_size ); - if window_size.is_some() { - unimplemented!("Window size not supported in the prefix-caching block allocator yet"); - } + // if window_size.is_some() { + // unimplemented!("Window size not supported in the prefix-caching block allocator yet"); + // } RadixAllocator { allocation_id: 0, @@ -36,6 +41,7 @@ impl RadixAllocator { // Block 0 is reserved for health checks. free_blocks: (1..n_blocks).collect(), + window_size, } } diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index b3b7ea4f..7c415804 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -233,7 +233,7 @@ if ATTENTION == "flashinfer": causal=True, softcap=0.0, ): - assert window_size_left == -1, "Windowing is not supported with flash infer" + # assert window_size_left == -1, "Windowing is not supported with flash infer" from text_generation_server.layers.attention.flashinfer import ( prefill_with_paged_kv_state, ) @@ -244,6 +244,7 @@ if ATTENTION == "flashinfer": paged_kv_cache=(key_cache, value_cache), logits_soft_cap=softcap, sm_scale=softmax_scale, + window_left=window_size_left, ) elif V2: