Initial radix cache tests

And fix issues.
This commit is contained in:
Daniël de Kok 2024-08-06 10:50:18 +00:00
parent 6486887b43
commit 2a255ad719
2 changed files with 84 additions and 5 deletions

View File

@ -230,6 +230,8 @@ impl RadixAllocator {
allocation_id: 0,
allocations: HashMap::new(),
cache_blocks: RadixTrie::new(),
// Block 0 is reserved for health checks.
free_blocks: (1..n_blocks).collect(),
}
}
@ -248,7 +250,10 @@ impl RadixAllocator {
}
if self.free_blocks.len() >= n_blocks_needed {
Some(self.free_blocks.split_off(n_blocks_needed))
Some(
self.free_blocks
.split_off(self.free_blocks.len() - n_blocks_needed),
)
} else {
None
}
@ -316,9 +321,21 @@ impl Allocator for RadixAllocator {
// If there are prefill tokens that did not come from the cache,
// add them to the cache.
if prefill_tokens.len() > allocation.cached_prefix_len {
// TODO: check if the prefill tokens are already in the cache???
self.cache_blocks
let prefix_len = self
.cache_blocks
.insert(prefill_tokens, &blocks[..prefill_tokens.len()]);
// We can have a prefill with the following structure:
//
// |---| From the prefix cache.
// A B C D E F G
//|--------| Found in the trie during insertion.
//
// This means that while processing this request there was a
// partially overlapping request that had A..=E in its
// prefill. In this case we need to free the blocks D E.
self.free_blocks
.extend(&blocks[allocation.cached_prefix_len..prefix_len]);
}
// Free non-prefill blocks.
@ -334,3 +351,60 @@ struct RadixAllocation {
cached_prefix_len: usize,
prefill_tokens: Option<Arc<Vec<u32>>>,
}
#[cfg(test)]
mod tests {
use std::{rc::Rc, sync::Arc};
use super::{Allocator, RadixAllocator};
#[test]
fn test_prefix_cache() {
let mut cache = RadixAllocator::new(1, 12, None);
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation.0, vec![4, 5, 6, 7, 8, 9, 10, 11]);
assert_eq!(allocation.1, allocation.0);
assert_eq!(allocation.2, 0);
cache.free(allocation.0, allocation.3);
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation.0, vec![4, 5, 6, 7, 8, 9, 10, 11]);
assert_eq!(allocation.2, 4);
}
#[test]
fn test_older_prefixes_are_collected_first() {
let mut cache = RadixAllocator::new(1, 7, None);
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation1.0, vec![3, 4, 5, 6]);
assert_eq!(allocation1.2, 0);
let allocation2 = cache.allocate(2, Some(Arc::new(vec![4, 5]))).unwrap();
assert_eq!(allocation2.0, vec![1, 2]);
assert_eq!(allocation2.2, 0);
cache.free(allocation1.0, allocation1.3);
cache.free(allocation2.0, allocation2.3);
// We should get the blocks of the first allocation, since they are more recent.
let allocation3 = cache.allocate(4, Some(Arc::new(vec![6, 7, 8, 9]))).unwrap();
assert_eq!(allocation3.0, vec![3, 4, 5, 6]);
assert_eq!(allocation3.2, 0);
}
#[test]
fn correctly_free_when_fully_overlapping_prefills_in_flight() {
let mut cache = RadixAllocator::new(1, 10, None);
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
let allocation2 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
cache.free(allocation2.0, allocation2.3);
cache.free(allocation1.0, allocation1.3);
let allocation3 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation3.2, 4);
// 10 blocks, of which 1 reserved for health checks, 4 for the cached blocks.
assert_eq!(cache.free_blocks.len(), 5);
}
}

View File

@ -118,6 +118,11 @@ impl RadixTrie {
node.ref_count += 1;
}
/// Insert a prefill along with its blocks.
///
/// This method returns the length of the prefix that was already
/// in the trie. E.g. if the length is 10, this means that for
/// the first 10 elements of the tree **the blocks are not updated**.
pub fn insert(&mut self, key: &[u32], blocks: &[u32]) -> usize {
self.time += 1;
self.insert_(self.root, key, blocks)
@ -152,10 +157,10 @@ impl RadixTrie {
let child_id = self.split_node(child_id, shared_prefix_len);
let key = &key[shared_prefix_len..];
let blocks = &blocks[shared_prefix_len..];
self.insert_(child_id, key, blocks)
shared_prefix_len + self.insert_(child_id, key, blocks)
} else {
self.add_node(node_id, key, blocks);
key.len()
0
}
}