parent
6486887b43
commit
2a255ad719
|
@ -230,6 +230,8 @@ impl RadixAllocator {
|
||||||
allocation_id: 0,
|
allocation_id: 0,
|
||||||
allocations: HashMap::new(),
|
allocations: HashMap::new(),
|
||||||
cache_blocks: RadixTrie::new(),
|
cache_blocks: RadixTrie::new(),
|
||||||
|
|
||||||
|
// Block 0 is reserved for health checks.
|
||||||
free_blocks: (1..n_blocks).collect(),
|
free_blocks: (1..n_blocks).collect(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -248,7 +250,10 @@ impl RadixAllocator {
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.free_blocks.len() >= n_blocks_needed {
|
if self.free_blocks.len() >= n_blocks_needed {
|
||||||
Some(self.free_blocks.split_off(n_blocks_needed))
|
Some(
|
||||||
|
self.free_blocks
|
||||||
|
.split_off(self.free_blocks.len() - n_blocks_needed),
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
@ -316,9 +321,21 @@ impl Allocator for RadixAllocator {
|
||||||
// If there are prefill tokens that did not come from the cache,
|
// If there are prefill tokens that did not come from the cache,
|
||||||
// add them to the cache.
|
// add them to the cache.
|
||||||
if prefill_tokens.len() > allocation.cached_prefix_len {
|
if prefill_tokens.len() > allocation.cached_prefix_len {
|
||||||
// TODO: check if the prefill tokens are already in the cache???
|
let prefix_len = self
|
||||||
self.cache_blocks
|
.cache_blocks
|
||||||
.insert(prefill_tokens, &blocks[..prefill_tokens.len()]);
|
.insert(prefill_tokens, &blocks[..prefill_tokens.len()]);
|
||||||
|
|
||||||
|
// We can have a prefill with the following structure:
|
||||||
|
//
|
||||||
|
// |---| From the prefix cache.
|
||||||
|
// A B C D E F G
|
||||||
|
//|--------| Found in the trie during insertion.
|
||||||
|
//
|
||||||
|
// This means that while processing this request there was a
|
||||||
|
// partially overlapping request that had A..=E in its
|
||||||
|
// prefill. In this case we need to free the blocks D E.
|
||||||
|
self.free_blocks
|
||||||
|
.extend(&blocks[allocation.cached_prefix_len..prefix_len]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Free non-prefill blocks.
|
// Free non-prefill blocks.
|
||||||
|
@ -334,3 +351,60 @@ struct RadixAllocation {
|
||||||
cached_prefix_len: usize,
|
cached_prefix_len: usize,
|
||||||
prefill_tokens: Option<Arc<Vec<u32>>>,
|
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use std::{rc::Rc, sync::Arc};
|
||||||
|
|
||||||
|
use super::{Allocator, RadixAllocator};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_prefix_cache() {
|
||||||
|
let mut cache = RadixAllocator::new(1, 12, None);
|
||||||
|
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||||
|
assert_eq!(allocation.0, vec![4, 5, 6, 7, 8, 9, 10, 11]);
|
||||||
|
assert_eq!(allocation.1, allocation.0);
|
||||||
|
assert_eq!(allocation.2, 0);
|
||||||
|
cache.free(allocation.0, allocation.3);
|
||||||
|
|
||||||
|
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||||
|
assert_eq!(allocation.0, vec![4, 5, 6, 7, 8, 9, 10, 11]);
|
||||||
|
assert_eq!(allocation.2, 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_older_prefixes_are_collected_first() {
|
||||||
|
let mut cache = RadixAllocator::new(1, 7, None);
|
||||||
|
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||||
|
assert_eq!(allocation1.0, vec![3, 4, 5, 6]);
|
||||||
|
assert_eq!(allocation1.2, 0);
|
||||||
|
|
||||||
|
let allocation2 = cache.allocate(2, Some(Arc::new(vec![4, 5]))).unwrap();
|
||||||
|
assert_eq!(allocation2.0, vec![1, 2]);
|
||||||
|
assert_eq!(allocation2.2, 0);
|
||||||
|
|
||||||
|
cache.free(allocation1.0, allocation1.3);
|
||||||
|
cache.free(allocation2.0, allocation2.3);
|
||||||
|
|
||||||
|
// We should get the blocks of the first allocation, since they are more recent.
|
||||||
|
let allocation3 = cache.allocate(4, Some(Arc::new(vec![6, 7, 8, 9]))).unwrap();
|
||||||
|
assert_eq!(allocation3.0, vec![3, 4, 5, 6]);
|
||||||
|
assert_eq!(allocation3.2, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn correctly_free_when_fully_overlapping_prefills_in_flight() {
|
||||||
|
let mut cache = RadixAllocator::new(1, 10, None);
|
||||||
|
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||||
|
let allocation2 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||||
|
|
||||||
|
cache.free(allocation2.0, allocation2.3);
|
||||||
|
cache.free(allocation1.0, allocation1.3);
|
||||||
|
|
||||||
|
let allocation3 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||||
|
assert_eq!(allocation3.2, 4);
|
||||||
|
|
||||||
|
// 10 blocks, of which 1 reserved for health checks, 4 for the cached blocks.
|
||||||
|
assert_eq!(cache.free_blocks.len(), 5);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -118,6 +118,11 @@ impl RadixTrie {
|
||||||
node.ref_count += 1;
|
node.ref_count += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Insert a prefill along with its blocks.
|
||||||
|
///
|
||||||
|
/// This method returns the length of the prefix that was already
|
||||||
|
/// in the trie. E.g. if the length is 10, this means that for
|
||||||
|
/// the first 10 elements of the tree **the blocks are not updated**.
|
||||||
pub fn insert(&mut self, key: &[u32], blocks: &[u32]) -> usize {
|
pub fn insert(&mut self, key: &[u32], blocks: &[u32]) -> usize {
|
||||||
self.time += 1;
|
self.time += 1;
|
||||||
self.insert_(self.root, key, blocks)
|
self.insert_(self.root, key, blocks)
|
||||||
|
@ -152,10 +157,10 @@ impl RadixTrie {
|
||||||
let child_id = self.split_node(child_id, shared_prefix_len);
|
let child_id = self.split_node(child_id, shared_prefix_len);
|
||||||
let key = &key[shared_prefix_len..];
|
let key = &key[shared_prefix_len..];
|
||||||
let blocks = &blocks[shared_prefix_len..];
|
let blocks = &blocks[shared_prefix_len..];
|
||||||
self.insert_(child_id, key, blocks)
|
shared_prefix_len + self.insert_(child_id, key, blocks)
|
||||||
} else {
|
} else {
|
||||||
self.add_node(node_id, key, blocks);
|
self.add_node(node_id, key, blocks);
|
||||||
key.len()
|
0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue