From 26e5037de43c28b6663d819dc621a616311a1ee5 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 26 Aug 2024 18:27:28 +0200 Subject: [PATCH] This seems to be working. --- backends/v3/src/radix.rs | 66 +++++++++++++------ launcher/src/main.rs | 6 +- .../models/flash_causal_lm.py | 11 ++-- .../text_generation_server/models/globals.py | 4 +- 4 files changed, 54 insertions(+), 33 deletions(-) diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs index be24b67b..c8e8d05b 100644 --- a/backends/v3/src/radix.rs +++ b/backends/v3/src/radix.rs @@ -158,34 +158,40 @@ impl Allocator for RadixAllocator { if let Some(prefill_tokens) = allocation.prefill_tokens { let prefill_tokens = prefill_tokens.as_slice(); - assert_eq!(prefill_tokens.len() % self.block_size as usize, 0); // If there are prefill tokens that did not come from the cache, // add them to the cache. if prefill_tokens.len() > allocation.cached_prefix_len { - let prefix_len = self - .cache_blocks - .insert( - prefill_tokens, - &blocks[..prefill_tokens.len() / self.block_size as usize], - ) - // Unwrap, failing is a programming error. - .expect("Failed to store prefill tokens"); + let aligned = + (prefill_tokens.len() / self.block_size as usize) * self.block_size as usize; + if aligned > 0 { + let prefix_len = self + .cache_blocks + .insert( + &prefill_tokens[..aligned], + &blocks[..aligned / self.block_size as usize], + ) + // Unwrap, failing is a programming error. + .expect("Failed to store prefill tokens"); - // We can have a prefill with the following structure: - // - // |---| From the prefix cache. - // A B C D E F G - //|--------| Found in the trie during insertion. - // - // This means that while processing this request there was a - // partially overlapping request that had A..=E in its - // prefill. In this case we need to free the blocks D E. - self.free_blocks - .extend(&blocks[allocation.cached_prefix_len..prefix_len]); + // We can have a prefill with the following structure: + // + // |---| From the prefix cache. + // A B C D E F G + //|--------| Found in the trie during insertion. + // + // This means that while processing this request there was a + // partially overlapping request that had A..=E in its + // prefill. In this case we need to free the blocks D E. + self.free_blocks.extend( + &blocks[allocation.cached_prefix_len / self.block_size as usize + ..prefix_len / self.block_size as usize], + ); + } } // Free non-prefill blocks. - self.free_blocks.extend(&blocks[prefill_tokens.len()..]); + self.free_blocks + .extend(&blocks[prefill_tokens.len() / self.block_size as usize..]); } else { self.free_blocks.extend(blocks); } @@ -605,6 +611,24 @@ mod tests { assert_eq!(allocation.prefix_len, 4); } + #[test] + fn allocator_block_size_non_aligned() { + let mut cache = RadixAllocator::new(2, 12, None); + let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap(); + assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); + assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]); + assert_eq!(allocation.prefix_len, 0); + cache.free( + allocation.blocks[..allocation.blocks.len() - 1].to_vec(), + allocation.allocation_id, + ); + + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation.blocks, vec![8, 9, 6, 7]); + assert_eq!(allocation.slots, vec![16, 17, 18, 19, 12, 13, 14, 15]); + assert_eq!(allocation.prefix_len, 4); + } + #[test] fn allocator_reuses_prefixes() { let mut cache = RadixAllocator::new(1, 12, None); diff --git a/launcher/src/main.rs b/launcher/src/main.rs index cc1d518e..557b3f8c 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -83,7 +83,7 @@ fn resolve_attention(config: &Option, lora_adapters: &Option) -> "Forcing flash decoding because model {} requires it", config.model_type.as_ref().unwrap() ); - prefix_caching = Some("0".to_string()); + prefix_caching = Some("1".to_string()); attention = Some("flashdecoding".to_string()); } } @@ -93,7 +93,7 @@ fn resolve_attention(config: &Option, lora_adapters: &Option) -> _ => { if prefix_caching.is_none() { tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching"); - prefix_caching = Some("0".to_string()); + prefix_caching = Some("1".to_string()); attention = Some("flashdecoding".to_string()); } } @@ -1000,7 +1000,7 @@ impl TryFrom<&[u8]> for PythonLogMessage { } fn log_lines(mut bufread: BufReader) { - let mut buffer = vec![0u8; 4096]; + let mut buffer = vec![0u8; 8 * 4096]; let mut stdout = std::io::stdout(); loop { let n = bufread.read(&mut buffer); diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 3d962bed..968eaf1d 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -265,13 +265,10 @@ class FlashCausalLMBatch(Batch): orig_input_length = len(tokenized_input) - if ATTENTION == "flashinfer": - prefix_len = r.prefix_len - if prefix_len == orig_input_length: - assert prefix_len > 0 - prefix_len -= 1 - else: - prefix_len = 0 + prefix_len = r.prefix_len + if prefix_len == orig_input_length: + assert prefix_len > 0 + prefix_len -= 1 prefix_ids.append(tokenized_input[:prefix_len]) tokenized_input = tokenized_input[prefix_len:] diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index aaed2475..1f9544a6 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -14,8 +14,8 @@ assert ( ), f"Attention is not valid {ATTENTION}, expected {_expected}" log_master(logger.info, f"Using Attention = {ATTENTION}") -if PREFIX_CACHING and ATTENTION != "flashinfer": - raise RuntimeError("Prefix caching is only supported with flashinfer") +# if PREFIX_CACHING and ATTENTION != "flashinfer": +# raise RuntimeError("Prefix caching is only supported with flashinfer") MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None