Handling debugger.
This commit is contained in:
parent
c53968dc45
commit
682db34b6a
|
@ -43,13 +43,7 @@ impl BackendV3 {
|
|||
let attention: Attention = attention
|
||||
.parse()
|
||||
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"));
|
||||
let block_size = if attention == Attention::FlashDecoding {
|
||||
256
|
||||
} else if attention == Attention::FlashInfer {
|
||||
1
|
||||
} else {
|
||||
16
|
||||
};
|
||||
let block_size = attention.block_size();
|
||||
|
||||
let queue = Queue::new(
|
||||
requires_padding,
|
||||
|
|
|
@ -91,7 +91,11 @@ async fn block_allocator_task(
|
|||
window_size: Option<u32>,
|
||||
mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>,
|
||||
) {
|
||||
let mut allocator = RadixAllocator::new(block_size, blocks, window_size, prefix_caching);
|
||||
let mut allocator: Box<dyn Allocator + Send> = if prefix_caching {
|
||||
Box::new(RadixAllocator::new(block_size, blocks, window_size))
|
||||
} else {
|
||||
Box::new(SimpleAllocator::new(blocks, block_size, window_size))
|
||||
};
|
||||
while let Some(cmd) = receiver.recv().await {
|
||||
match cmd {
|
||||
BlockAllocatorCommand::Free {
|
||||
|
@ -124,12 +128,82 @@ enum BlockAllocatorCommand {
|
|||
},
|
||||
}
|
||||
|
||||
// pub trait Allocator {
|
||||
// fn allocate(
|
||||
// &mut self,
|
||||
// tokens: u32,
|
||||
// prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
// ) -> Option<BlockAllocation>;
|
||||
//
|
||||
// fn free(&mut self, blocks: Vec<u32>, allocation_id: u64);
|
||||
// }
|
||||
pub trait Allocator {
|
||||
fn allocate(
|
||||
&mut self,
|
||||
tokens: u32,
|
||||
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
) -> Option<BlockAllocation>;
|
||||
|
||||
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64);
|
||||
}
|
||||
pub struct SimpleAllocator {
|
||||
free_blocks: Vec<u32>,
|
||||
block_size: u32,
|
||||
window_size: Option<u32>,
|
||||
}
|
||||
|
||||
impl SimpleAllocator {
|
||||
fn new(blocks: u32, block_size: u32, window_size: Option<u32>) -> Self {
|
||||
SimpleAllocator {
|
||||
block_size,
|
||||
// Block 0 is reserved for health checks
|
||||
free_blocks: (1..blocks).collect(),
|
||||
window_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Allocator for SimpleAllocator {
|
||||
fn allocate(
|
||||
&mut self,
|
||||
tokens: u32,
|
||||
_prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
) -> Option<BlockAllocation> {
|
||||
// Apply window size
|
||||
let (required_blocks, repeats) = {
|
||||
let (tokens, repeats) = match self.window_size {
|
||||
None => (tokens, 1),
|
||||
Some(window_size) => {
|
||||
let repeats = (tokens + window_size - 1) / window_size;
|
||||
let tokens = core::cmp::min(tokens, window_size);
|
||||
(tokens, repeats as usize)
|
||||
}
|
||||
};
|
||||
// Pad to a multiple of block size
|
||||
let required_blocks = (tokens + self.block_size - 1) / self.block_size;
|
||||
(required_blocks, repeats)
|
||||
};
|
||||
|
||||
let tokens = tokens as usize;
|
||||
if required_blocks > self.free_blocks.len() as u32 {
|
||||
None
|
||||
} else {
|
||||
let blocks = self
|
||||
.free_blocks
|
||||
.split_off(self.free_blocks.len() - required_blocks as usize);
|
||||
let mut slots =
|
||||
Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);
|
||||
|
||||
'slots: for block_id in blocks.repeat(repeats).iter() {
|
||||
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
|
||||
slots.push(s);
|
||||
if slots.len() == tokens {
|
||||
break 'slots;
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(BlockAllocation {
|
||||
allocation_id: 0,
|
||||
blocks,
|
||||
slots,
|
||||
prefix_len: 0,
|
||||
block_allocator: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn free(&mut self, blocks: Vec<u32>, _allocation_id: u64) {
|
||||
self.free_blocks.extend(blocks)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -333,7 +333,7 @@ impl State {
|
|||
break 'entry_loop;
|
||||
}
|
||||
Some(block_allocation) => {
|
||||
tracing::debug!("Allocation: {block_allocation:?}");
|
||||
tracing::info!("Allocation: {block_allocation:?}");
|
||||
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
|
||||
Some(block_allocation)
|
||||
}
|
||||
|
|
|
@ -1,12 +1,10 @@
|
|||
use crate::block_allocator::{Allocator, BlockAllocation};
|
||||
use slotmap::{DefaultKey, SlotMap};
|
||||
use std::{
|
||||
collections::{BTreeSet, HashMap},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use slotmap::{DefaultKey, SlotMap};
|
||||
|
||||
use crate::block_allocator::BlockAllocation;
|
||||
|
||||
pub struct RadixAllocator {
|
||||
allocation_id: u64,
|
||||
|
||||
|
@ -21,25 +19,15 @@ pub struct RadixAllocator {
|
|||
// This isn't used because the prefix need to match without the windowing
|
||||
// mecanism. This at worst is overallocating, not necessarily being wrong.
|
||||
window_size: Option<u32>,
|
||||
|
||||
/// Wether to actual use the radix tree for searching or not.
|
||||
prefix_caching: bool,
|
||||
}
|
||||
|
||||
impl RadixAllocator {
|
||||
pub fn new(
|
||||
block_size: u32,
|
||||
n_blocks: u32,
|
||||
window_size: Option<u32>,
|
||||
prefix_caching: bool,
|
||||
) -> Self {
|
||||
if prefix_caching {
|
||||
assert_eq!(
|
||||
block_size, 1,
|
||||
"Radix tree allocator only works with block_size=1, was: {}",
|
||||
block_size
|
||||
);
|
||||
}
|
||||
pub fn new(block_size: u32, n_blocks: u32, window_size: Option<u32>) -> Self {
|
||||
assert_eq!(
|
||||
block_size, 1,
|
||||
"Radix tree allocator only works with block_size=1, was: {}",
|
||||
block_size
|
||||
);
|
||||
// if window_size.is_some() {
|
||||
// unimplemented!("Window size not supported in the prefix-caching block allocator yet");
|
||||
// }
|
||||
|
@ -52,7 +40,6 @@ impl RadixAllocator {
|
|||
// Block 0 is reserved for health checks.
|
||||
free_blocks: (1..n_blocks).collect(),
|
||||
window_size,
|
||||
prefix_caching,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -81,24 +68,23 @@ impl RadixAllocator {
|
|||
}
|
||||
|
||||
// Allocator trait
|
||||
impl RadixAllocator {
|
||||
pub fn allocate(
|
||||
impl Allocator for RadixAllocator {
|
||||
fn allocate(
|
||||
&mut self,
|
||||
tokens: u32,
|
||||
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
) -> Option<BlockAllocation> {
|
||||
let mut blocks = vec![];
|
||||
let prefix_node = match (self.prefix_caching, prefill_tokens.as_ref()) {
|
||||
(true, Some(prefill_tokens)) => {
|
||||
let node_id = self
|
||||
.cache_blocks
|
||||
.find(prefill_tokens.as_slice(), &mut blocks);
|
||||
// Even if this allocation fails below, we need to increase he
|
||||
// refcount to ensure that the prefix that was found is not evicted.
|
||||
let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() {
|
||||
let node_id = self
|
||||
.cache_blocks
|
||||
.find(prefill_tokens.as_slice(), &mut blocks);
|
||||
// Even if this allocation fails below, we need to increase he
|
||||
// refcount to ensure that the prefix that was found is not evicted.
|
||||
|
||||
node_id
|
||||
}
|
||||
_ => self.cache_blocks.root_id(),
|
||||
node_id
|
||||
} else {
|
||||
self.cache_blocks.root_id()
|
||||
};
|
||||
|
||||
self.cache_blocks
|
||||
|
@ -108,7 +94,9 @@ impl RadixAllocator {
|
|||
let prefix_len = blocks.len();
|
||||
let suffix_len = tokens - prefix_len as u32;
|
||||
|
||||
match self.alloc_or_reclaim(suffix_len as usize) {
|
||||
let suffix_blocks = suffix_len;
|
||||
|
||||
match self.alloc_or_reclaim(suffix_blocks as usize) {
|
||||
Some(suffix_blocks) => blocks.extend(suffix_blocks),
|
||||
None => {
|
||||
self.cache_blocks
|
||||
|
@ -127,6 +115,8 @@ impl RadixAllocator {
|
|||
prefill_tokens: prefill_tokens.clone(),
|
||||
};
|
||||
|
||||
tracing::info!("Blocks {blocks:?}");
|
||||
|
||||
self.allocation_id += 1;
|
||||
self.allocations.insert(self.allocation_id, allocation);
|
||||
|
||||
|
@ -139,7 +129,7 @@ impl RadixAllocator {
|
|||
})
|
||||
}
|
||||
|
||||
pub fn free(&mut self, blocks: Vec<u32>, allocation_id: u64) {
|
||||
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64) {
|
||||
let allocation = match self.allocations.remove(&allocation_id) {
|
||||
Some(allocation) => allocation,
|
||||
None => unreachable!("Tried to free an unknown allocation."),
|
||||
|
@ -613,7 +603,21 @@ mod tests {
|
|||
cache.free(allocation.blocks.clone(), allocation.allocation_id);
|
||||
|
||||
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
assert_eq!(allocation.blocks, vec![1, 2, 3, 8, 9, 10, 11, 7]);
|
||||
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
|
||||
assert_eq!(allocation.prefix_len, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allocator_block_size() {
|
||||
let mut cache = RadixAllocator::new(256, 12, None, false);
|
||||
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
assert_eq!(allocation.blocks, vec![11]);
|
||||
assert_eq!(allocation.slots, allocation.slots);
|
||||
assert_eq!(allocation.prefix_len, 0);
|
||||
cache.free(allocation.blocks.clone(), allocation.allocation_id);
|
||||
|
||||
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
assert_eq!(allocation.blocks, vec![11]);
|
||||
assert_eq!(allocation.prefix_len, 0);
|
||||
}
|
||||
|
||||
|
|
|
@ -835,11 +835,11 @@
|
|||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1724379657,
|
||||
"narHash": "sha256-+CFDh1FUgyY7q0FiWhKJpHS7LlD3KbiqN5Z4Z+4bGmc=",
|
||||
"lastModified": 1724638882,
|
||||
"narHash": "sha256-ap2jIQi/FuUHR6HCht6ASWhoz8EiB99XmI8Esot38VE=",
|
||||
"owner": "oxalica",
|
||||
"repo": "rust-overlay",
|
||||
"rev": "a18034322c7703fcfe5d7352a77981ba4a936a61",
|
||||
"rev": "19b70f147b9c67a759e35824b241f1ed92e46694",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
|
|
@ -57,7 +57,7 @@
|
|||
{
|
||||
devShells = with pkgs; rec {
|
||||
|
||||
default = pure;
|
||||
default = impure;
|
||||
|
||||
pure = mkShell {
|
||||
buildInputs = [
|
||||
|
|
|
@ -8,7 +8,7 @@ use nix::unistd::Pid;
|
|||
use serde::Deserialize;
|
||||
use std::env;
|
||||
use std::ffi::OsString;
|
||||
use std::io::{BufRead, BufReader, Lines};
|
||||
use std::io::{BufRead, BufReader};
|
||||
use std::os::unix::process::{CommandExt, ExitStatusExt};
|
||||
use std::path::Path;
|
||||
use std::process::{Child, Command, ExitStatus, Stdio};
|
||||
|
@ -18,7 +18,10 @@ use std::sync::{mpsc, Arc};
|
|||
use std::thread;
|
||||
use std::thread::sleep;
|
||||
use std::time::{Duration, Instant};
|
||||
use std::{fs, io};
|
||||
use std::{
|
||||
fs, io,
|
||||
io::{Read, Write},
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tracing_subscriber::{filter::LevelFilter, EnvFilter};
|
||||
|
||||
|
@ -833,6 +836,7 @@ fn shard_manager(
|
|||
.args(shard_args)
|
||||
.env_clear()
|
||||
.envs(envs)
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.process_group(0)
|
||||
|
@ -854,12 +858,13 @@ fn shard_manager(
|
|||
};
|
||||
|
||||
// Redirect STDOUT to the console
|
||||
let mut pstdin = p.stdin.take().unwrap();
|
||||
let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap());
|
||||
let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap());
|
||||
|
||||
//stdout tracing thread
|
||||
thread::spawn(move || {
|
||||
log_lines(shard_stdout_reader.lines());
|
||||
log_lines(shard_stdout_reader);
|
||||
});
|
||||
// We read stderr in another thread as it seems that lines() can block in some cases
|
||||
let (err_sender, err_receiver) = mpsc::channel();
|
||||
|
@ -868,6 +873,18 @@ fn shard_manager(
|
|||
err_sender.send(line).unwrap_or(());
|
||||
}
|
||||
});
|
||||
// We read stdin in another thread as it seems that lines() can block in some cases
|
||||
thread::spawn(move || {
|
||||
let mut stdin = io::stdin(); // We get `Stdin` here.
|
||||
loop {
|
||||
let mut buffer = vec![0; 4096];
|
||||
if let Ok(n) = stdin.read(&mut buffer) {
|
||||
if n > 0 {
|
||||
let _ = pstdin.write_all(&buffer[..n]);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let mut ready = false;
|
||||
let start_time = Instant::now();
|
||||
|
@ -974,19 +991,36 @@ impl PythonLogMessage {
|
|||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&String> for PythonLogMessage {
|
||||
impl TryFrom<&[u8]> for PythonLogMessage {
|
||||
type Error = serde_json::Error;
|
||||
|
||||
fn try_from(value: &String) -> Result<Self, Self::Error> {
|
||||
serde_json::from_str::<Self>(value)
|
||||
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
|
||||
serde_json::from_slice::<Self>(value)
|
||||
}
|
||||
}
|
||||
|
||||
fn log_lines<S: Sized + BufRead>(lines: Lines<S>) {
|
||||
for line in lines.map_while(Result::ok) {
|
||||
match PythonLogMessage::try_from(&line) {
|
||||
Ok(log) => log.trace(),
|
||||
Err(_) => tracing::debug!("{line}"),
|
||||
fn log_lines<R: Sized + Read>(mut bufread: BufReader<R>) {
|
||||
let mut buffer = vec![0u8; 4096];
|
||||
let mut stdout = std::io::stdout();
|
||||
loop {
|
||||
let n = bufread.read(&mut buffer);
|
||||
if let Ok(n) = n {
|
||||
if n > 0 {
|
||||
let mut lines = buffer[..n].split(|i| *i == b'\n').peekable();
|
||||
while let Some(line) = lines.next() {
|
||||
match PythonLogMessage::try_from(line) {
|
||||
Ok(log) => log.trace(),
|
||||
// For interactive debugging ?
|
||||
Err(_) => {
|
||||
stdout.write_all(line).unwrap();
|
||||
if lines.peek().is_some() {
|
||||
stdout.write_all(b"\n").unwrap();
|
||||
}
|
||||
stdout.flush().unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1146,7 +1180,7 @@ fn download_convert_model(
|
|||
let download_stdout = BufReader::new(download_process.stdout.take().unwrap());
|
||||
|
||||
thread::spawn(move || {
|
||||
log_lines(download_stdout.lines());
|
||||
log_lines(download_stdout);
|
||||
});
|
||||
|
||||
let download_stderr = BufReader::new(download_process.stderr.take().unwrap());
|
||||
|
|
|
@ -22,6 +22,16 @@ pub enum Attention {
|
|||
FlashInfer,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
pub fn block_size(&self) -> u32 {
|
||||
match self {
|
||||
Attention::FlashDecoding => 256,
|
||||
Attention::FlashInfer => 1,
|
||||
Attention::Paged => 16,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ParseError;
|
||||
|
||||
|
|
Loading…
Reference in New Issue