Handling debugger.
This commit is contained in:
parent
c53968dc45
commit
682db34b6a
|
@ -43,13 +43,7 @@ impl BackendV3 {
|
||||||
let attention: Attention = attention
|
let attention: Attention = attention
|
||||||
.parse()
|
.parse()
|
||||||
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"));
|
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"));
|
||||||
let block_size = if attention == Attention::FlashDecoding {
|
let block_size = attention.block_size();
|
||||||
256
|
|
||||||
} else if attention == Attention::FlashInfer {
|
|
||||||
1
|
|
||||||
} else {
|
|
||||||
16
|
|
||||||
};
|
|
||||||
|
|
||||||
let queue = Queue::new(
|
let queue = Queue::new(
|
||||||
requires_padding,
|
requires_padding,
|
||||||
|
|
|
@ -91,7 +91,11 @@ async fn block_allocator_task(
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>,
|
mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>,
|
||||||
) {
|
) {
|
||||||
let mut allocator = RadixAllocator::new(block_size, blocks, window_size, prefix_caching);
|
let mut allocator: Box<dyn Allocator + Send> = if prefix_caching {
|
||||||
|
Box::new(RadixAllocator::new(block_size, blocks, window_size))
|
||||||
|
} else {
|
||||||
|
Box::new(SimpleAllocator::new(blocks, block_size, window_size))
|
||||||
|
};
|
||||||
while let Some(cmd) = receiver.recv().await {
|
while let Some(cmd) = receiver.recv().await {
|
||||||
match cmd {
|
match cmd {
|
||||||
BlockAllocatorCommand::Free {
|
BlockAllocatorCommand::Free {
|
||||||
|
@ -124,12 +128,82 @@ enum BlockAllocatorCommand {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// pub trait Allocator {
|
pub trait Allocator {
|
||||||
// fn allocate(
|
fn allocate(
|
||||||
// &mut self,
|
&mut self,
|
||||||
// tokens: u32,
|
tokens: u32,
|
||||||
// prefill_tokens: Option<Arc<Vec<u32>>>,
|
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||||
// ) -> Option<BlockAllocation>;
|
) -> Option<BlockAllocation>;
|
||||||
//
|
|
||||||
// fn free(&mut self, blocks: Vec<u32>, allocation_id: u64);
|
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64);
|
||||||
// }
|
}
|
||||||
|
pub struct SimpleAllocator {
|
||||||
|
free_blocks: Vec<u32>,
|
||||||
|
block_size: u32,
|
||||||
|
window_size: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SimpleAllocator {
|
||||||
|
fn new(blocks: u32, block_size: u32, window_size: Option<u32>) -> Self {
|
||||||
|
SimpleAllocator {
|
||||||
|
block_size,
|
||||||
|
// Block 0 is reserved for health checks
|
||||||
|
free_blocks: (1..blocks).collect(),
|
||||||
|
window_size,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Allocator for SimpleAllocator {
|
||||||
|
fn allocate(
|
||||||
|
&mut self,
|
||||||
|
tokens: u32,
|
||||||
|
_prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||||
|
) -> Option<BlockAllocation> {
|
||||||
|
// Apply window size
|
||||||
|
let (required_blocks, repeats) = {
|
||||||
|
let (tokens, repeats) = match self.window_size {
|
||||||
|
None => (tokens, 1),
|
||||||
|
Some(window_size) => {
|
||||||
|
let repeats = (tokens + window_size - 1) / window_size;
|
||||||
|
let tokens = core::cmp::min(tokens, window_size);
|
||||||
|
(tokens, repeats as usize)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
// Pad to a multiple of block size
|
||||||
|
let required_blocks = (tokens + self.block_size - 1) / self.block_size;
|
||||||
|
(required_blocks, repeats)
|
||||||
|
};
|
||||||
|
|
||||||
|
let tokens = tokens as usize;
|
||||||
|
if required_blocks > self.free_blocks.len() as u32 {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
let blocks = self
|
||||||
|
.free_blocks
|
||||||
|
.split_off(self.free_blocks.len() - required_blocks as usize);
|
||||||
|
let mut slots =
|
||||||
|
Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);
|
||||||
|
|
||||||
|
'slots: for block_id in blocks.repeat(repeats).iter() {
|
||||||
|
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
|
||||||
|
slots.push(s);
|
||||||
|
if slots.len() == tokens {
|
||||||
|
break 'slots;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some(BlockAllocation {
|
||||||
|
allocation_id: 0,
|
||||||
|
blocks,
|
||||||
|
slots,
|
||||||
|
prefix_len: 0,
|
||||||
|
block_allocator: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn free(&mut self, blocks: Vec<u32>, _allocation_id: u64) {
|
||||||
|
self.free_blocks.extend(blocks)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -333,7 +333,7 @@ impl State {
|
||||||
break 'entry_loop;
|
break 'entry_loop;
|
||||||
}
|
}
|
||||||
Some(block_allocation) => {
|
Some(block_allocation) => {
|
||||||
tracing::debug!("Allocation: {block_allocation:?}");
|
tracing::info!("Allocation: {block_allocation:?}");
|
||||||
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
|
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
|
||||||
Some(block_allocation)
|
Some(block_allocation)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,12 +1,10 @@
|
||||||
|
use crate::block_allocator::{Allocator, BlockAllocation};
|
||||||
|
use slotmap::{DefaultKey, SlotMap};
|
||||||
use std::{
|
use std::{
|
||||||
collections::{BTreeSet, HashMap},
|
collections::{BTreeSet, HashMap},
|
||||||
sync::Arc,
|
sync::Arc,
|
||||||
};
|
};
|
||||||
|
|
||||||
use slotmap::{DefaultKey, SlotMap};
|
|
||||||
|
|
||||||
use crate::block_allocator::BlockAllocation;
|
|
||||||
|
|
||||||
pub struct RadixAllocator {
|
pub struct RadixAllocator {
|
||||||
allocation_id: u64,
|
allocation_id: u64,
|
||||||
|
|
||||||
|
@ -21,25 +19,15 @@ pub struct RadixAllocator {
|
||||||
// This isn't used because the prefix need to match without the windowing
|
// This isn't used because the prefix need to match without the windowing
|
||||||
// mecanism. This at worst is overallocating, not necessarily being wrong.
|
// mecanism. This at worst is overallocating, not necessarily being wrong.
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
|
|
||||||
/// Wether to actual use the radix tree for searching or not.
|
|
||||||
prefix_caching: bool,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RadixAllocator {
|
impl RadixAllocator {
|
||||||
pub fn new(
|
pub fn new(block_size: u32, n_blocks: u32, window_size: Option<u32>) -> Self {
|
||||||
block_size: u32,
|
|
||||||
n_blocks: u32,
|
|
||||||
window_size: Option<u32>,
|
|
||||||
prefix_caching: bool,
|
|
||||||
) -> Self {
|
|
||||||
if prefix_caching {
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
block_size, 1,
|
block_size, 1,
|
||||||
"Radix tree allocator only works with block_size=1, was: {}",
|
"Radix tree allocator only works with block_size=1, was: {}",
|
||||||
block_size
|
block_size
|
||||||
);
|
);
|
||||||
}
|
|
||||||
// if window_size.is_some() {
|
// if window_size.is_some() {
|
||||||
// unimplemented!("Window size not supported in the prefix-caching block allocator yet");
|
// unimplemented!("Window size not supported in the prefix-caching block allocator yet");
|
||||||
// }
|
// }
|
||||||
|
@ -52,7 +40,6 @@ impl RadixAllocator {
|
||||||
// Block 0 is reserved for health checks.
|
// Block 0 is reserved for health checks.
|
||||||
free_blocks: (1..n_blocks).collect(),
|
free_blocks: (1..n_blocks).collect(),
|
||||||
window_size,
|
window_size,
|
||||||
prefix_caching,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -81,15 +68,14 @@ impl RadixAllocator {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allocator trait
|
// Allocator trait
|
||||||
impl RadixAllocator {
|
impl Allocator for RadixAllocator {
|
||||||
pub fn allocate(
|
fn allocate(
|
||||||
&mut self,
|
&mut self,
|
||||||
tokens: u32,
|
tokens: u32,
|
||||||
prefill_tokens: Option<Arc<Vec<u32>>>,
|
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||||
) -> Option<BlockAllocation> {
|
) -> Option<BlockAllocation> {
|
||||||
let mut blocks = vec![];
|
let mut blocks = vec![];
|
||||||
let prefix_node = match (self.prefix_caching, prefill_tokens.as_ref()) {
|
let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() {
|
||||||
(true, Some(prefill_tokens)) => {
|
|
||||||
let node_id = self
|
let node_id = self
|
||||||
.cache_blocks
|
.cache_blocks
|
||||||
.find(prefill_tokens.as_slice(), &mut blocks);
|
.find(prefill_tokens.as_slice(), &mut blocks);
|
||||||
|
@ -97,8 +83,8 @@ impl RadixAllocator {
|
||||||
// refcount to ensure that the prefix that was found is not evicted.
|
// refcount to ensure that the prefix that was found is not evicted.
|
||||||
|
|
||||||
node_id
|
node_id
|
||||||
}
|
} else {
|
||||||
_ => self.cache_blocks.root_id(),
|
self.cache_blocks.root_id()
|
||||||
};
|
};
|
||||||
|
|
||||||
self.cache_blocks
|
self.cache_blocks
|
||||||
|
@ -108,7 +94,9 @@ impl RadixAllocator {
|
||||||
let prefix_len = blocks.len();
|
let prefix_len = blocks.len();
|
||||||
let suffix_len = tokens - prefix_len as u32;
|
let suffix_len = tokens - prefix_len as u32;
|
||||||
|
|
||||||
match self.alloc_or_reclaim(suffix_len as usize) {
|
let suffix_blocks = suffix_len;
|
||||||
|
|
||||||
|
match self.alloc_or_reclaim(suffix_blocks as usize) {
|
||||||
Some(suffix_blocks) => blocks.extend(suffix_blocks),
|
Some(suffix_blocks) => blocks.extend(suffix_blocks),
|
||||||
None => {
|
None => {
|
||||||
self.cache_blocks
|
self.cache_blocks
|
||||||
|
@ -127,6 +115,8 @@ impl RadixAllocator {
|
||||||
prefill_tokens: prefill_tokens.clone(),
|
prefill_tokens: prefill_tokens.clone(),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
tracing::info!("Blocks {blocks:?}");
|
||||||
|
|
||||||
self.allocation_id += 1;
|
self.allocation_id += 1;
|
||||||
self.allocations.insert(self.allocation_id, allocation);
|
self.allocations.insert(self.allocation_id, allocation);
|
||||||
|
|
||||||
|
@ -139,7 +129,7 @@ impl RadixAllocator {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn free(&mut self, blocks: Vec<u32>, allocation_id: u64) {
|
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64) {
|
||||||
let allocation = match self.allocations.remove(&allocation_id) {
|
let allocation = match self.allocations.remove(&allocation_id) {
|
||||||
Some(allocation) => allocation,
|
Some(allocation) => allocation,
|
||||||
None => unreachable!("Tried to free an unknown allocation."),
|
None => unreachable!("Tried to free an unknown allocation."),
|
||||||
|
@ -613,7 +603,21 @@ mod tests {
|
||||||
cache.free(allocation.blocks.clone(), allocation.allocation_id);
|
cache.free(allocation.blocks.clone(), allocation.allocation_id);
|
||||||
|
|
||||||
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||||
assert_eq!(allocation.blocks, vec![1, 2, 3, 8, 9, 10, 11, 7]);
|
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
|
||||||
|
assert_eq!(allocation.prefix_len, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn allocator_block_size() {
|
||||||
|
let mut cache = RadixAllocator::new(256, 12, None, false);
|
||||||
|
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||||
|
assert_eq!(allocation.blocks, vec![11]);
|
||||||
|
assert_eq!(allocation.slots, allocation.slots);
|
||||||
|
assert_eq!(allocation.prefix_len, 0);
|
||||||
|
cache.free(allocation.blocks.clone(), allocation.allocation_id);
|
||||||
|
|
||||||
|
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||||
|
assert_eq!(allocation.blocks, vec![11]);
|
||||||
assert_eq!(allocation.prefix_len, 0);
|
assert_eq!(allocation.prefix_len, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -835,11 +835,11 @@
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1724379657,
|
"lastModified": 1724638882,
|
||||||
"narHash": "sha256-+CFDh1FUgyY7q0FiWhKJpHS7LlD3KbiqN5Z4Z+4bGmc=",
|
"narHash": "sha256-ap2jIQi/FuUHR6HCht6ASWhoz8EiB99XmI8Esot38VE=",
|
||||||
"owner": "oxalica",
|
"owner": "oxalica",
|
||||||
"repo": "rust-overlay",
|
"repo": "rust-overlay",
|
||||||
"rev": "a18034322c7703fcfe5d7352a77981ba4a936a61",
|
"rev": "19b70f147b9c67a759e35824b241f1ed92e46694",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
|
|
|
@ -57,7 +57,7 @@
|
||||||
{
|
{
|
||||||
devShells = with pkgs; rec {
|
devShells = with pkgs; rec {
|
||||||
|
|
||||||
default = pure;
|
default = impure;
|
||||||
|
|
||||||
pure = mkShell {
|
pure = mkShell {
|
||||||
buildInputs = [
|
buildInputs = [
|
||||||
|
|
|
@ -8,7 +8,7 @@ use nix::unistd::Pid;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::env;
|
use std::env;
|
||||||
use std::ffi::OsString;
|
use std::ffi::OsString;
|
||||||
use std::io::{BufRead, BufReader, Lines};
|
use std::io::{BufRead, BufReader};
|
||||||
use std::os::unix::process::{CommandExt, ExitStatusExt};
|
use std::os::unix::process::{CommandExt, ExitStatusExt};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::process::{Child, Command, ExitStatus, Stdio};
|
use std::process::{Child, Command, ExitStatus, Stdio};
|
||||||
|
@ -18,7 +18,10 @@ use std::sync::{mpsc, Arc};
|
||||||
use std::thread;
|
use std::thread;
|
||||||
use std::thread::sleep;
|
use std::thread::sleep;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use std::{fs, io};
|
use std::{
|
||||||
|
fs, io,
|
||||||
|
io::{Read, Write},
|
||||||
|
};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tracing_subscriber::{filter::LevelFilter, EnvFilter};
|
use tracing_subscriber::{filter::LevelFilter, EnvFilter};
|
||||||
|
|
||||||
|
@ -833,6 +836,7 @@ fn shard_manager(
|
||||||
.args(shard_args)
|
.args(shard_args)
|
||||||
.env_clear()
|
.env_clear()
|
||||||
.envs(envs)
|
.envs(envs)
|
||||||
|
.stdin(Stdio::piped())
|
||||||
.stdout(Stdio::piped())
|
.stdout(Stdio::piped())
|
||||||
.stderr(Stdio::piped())
|
.stderr(Stdio::piped())
|
||||||
.process_group(0)
|
.process_group(0)
|
||||||
|
@ -854,12 +858,13 @@ fn shard_manager(
|
||||||
};
|
};
|
||||||
|
|
||||||
// Redirect STDOUT to the console
|
// Redirect STDOUT to the console
|
||||||
|
let mut pstdin = p.stdin.take().unwrap();
|
||||||
let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap());
|
let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap());
|
||||||
let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap());
|
let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap());
|
||||||
|
|
||||||
//stdout tracing thread
|
//stdout tracing thread
|
||||||
thread::spawn(move || {
|
thread::spawn(move || {
|
||||||
log_lines(shard_stdout_reader.lines());
|
log_lines(shard_stdout_reader);
|
||||||
});
|
});
|
||||||
// We read stderr in another thread as it seems that lines() can block in some cases
|
// We read stderr in another thread as it seems that lines() can block in some cases
|
||||||
let (err_sender, err_receiver) = mpsc::channel();
|
let (err_sender, err_receiver) = mpsc::channel();
|
||||||
|
@ -868,6 +873,18 @@ fn shard_manager(
|
||||||
err_sender.send(line).unwrap_or(());
|
err_sender.send(line).unwrap_or(());
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
// We read stdin in another thread as it seems that lines() can block in some cases
|
||||||
|
thread::spawn(move || {
|
||||||
|
let mut stdin = io::stdin(); // We get `Stdin` here.
|
||||||
|
loop {
|
||||||
|
let mut buffer = vec![0; 4096];
|
||||||
|
if let Ok(n) = stdin.read(&mut buffer) {
|
||||||
|
if n > 0 {
|
||||||
|
let _ = pstdin.write_all(&buffer[..n]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
let mut ready = false;
|
let mut ready = false;
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
|
@ -974,19 +991,36 @@ impl PythonLogMessage {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TryFrom<&String> for PythonLogMessage {
|
impl TryFrom<&[u8]> for PythonLogMessage {
|
||||||
type Error = serde_json::Error;
|
type Error = serde_json::Error;
|
||||||
|
|
||||||
fn try_from(value: &String) -> Result<Self, Self::Error> {
|
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
|
||||||
serde_json::from_str::<Self>(value)
|
serde_json::from_slice::<Self>(value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn log_lines<S: Sized + BufRead>(lines: Lines<S>) {
|
fn log_lines<R: Sized + Read>(mut bufread: BufReader<R>) {
|
||||||
for line in lines.map_while(Result::ok) {
|
let mut buffer = vec![0u8; 4096];
|
||||||
match PythonLogMessage::try_from(&line) {
|
let mut stdout = std::io::stdout();
|
||||||
|
loop {
|
||||||
|
let n = bufread.read(&mut buffer);
|
||||||
|
if let Ok(n) = n {
|
||||||
|
if n > 0 {
|
||||||
|
let mut lines = buffer[..n].split(|i| *i == b'\n').peekable();
|
||||||
|
while let Some(line) = lines.next() {
|
||||||
|
match PythonLogMessage::try_from(line) {
|
||||||
Ok(log) => log.trace(),
|
Ok(log) => log.trace(),
|
||||||
Err(_) => tracing::debug!("{line}"),
|
// For interactive debugging ?
|
||||||
|
Err(_) => {
|
||||||
|
stdout.write_all(line).unwrap();
|
||||||
|
if lines.peek().is_some() {
|
||||||
|
stdout.write_all(b"\n").unwrap();
|
||||||
|
}
|
||||||
|
stdout.flush().unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1146,7 +1180,7 @@ fn download_convert_model(
|
||||||
let download_stdout = BufReader::new(download_process.stdout.take().unwrap());
|
let download_stdout = BufReader::new(download_process.stdout.take().unwrap());
|
||||||
|
|
||||||
thread::spawn(move || {
|
thread::spawn(move || {
|
||||||
log_lines(download_stdout.lines());
|
log_lines(download_stdout);
|
||||||
});
|
});
|
||||||
|
|
||||||
let download_stderr = BufReader::new(download_process.stderr.take().unwrap());
|
let download_stderr = BufReader::new(download_process.stderr.take().unwrap());
|
||||||
|
|
|
@ -22,6 +22,16 @@ pub enum Attention {
|
||||||
FlashInfer,
|
FlashInfer,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Attention {
|
||||||
|
pub fn block_size(&self) -> u32 {
|
||||||
|
match self {
|
||||||
|
Attention::FlashDecoding => 256,
|
||||||
|
Attention::FlashInfer => 1,
|
||||||
|
Attention::Paged => 16,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct ParseError;
|
pub struct ParseError;
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue