From 379472c4c2e401b1efd66d7d47edc00b96f5ce14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 6 Sep 2024 11:55:23 +0200 Subject: [PATCH] radix trie: add assertions (#2491) These should all be cheap assertions. Also: * Fixup some comments. * Delete a `remove` that was done unnecessarily twice. --- backends/v3/src/radix.rs | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs index b85be00b..bb6582b0 100644 --- a/backends/v3/src/radix.rs +++ b/backends/v3/src/radix.rs @@ -73,14 +73,13 @@ impl Allocator for RadixAllocator { let node_id = self .cache_blocks .find(prefill_tokens.as_slice(), &mut blocks); - // Even if this allocation fails below, we need to increase he - // refcount to ensure that the prefix that was found is not evicted. - node_id } else { self.cache_blocks.root_id() }; + // Even if this allocation fails below, we need to increase he + // refcount to ensure that the prefix that was found is not evicted. self.cache_blocks .incref(prefix_node) .expect("Failed to increment refcount"); @@ -303,6 +302,11 @@ impl RadixTrie { node.ref_count -= 1; if node.ref_count == 0 { + assert!( + node.children.is_empty(), + "Nodes with children must have refcount > 0" + ); + self.leaves.insert((node.last_accessed, node_id)); } @@ -330,7 +334,7 @@ impl RadixTrie { /// Evict `n_blocks` from the trie. /// /// Returns the evicted blocks. When the length is less than `n_blocks`, - /// not enough blocks could beevicted. + /// not enough blocks could be evicted. pub fn evict(&mut self, n_blocks: usize) -> Vec { // NOTE: we don't return Result here. If any of the unwrapping fails, // it's a programming error in the trie implementation, not a user @@ -345,6 +349,12 @@ impl RadixTrie { let blocks_needed = n_blocks - evicted.len(); let node = self.nodes.get(node_id).expect("Leave does not exist"); + assert_eq!( + node.ref_count, 0, + "Leaf must have refcount of 0, got {}", + node.ref_count + ); + if blocks_needed >= node.blocks.len() { // We need to evict the whole node if we need more blocks than it has. let node = self.remove_node(node_id); @@ -500,12 +510,16 @@ impl RadixTrie { fn remove_node(&mut self, node_id: NodeId) -> TrieNode { // Unwrap here, passing in an unknown id is a programming error. let node = self.nodes.remove(node_id).expect("Unknown node"); + assert!( + node.children.is_empty(), + "Tried to remove a node with {} children", + node.children.len() + ); let parent_id = node.parent.expect("Attempted to remove root node"); let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node"); parent.children.remove(&node.key[0]); self.decref(parent_id) .expect("Failed to decrease parent refcount"); - self.nodes.remove(node_id); node } @@ -579,6 +593,9 @@ impl TrieNode { fn shared_prefix(left: &[u32], right: &[u32], block_size: usize) -> usize { let full = left.iter().zip(right).take_while(|(a, b)| a == b).count(); + // NOTE: this is the case because the child node was chosen based on + // matching the first character of the key/prefix. + assert!(full > 0, "Prefixes must at least share 1 token"); (full / block_size) * block_size }