This seems to be working.

This commit is contained in:
Nicolas Patry 2024-08-26 18:27:28 +02:00
parent f5182c188c
commit 26e5037de4
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
4 changed files with 54 additions and 33 deletions

View File

@ -158,34 +158,40 @@ impl Allocator for RadixAllocator {
if let Some(prefill_tokens) = allocation.prefill_tokens {
let prefill_tokens = prefill_tokens.as_slice();
assert_eq!(prefill_tokens.len() % self.block_size as usize, 0);
// If there are prefill tokens that did not come from the cache,
// add them to the cache.
if prefill_tokens.len() > allocation.cached_prefix_len {
let prefix_len = self
.cache_blocks
.insert(
prefill_tokens,
&blocks[..prefill_tokens.len() / self.block_size as usize],
)
// Unwrap, failing is a programming error.
.expect("Failed to store prefill tokens");
let aligned =
(prefill_tokens.len() / self.block_size as usize) * self.block_size as usize;
if aligned > 0 {
let prefix_len = self
.cache_blocks
.insert(
&prefill_tokens[..aligned],
&blocks[..aligned / self.block_size as usize],
)
// Unwrap, failing is a programming error.
.expect("Failed to store prefill tokens");
// We can have a prefill with the following structure:
//
// |---| From the prefix cache.
// A B C D E F G
//|--------| Found in the trie during insertion.
//
// This means that while processing this request there was a
// partially overlapping request that had A..=E in its
// prefill. In this case we need to free the blocks D E.
self.free_blocks
.extend(&blocks[allocation.cached_prefix_len..prefix_len]);
// We can have a prefill with the following structure:
//
// |---| From the prefix cache.
// A B C D E F G
//|--------| Found in the trie during insertion.
//
// This means that while processing this request there was a
// partially overlapping request that had A..=E in its
// prefill. In this case we need to free the blocks D E.
self.free_blocks.extend(
&blocks[allocation.cached_prefix_len / self.block_size as usize
..prefix_len / self.block_size as usize],
);
}
}
// Free non-prefill blocks.
self.free_blocks.extend(&blocks[prefill_tokens.len()..]);
self.free_blocks
.extend(&blocks[prefill_tokens.len() / self.block_size as usize..]);
} else {
self.free_blocks.extend(blocks);
}
@ -605,6 +611,24 @@ mod tests {
assert_eq!(allocation.prefix_len, 4);
}
#[test]
fn allocator_block_size_non_aligned() {
let mut cache = RadixAllocator::new(2, 12, None);
let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap();
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]);
assert_eq!(allocation.prefix_len, 0);
cache.free(
allocation.blocks[..allocation.blocks.len() - 1].to_vec(),
allocation.allocation_id,
);
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation.blocks, vec![8, 9, 6, 7]);
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 12, 13, 14, 15]);
assert_eq!(allocation.prefix_len, 4);
}
#[test]
fn allocator_reuses_prefixes() {
let mut cache = RadixAllocator::new(1, 12, None);

View File

@ -83,7 +83,7 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
"Forcing flash decoding because model {} requires it",
config.model_type.as_ref().unwrap()
);
prefix_caching = Some("0".to_string());
prefix_caching = Some("1".to_string());
attention = Some("flashdecoding".to_string());
}
}
@ -93,7 +93,7 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
_ => {
if prefix_caching.is_none() {
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching");
prefix_caching = Some("0".to_string());
prefix_caching = Some("1".to_string());
attention = Some("flashdecoding".to_string());
}
}
@ -1000,7 +1000,7 @@ impl TryFrom<&[u8]> for PythonLogMessage {
}
fn log_lines<R: Sized + Read>(mut bufread: BufReader<R>) {
let mut buffer = vec![0u8; 4096];
let mut buffer = vec![0u8; 8 * 4096];
let mut stdout = std::io::stdout();
loop {
let n = bufread.read(&mut buffer);

View File

@ -265,13 +265,10 @@ class FlashCausalLMBatch(Batch):
orig_input_length = len(tokenized_input)
if ATTENTION == "flashinfer":
prefix_len = r.prefix_len
if prefix_len == orig_input_length:
assert prefix_len > 0
prefix_len -= 1
else:
prefix_len = 0
prefix_len = r.prefix_len
if prefix_len == orig_input_length:
assert prefix_len > 0
prefix_len -= 1
prefix_ids.append(tokenized_input[:prefix_len])
tokenized_input = tokenized_input[prefix_len:]

View File

@ -14,8 +14,8 @@ assert (
), f"Attention is not valid {ATTENTION}, expected {_expected}"
log_master(logger.info, f"Using Attention = {ATTENTION}")
if PREFIX_CACHING and ATTENTION != "flashinfer":
raise RuntimeError("Prefix caching is only supported with flashinfer")
# if PREFIX_CACHING and ATTENTION != "flashinfer":
# raise RuntimeError("Prefix caching is only supported with flashinfer")
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None