From 590fc2c58dfb6869b2f23c80ace9ecbda9e0bcbe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Fri, 2 Aug 2024 13:50:56 +0000 Subject: [PATCH] Double linked data structures are still terrible in Rust. So use fake pointers. --- Cargo.lock | 1 + backends/v3/Cargo.toml | 1 + backends/v3/src/block_allocator.rs | 12 +- backends/v3/src/lib.rs | 2 +- backends/v3/src/radix.rs | 203 +++++++++++++++++------------ 5 files changed, 129 insertions(+), 90 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 92367d1e..3a5845a7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4045,6 +4045,7 @@ dependencies = [ "reqwest", "serde", "serde_json", + "slotmap", "text-generation-router", "thiserror", "tokenizers", diff --git a/backends/v3/Cargo.toml b/backends/v3/Cargo.toml index 5d9a140b..129ceb9c 100644 --- a/backends/v3/Cargo.toml +++ b/backends/v3/Cargo.toml @@ -33,6 +33,7 @@ rand = "0.8.5" reqwest = { version = "0.11.20", features = [] } serde = "1.0.188" serde_json = "1.0.107" +slotmap = "1.0.7" thiserror = "1.0.48" tokenizers = { workspace = true} tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } diff --git a/backends/v3/src/block_allocator.rs b/backends/v3/src/block_allocator.rs index 03b26b05..b4735428 100644 --- a/backends/v3/src/block_allocator.rs +++ b/backends/v3/src/block_allocator.rs @@ -1,11 +1,7 @@ -use std::{ - cmp::min, - collections::{hash_map::Entry, BTreeSet, HashMap}, - sync::Arc, -}; +use std::{cmp::min, collections::BTreeSet, sync::Arc}; use tokio::sync::{mpsc, oneshot}; -use crate::TrieNode; +use crate::RadixTrie; #[derive(Debug, Clone)] pub(crate) struct BlockAllocation { @@ -212,7 +208,7 @@ struct PrefixBlockState { } struct RadixAllocator { - cache_blocks: TrieNode, + cache_blocks: RadixTrie, /// Blocks that are immediately available for allocation. free_blocks: Vec, @@ -236,7 +232,7 @@ impl RadixAllocator { } RadixAllocator { - cache_blocks: TrieNode::new(vec![], vec![], 0), + cache_blocks: RadixTrie::new(), free_blocks: (1..n_blocks).collect(), leaves: BTreeSet::new(), time: 0, diff --git a/backends/v3/src/lib.rs b/backends/v3/src/lib.rs index 190274c6..81e6e4fa 100644 --- a/backends/v3/src/lib.rs +++ b/backends/v3/src/lib.rs @@ -6,7 +6,7 @@ mod radix; use crate::client::{ClientError, ShardedClient}; pub(crate) use backend::BackendV3; -pub(crate) use radix::TrieNode; +pub(crate) use radix::RadixTrie; use serde::Serialize; use thiserror::Error; use utoipa::ToSchema; diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs index fb6756c8..c7236e97 100644 --- a/backends/v3/src/radix.rs +++ b/backends/v3/src/radix.rs @@ -1,5 +1,7 @@ use std::collections::{hash_map::Entry, HashMap}; +use slotmap::{DefaultKey, SlotMap}; + // Radix trie that is heavily inspired by radix attention from sglang. // // The trie is optimized for prefix caching: @@ -12,16 +14,115 @@ use std::collections::{hash_map::Entry, HashMap}; // - We store additional information in each node, such as last access // time and a reference count. -#[derive(Debug)] -pub struct TrieNode { - children: HashMap, +type NodeId = DefaultKey; + +pub struct RadixTrie { + root: DefaultKey, + nodes: SlotMap, + time: u64, +} + +impl RadixTrie { + pub fn new() -> Self { + let root = TrieNode::new(vec![], vec![], 0); + let mut nodes = SlotMap::new(); + let root = nodes.insert(root); + RadixTrie { + nodes, + root, + time: 0, + } + } + + pub fn find(&self, key: &[u32], blocks: &mut Vec) { + self.find_(self.root, key, blocks); + } + + fn find_(&self, node_id: NodeId, key: &[u32], blocks: &mut Vec) { + let node = &self.nodes[node_id]; + + if let Some(&child_id) = node.children.get(&key[0]) { + let child = &self.nodes[child_id]; + let shared_prefix_len = child.key.shared_prefix_len(key); + blocks.extend(&child.blocks[..shared_prefix_len]); + + let key = &key[shared_prefix_len..]; + if !key.is_empty() { + self.find_(child_id, key, blocks); + } + } + } + + pub fn insert(&mut self, key: &[u32], blocks: &[u32]) -> usize { + self.time += 1; + self.insert_(self.root, key, blocks) + } + + fn insert_(&mut self, node_id: NodeId, key: &[u32], blocks: &[u32]) -> usize { + assert_eq!(key.len(), blocks.len()); + + //let node = self.nodes.get_mut(node).unwrap(); + + if let Some(&child_id) = self.nodes[node_id].children.get(&key[0]) { + let child = self.nodes.get_mut(child_id).unwrap(); + let shared_prefix_len = child.key.shared_prefix_len(key); + + // We are done, the prefix is already in the trie. + if shared_prefix_len == key.len() { + return shared_prefix_len; + } + + // The node's prefix is a prefix of the insertion prefix. + if shared_prefix_len == child.key.len() { + return shared_prefix_len + + self.insert_( + child_id, + &key[shared_prefix_len..], + &blocks[shared_prefix_len..], + ); + } + + // The node's prefix and the insertion prefix only match partially, + // split the node to just contain the matching part. Then insert the + // remainder of the prefix into the node again. + self.split(child_id, shared_prefix_len); + let key = &key[shared_prefix_len..]; + let blocks = &blocks[shared_prefix_len..]; + self.insert_(child_id, key, blocks) + } else { + let child = TrieNode::new(key.to_vec(), blocks.to_vec(), self.time); + let child_id = self.nodes.insert(child); + let node = self.nodes.get_mut(node_id).unwrap(); + node.children.insert(key[0], child_id); + return key.len(); + } + } + + fn split(&mut self, node_id: NodeId, prefix_len: usize) { + let node = self.nodes.get_mut(node_id).unwrap(); + + let rest_key = node.key.split_off(prefix_len); + let rest_blocks = node.blocks.split_off(prefix_len); + let first = rest_key[0]; + + let new_id = self + .nodes + .insert(TrieNode::new(rest_key, rest_blocks, self.time)); + + let node = self.nodes.get_mut(node_id).unwrap(); + node.children.insert(first, new_id); + } +} + +struct TrieNode { + children: HashMap, key: Vec, blocks: Vec, last_accessed: u64, } impl TrieNode { - pub fn new(key: Vec, blocks: Vec, last_accessed: u64) -> Self { + fn new(key: Vec, blocks: Vec, last_accessed: u64) -> Self { TrieNode { children: HashMap::new(), key, @@ -29,66 +130,6 @@ impl TrieNode { last_accessed, } } - - pub fn find(&self, key: &[u32], blocks: &mut Vec) { - if let Some(child) = self.children.get(&key[0]) { - let shared_prefix_len = child.key.shared_prefix_len(key); - blocks.extend(&child.blocks[..shared_prefix_len]); - - let key = &key[shared_prefix_len..]; - if !key.is_empty() { - child.find(key, blocks); - } - } - } - - // Insert a prefix into the trie. Returns the length of the shared prefix. - pub fn insert(&mut self, key: &[u32], blocks: &[u32]) -> usize { - assert_eq!(key.len(), blocks.len()); - - match self.children.entry(key[0]) { - Entry::Occupied(entry) => { - let child = entry.into_mut(); - let shared_prefix_len = child.key.shared_prefix_len(key); - - // We are done, the prefix is already in the trie. - if shared_prefix_len == key.len() { - return shared_prefix_len; - } - - // The node's prefix is a prefix of the insertion prefix. - if shared_prefix_len == child.key.len() { - return shared_prefix_len - + child.insert(&key[shared_prefix_len..], &blocks[shared_prefix_len..]); - } - - // The node's prefix and the insertion prefix only match partially, - // split the node to just contain the matching part. Then insert the - // remainder of the prefix into the node again. - child.split(shared_prefix_len); - let key = &key[shared_prefix_len..]; - let blocks = &blocks[shared_prefix_len..]; - child.insert(key, blocks) - } - Entry::Vacant(entry) => { - let child = TrieNode::new(key.to_vec(), blocks.to_vec(), self.last_accessed); - entry.insert(child); - return key.len(); - } - } - - //node.last_accessed = last_accessed; - } - - fn split(&mut self, prefix_len: usize) { - let rest_key = self.key.split_off(prefix_len); - let rest_blocks = self.blocks.split_off(prefix_len); - - self.children.insert( - rest_key[0], - TrieNode::new(rest_key, rest_blocks, self.last_accessed), - ); - } } trait SharedPrefixLen { @@ -108,56 +149,56 @@ where mod tests { #[test] fn insertions_have_correct_prefix_len() { - let mut root = super::TrieNode::new(vec![], vec![], 0); + let mut trie = super::RadixTrie::new(); - assert_eq!(root.insert(&[0, 1, 2], &[0, 1, 2]), 3); + assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]), 3); // Already exists. - assert_eq!(root.insert(&[0, 1, 2], &[0, 1, 2]), 3); + assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]), 3); // Completely new at root-level - assert_eq!(root.insert(&[1, 2, 3], &[1, 2, 3]), 3); + assert_eq!(trie.insert(&[1, 2, 3], &[1, 2, 3]), 3); // Contains full prefix, but longer. - assert_eq!(root.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]), 5); + assert_eq!(trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]), 5); // Shares partial prefix, we need a split. assert_eq!( - root.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]), + trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]), 6 ); } #[test] fn prefix_get_returns_correct_blocks() { - let mut root = super::TrieNode::new(vec![], vec![], 0); - root.insert(&[0, 1, 2], &[0, 1, 2]); - root.insert(&[1, 2, 3], &[1, 2, 3]); - root.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]); - root.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]); + let mut trie = super::RadixTrie::new(); + trie.insert(&[0, 1, 2], &[0, 1, 2]); + trie.insert(&[1, 2, 3], &[1, 2, 3]); + trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]); + trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]); let mut blocks = Vec::new(); - root.find(&[0], &mut blocks); + trie.find(&[0], &mut blocks); assert_eq!(blocks, vec![0]); blocks.clear(); - root.find(&[0, 1, 2], &mut blocks); + trie.find(&[0, 1, 2], &mut blocks); assert_eq!(blocks, vec![0, 1, 2]); blocks.clear(); - root.find(&[1, 2, 3], &mut blocks); + trie.find(&[1, 2, 3], &mut blocks); assert_eq!(blocks, vec![1, 2, 3]); blocks.clear(); - root.find(&[0, 1, 2, 3], &mut blocks); + trie.find(&[0, 1, 2, 3], &mut blocks); assert_eq!(blocks, vec![0, 1, 2, 3]); blocks.clear(); - root.find(&[0, 1, 2, 3, 4], &mut blocks); + trie.find(&[0, 1, 2, 3, 4], &mut blocks); assert_eq!(blocks, vec![0, 1, 2, 3, 4]); blocks.clear(); - root.find(&[0, 1, 2, 3, 5], &mut blocks); + trie.find(&[0, 1, 2, 3, 5], &mut blocks); assert_eq!(blocks, vec![0, 1, 2, 3, 5]); } }