Handling debugger.

This commit is contained in:
Nicolas Patry 2024-08-26 14:59:27 +02:00
parent c53968dc45
commit 682db34b6a
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
8 changed files with 186 additions and 70 deletions

View File

@ -43,13 +43,7 @@ impl BackendV3 {
let attention: Attention = attention let attention: Attention = attention
.parse() .parse()
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`")); .unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"));
let block_size = if attention == Attention::FlashDecoding { let block_size = attention.block_size();
256
} else if attention == Attention::FlashInfer {
1
} else {
16
};
let queue = Queue::new( let queue = Queue::new(
requires_padding, requires_padding,

View File

@ -91,7 +91,11 @@ async fn block_allocator_task(
window_size: Option<u32>, window_size: Option<u32>,
mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>, mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>,
) { ) {
let mut allocator = RadixAllocator::new(block_size, blocks, window_size, prefix_caching); let mut allocator: Box<dyn Allocator + Send> = if prefix_caching {
Box::new(RadixAllocator::new(block_size, blocks, window_size))
} else {
Box::new(SimpleAllocator::new(blocks, block_size, window_size))
};
while let Some(cmd) = receiver.recv().await { while let Some(cmd) = receiver.recv().await {
match cmd { match cmd {
BlockAllocatorCommand::Free { BlockAllocatorCommand::Free {
@ -124,12 +128,82 @@ enum BlockAllocatorCommand {
}, },
} }
// pub trait Allocator { pub trait Allocator {
// fn allocate( fn allocate(
// &mut self, &mut self,
// tokens: u32, tokens: u32,
// prefill_tokens: Option<Arc<Vec<u32>>>, prefill_tokens: Option<Arc<Vec<u32>>>,
// ) -> Option<BlockAllocation>; ) -> Option<BlockAllocation>;
//
// fn free(&mut self, blocks: Vec<u32>, allocation_id: u64); fn free(&mut self, blocks: Vec<u32>, allocation_id: u64);
// } }
pub struct SimpleAllocator {
free_blocks: Vec<u32>,
block_size: u32,
window_size: Option<u32>,
}
impl SimpleAllocator {
fn new(blocks: u32, block_size: u32, window_size: Option<u32>) -> Self {
SimpleAllocator {
block_size,
// Block 0 is reserved for health checks
free_blocks: (1..blocks).collect(),
window_size,
}
}
}
impl Allocator for SimpleAllocator {
fn allocate(
&mut self,
tokens: u32,
_prefill_tokens: Option<Arc<Vec<u32>>>,
) -> Option<BlockAllocation> {
// Apply window size
let (required_blocks, repeats) = {
let (tokens, repeats) = match self.window_size {
None => (tokens, 1),
Some(window_size) => {
let repeats = (tokens + window_size - 1) / window_size;
let tokens = core::cmp::min(tokens, window_size);
(tokens, repeats as usize)
}
};
// Pad to a multiple of block size
let required_blocks = (tokens + self.block_size - 1) / self.block_size;
(required_blocks, repeats)
};
let tokens = tokens as usize;
if required_blocks > self.free_blocks.len() as u32 {
None
} else {
let blocks = self
.free_blocks
.split_off(self.free_blocks.len() - required_blocks as usize);
let mut slots =
Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);
'slots: for block_id in blocks.repeat(repeats).iter() {
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
slots.push(s);
if slots.len() == tokens {
break 'slots;
}
}
}
Some(BlockAllocation {
allocation_id: 0,
blocks,
slots,
prefix_len: 0,
block_allocator: None,
})
}
}
fn free(&mut self, blocks: Vec<u32>, _allocation_id: u64) {
self.free_blocks.extend(blocks)
}
}

View File

@ -333,7 +333,7 @@ impl State {
break 'entry_loop; break 'entry_loop;
} }
Some(block_allocation) => { Some(block_allocation) => {
tracing::debug!("Allocation: {block_allocation:?}"); tracing::info!("Allocation: {block_allocation:?}");
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32); max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
Some(block_allocation) Some(block_allocation)
} }

View File

@ -1,12 +1,10 @@
use crate::block_allocator::{Allocator, BlockAllocation};
use slotmap::{DefaultKey, SlotMap};
use std::{ use std::{
collections::{BTreeSet, HashMap}, collections::{BTreeSet, HashMap},
sync::Arc, sync::Arc,
}; };
use slotmap::{DefaultKey, SlotMap};
use crate::block_allocator::BlockAllocation;
pub struct RadixAllocator { pub struct RadixAllocator {
allocation_id: u64, allocation_id: u64,
@ -21,25 +19,15 @@ pub struct RadixAllocator {
// This isn't used because the prefix need to match without the windowing // This isn't used because the prefix need to match without the windowing
// mecanism. This at worst is overallocating, not necessarily being wrong. // mecanism. This at worst is overallocating, not necessarily being wrong.
window_size: Option<u32>, window_size: Option<u32>,
/// Wether to actual use the radix tree for searching or not.
prefix_caching: bool,
} }
impl RadixAllocator { impl RadixAllocator {
pub fn new( pub fn new(block_size: u32, n_blocks: u32, window_size: Option<u32>) -> Self {
block_size: u32,
n_blocks: u32,
window_size: Option<u32>,
prefix_caching: bool,
) -> Self {
if prefix_caching {
assert_eq!( assert_eq!(
block_size, 1, block_size, 1,
"Radix tree allocator only works with block_size=1, was: {}", "Radix tree allocator only works with block_size=1, was: {}",
block_size block_size
); );
}
// if window_size.is_some() { // if window_size.is_some() {
// unimplemented!("Window size not supported in the prefix-caching block allocator yet"); // unimplemented!("Window size not supported in the prefix-caching block allocator yet");
// } // }
@ -52,7 +40,6 @@ impl RadixAllocator {
// Block 0 is reserved for health checks. // Block 0 is reserved for health checks.
free_blocks: (1..n_blocks).collect(), free_blocks: (1..n_blocks).collect(),
window_size, window_size,
prefix_caching,
} }
} }
@ -81,15 +68,14 @@ impl RadixAllocator {
} }
// Allocator trait // Allocator trait
impl RadixAllocator { impl Allocator for RadixAllocator {
pub fn allocate( fn allocate(
&mut self, &mut self,
tokens: u32, tokens: u32,
prefill_tokens: Option<Arc<Vec<u32>>>, prefill_tokens: Option<Arc<Vec<u32>>>,
) -> Option<BlockAllocation> { ) -> Option<BlockAllocation> {
let mut blocks = vec![]; let mut blocks = vec![];
let prefix_node = match (self.prefix_caching, prefill_tokens.as_ref()) { let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() {
(true, Some(prefill_tokens)) => {
let node_id = self let node_id = self
.cache_blocks .cache_blocks
.find(prefill_tokens.as_slice(), &mut blocks); .find(prefill_tokens.as_slice(), &mut blocks);
@ -97,8 +83,8 @@ impl RadixAllocator {
// refcount to ensure that the prefix that was found is not evicted. // refcount to ensure that the prefix that was found is not evicted.
node_id node_id
} } else {
_ => self.cache_blocks.root_id(), self.cache_blocks.root_id()
}; };
self.cache_blocks self.cache_blocks
@ -108,7 +94,9 @@ impl RadixAllocator {
let prefix_len = blocks.len(); let prefix_len = blocks.len();
let suffix_len = tokens - prefix_len as u32; let suffix_len = tokens - prefix_len as u32;
match self.alloc_or_reclaim(suffix_len as usize) { let suffix_blocks = suffix_len;
match self.alloc_or_reclaim(suffix_blocks as usize) {
Some(suffix_blocks) => blocks.extend(suffix_blocks), Some(suffix_blocks) => blocks.extend(suffix_blocks),
None => { None => {
self.cache_blocks self.cache_blocks
@ -127,6 +115,8 @@ impl RadixAllocator {
prefill_tokens: prefill_tokens.clone(), prefill_tokens: prefill_tokens.clone(),
}; };
tracing::info!("Blocks {blocks:?}");
self.allocation_id += 1; self.allocation_id += 1;
self.allocations.insert(self.allocation_id, allocation); self.allocations.insert(self.allocation_id, allocation);
@ -139,7 +129,7 @@ impl RadixAllocator {
}) })
} }
pub fn free(&mut self, blocks: Vec<u32>, allocation_id: u64) { fn free(&mut self, blocks: Vec<u32>, allocation_id: u64) {
let allocation = match self.allocations.remove(&allocation_id) { let allocation = match self.allocations.remove(&allocation_id) {
Some(allocation) => allocation, Some(allocation) => allocation,
None => unreachable!("Tried to free an unknown allocation."), None => unreachable!("Tried to free an unknown allocation."),
@ -613,7 +603,21 @@ mod tests {
cache.free(allocation.blocks.clone(), allocation.allocation_id); cache.free(allocation.blocks.clone(), allocation.allocation_id);
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation.blocks, vec![1, 2, 3, 8, 9, 10, 11, 7]); assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
assert_eq!(allocation.prefix_len, 0);
}
#[test]
fn allocator_block_size() {
let mut cache = RadixAllocator::new(256, 12, None, false);
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation.blocks, vec![11]);
assert_eq!(allocation.slots, allocation.slots);
assert_eq!(allocation.prefix_len, 0);
cache.free(allocation.blocks.clone(), allocation.allocation_id);
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation.blocks, vec![11]);
assert_eq!(allocation.prefix_len, 0); assert_eq!(allocation.prefix_len, 0);
} }

View File

@ -835,11 +835,11 @@
] ]
}, },
"locked": { "locked": {
"lastModified": 1724379657, "lastModified": 1724638882,
"narHash": "sha256-+CFDh1FUgyY7q0FiWhKJpHS7LlD3KbiqN5Z4Z+4bGmc=", "narHash": "sha256-ap2jIQi/FuUHR6HCht6ASWhoz8EiB99XmI8Esot38VE=",
"owner": "oxalica", "owner": "oxalica",
"repo": "rust-overlay", "repo": "rust-overlay",
"rev": "a18034322c7703fcfe5d7352a77981ba4a936a61", "rev": "19b70f147b9c67a759e35824b241f1ed92e46694",
"type": "github" "type": "github"
}, },
"original": { "original": {

View File

@ -57,7 +57,7 @@
{ {
devShells = with pkgs; rec { devShells = with pkgs; rec {
default = pure; default = impure;
pure = mkShell { pure = mkShell {
buildInputs = [ buildInputs = [

View File

@ -8,7 +8,7 @@ use nix::unistd::Pid;
use serde::Deserialize; use serde::Deserialize;
use std::env; use std::env;
use std::ffi::OsString; use std::ffi::OsString;
use std::io::{BufRead, BufReader, Lines}; use std::io::{BufRead, BufReader};
use std::os::unix::process::{CommandExt, ExitStatusExt}; use std::os::unix::process::{CommandExt, ExitStatusExt};
use std::path::Path; use std::path::Path;
use std::process::{Child, Command, ExitStatus, Stdio}; use std::process::{Child, Command, ExitStatus, Stdio};
@ -18,7 +18,10 @@ use std::sync::{mpsc, Arc};
use std::thread; use std::thread;
use std::thread::sleep; use std::thread::sleep;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use std::{fs, io}; use std::{
fs, io,
io::{Read, Write},
};
use thiserror::Error; use thiserror::Error;
use tracing_subscriber::{filter::LevelFilter, EnvFilter}; use tracing_subscriber::{filter::LevelFilter, EnvFilter};
@ -833,6 +836,7 @@ fn shard_manager(
.args(shard_args) .args(shard_args)
.env_clear() .env_clear()
.envs(envs) .envs(envs)
.stdin(Stdio::piped())
.stdout(Stdio::piped()) .stdout(Stdio::piped())
.stderr(Stdio::piped()) .stderr(Stdio::piped())
.process_group(0) .process_group(0)
@ -854,12 +858,13 @@ fn shard_manager(
}; };
// Redirect STDOUT to the console // Redirect STDOUT to the console
let mut pstdin = p.stdin.take().unwrap();
let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap()); let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap());
let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap()); let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap());
//stdout tracing thread //stdout tracing thread
thread::spawn(move || { thread::spawn(move || {
log_lines(shard_stdout_reader.lines()); log_lines(shard_stdout_reader);
}); });
// We read stderr in another thread as it seems that lines() can block in some cases // We read stderr in another thread as it seems that lines() can block in some cases
let (err_sender, err_receiver) = mpsc::channel(); let (err_sender, err_receiver) = mpsc::channel();
@ -868,6 +873,18 @@ fn shard_manager(
err_sender.send(line).unwrap_or(()); err_sender.send(line).unwrap_or(());
} }
}); });
// We read stdin in another thread as it seems that lines() can block in some cases
thread::spawn(move || {
let mut stdin = io::stdin(); // We get `Stdin` here.
loop {
let mut buffer = vec![0; 4096];
if let Ok(n) = stdin.read(&mut buffer) {
if n > 0 {
let _ = pstdin.write_all(&buffer[..n]);
}
}
}
});
let mut ready = false; let mut ready = false;
let start_time = Instant::now(); let start_time = Instant::now();
@ -974,19 +991,36 @@ impl PythonLogMessage {
} }
} }
impl TryFrom<&String> for PythonLogMessage { impl TryFrom<&[u8]> for PythonLogMessage {
type Error = serde_json::Error; type Error = serde_json::Error;
fn try_from(value: &String) -> Result<Self, Self::Error> { fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
serde_json::from_str::<Self>(value) serde_json::from_slice::<Self>(value)
} }
} }
fn log_lines<S: Sized + BufRead>(lines: Lines<S>) { fn log_lines<R: Sized + Read>(mut bufread: BufReader<R>) {
for line in lines.map_while(Result::ok) { let mut buffer = vec![0u8; 4096];
match PythonLogMessage::try_from(&line) { let mut stdout = std::io::stdout();
loop {
let n = bufread.read(&mut buffer);
if let Ok(n) = n {
if n > 0 {
let mut lines = buffer[..n].split(|i| *i == b'\n').peekable();
while let Some(line) = lines.next() {
match PythonLogMessage::try_from(line) {
Ok(log) => log.trace(), Ok(log) => log.trace(),
Err(_) => tracing::debug!("{line}"), // For interactive debugging ?
Err(_) => {
stdout.write_all(line).unwrap();
if lines.peek().is_some() {
stdout.write_all(b"\n").unwrap();
}
stdout.flush().unwrap();
}
}
}
}
} }
} }
} }
@ -1146,7 +1180,7 @@ fn download_convert_model(
let download_stdout = BufReader::new(download_process.stdout.take().unwrap()); let download_stdout = BufReader::new(download_process.stdout.take().unwrap());
thread::spawn(move || { thread::spawn(move || {
log_lines(download_stdout.lines()); log_lines(download_stdout);
}); });
let download_stderr = BufReader::new(download_process.stderr.take().unwrap()); let download_stderr = BufReader::new(download_process.stderr.take().unwrap());

View File

@ -22,6 +22,16 @@ pub enum Attention {
FlashInfer, FlashInfer,
} }
impl Attention {
pub fn block_size(&self) -> u32 {
match self {
Attention::FlashDecoding => 256,
Attention::FlashInfer => 1,
Attention::Paged => 16,
}
}
}
#[derive(Debug)] #[derive(Debug)]
pub struct ParseError; pub struct ParseError;