feat(backend): rely on multi consumer queue to scheduler workers

This commit is contained in:
Morgan Funtowicz 2024-11-22 13:32:56 +01:00
parent 84eead219a
commit 5a85661661
3 changed files with 71 additions and 44 deletions

49
Cargo.lock generated
View File

@ -142,6 +142,18 @@ version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
[[package]]
name = "async-channel"
version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89b47800b0be77592da0afd425cc03468052844aff33b84e33cc696f64e77b6a"
dependencies = [
"concurrent-queue",
"event-listener-strategy",
"futures-core",
"pin-project-lite",
]
[[package]]
name = "async-rustls"
version = "0.3.0"
@ -758,6 +770,15 @@ dependencies = [
"static_assertions",
]
[[package]]
name = "concurrent-queue"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973"
dependencies = [
"crossbeam-utils",
]
[[package]]
name = "console"
version = "0.15.8"
@ -1158,6 +1179,27 @@ dependencies = [
"cc",
]
[[package]]
name = "event-listener"
version = "5.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6032be9bd27023a771701cc49f9f053c751055f71efb2e0ae5c15809093675ba"
dependencies = [
"concurrent-queue",
"parking",
"pin-project-lite",
]
[[package]]
name = "event-listener-strategy"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f214dc438f977e6d4e3500aaa277f5ad94ca83fbbd9b1a15713ce2344ccc5a1"
dependencies = [
"event-listener",
"pin-project-lite",
]
[[package]]
name = "exr"
version = "1.72.0"
@ -2922,6 +2964,12 @@ dependencies = [
"unicode-width",
]
[[package]]
name = "parking"
version = "2.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba"
[[package]]
name = "parking_lot"
version = "0.12.3"
@ -4219,6 +4267,7 @@ dependencies = [
name = "text-generation-backend-llamacpp"
version = "2.4.1-dev0"
dependencies = [
"async-channel",
"async-trait",
"clap 4.5.20",
"cmake",

View File

@ -7,6 +7,7 @@ homepage.workspace = true
[dependencies]
async-trait = "0.1"
async-channel = "2.3"
clap = { version = "4.5.19", features = ["derive"] }
cxx = "1.0"
num_cpus = "1"

View File

@ -2,6 +2,7 @@ use crate::ffi::{
create_worker_frontend, set_numactl_core_affinity, GenerationParams, LlamaCppWorkerFrontend,
SamplingParams,
};
use async_channel::{unbounded as mpmc_unbounded, Receiver as MpmcReceiver, Sender as MpmcSender};
use async_trait::async_trait;
use cxx::UniquePtr;
use log::warn;
@ -19,7 +20,6 @@ use text_generation_router::{FinishReason, Token};
use thiserror::Error;
use tokenizers::Tokenizer;
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::sync::Semaphore;
use tokio::task::JoinHandle;
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
@ -102,18 +102,6 @@ pub enum LlamaCppBackendError {
ModelInitializationFailed(PathBuf, String),
}
struct LlamaCppWorker {
sender: Sender<(GenerationContext, UnboundedSender<InferResult>)>,
}
impl LlamaCppWorker {
fn submit(&self, ctx: GenerationContext, sx: UnboundedSender<InferResult>) {
if let Err(err) = self.sender.send((ctx, sx)) {
// TODO: What do we do?
}
}
}
pub struct LlamaCppBackend {
scheduler_sender: UnboundedSender<(GenerationContext, UnboundedSender<InferResult>)>,
scheduler_handle: JoinHandle<()>,
@ -141,29 +129,26 @@ impl LlamaCppBackend {
));
}
let cores_allocation = get_cores_allocation(num_cores_per_instance as usize);
// Allocate the multi-consumer queue to orchestrate all the workers
let (backlog_submitter, backlog_receiver) = mpmc_unbounded();
// Allocate all the workers
let streams = cores_allocation
.iter()
.map(
|affinity| match Self::allocate_worker(path, num_cores_per_instance as u32) {
Ok(worker) => {
let tokenizer = Arc::clone(&tokenizer);
let (sender, receiver) = channel();
let affinity = affinity.clone().collect::<Vec<_>>();
spawn(move || worker_loop(worker, affinity, tokenizer, receiver));
Ok(LlamaCppWorker { sender })
}
Err(e) => Err(e),
},
)
.collect::<Result<Vec<_>, _>>()?;
let cores_allocation = get_cores_allocation(num_cores_per_instance as usize);
cores_allocation.iter().for_each(|affinity| {
match Self::allocate_worker(path, num_cores_per_instance as u32) {
Ok(worker) => {
let tokenizer = Arc::clone(&tokenizer);
let affinity = affinity.clone().collect::<Vec<_>>();
let backlog_receiver = backlog_receiver.clone();
spawn(move || worker_loop(worker, affinity, tokenizer, backlog_receiver));
}
Err(e) => {}
}
});
// Start the scheduler loop
let (scheduler_sender, scheduler_receiver) = unbounded_channel();
let scheduler_handle = tokio::spawn(scheduler_loop(scheduler_receiver, streams));
let scheduler_handle = tokio::spawn(scheduler_loop(scheduler_receiver, backlog_submitter));
Ok(Self {
scheduler_sender,
scheduler_handle,
@ -263,24 +248,16 @@ fn llama_generate_callback(
async fn scheduler_loop(
mut queue: UnboundedReceiver<(GenerationContext, UnboundedSender<InferResult>)>,
mut workers: Vec<LlamaCppWorker>,
backlog: MpmcSender<(GenerationContext, UnboundedSender<InferResult>)>,
) {
// Semaphore allows us to wait for a worker to become available
let permits = Semaphore::new(workers.len());
// Let's receive incoming requests
loop {
match queue.recv().await {
None => break,
Some((ctx, sender)) => {
let permit = permits.try_acquire();
if let Err(err) = permit {
let _ = sender.send(Err(InferError::Overloaded(err)));
if let Err(e) = backlog.send((ctx, sender)).await {
todo!("What do we do")
}
// We can unwrap because we wouldn't have a semaphore available otherwise
let worker = workers.pop().unwrap();
worker.submit(ctx, sender);
}
}
}
@ -290,7 +267,7 @@ fn worker_loop(
mut backend: UniquePtr<LlamaCppWorkerFrontend>,
affinity: Vec<usize>,
tokenizer: Arc<Tokenizer>,
backlog: Receiver<(GenerationContext, UnboundedSender<InferResult>)>,
backlog: MpmcReceiver<(GenerationContext, UnboundedSender<InferResult>)>,
) {
// This loop will mostly decode single token at every step, so no need to rely on parallelism
tokenizers::utils::parallelism::set_parallelism(false);
@ -299,7 +276,7 @@ fn worker_loop(
set_numactl_core_affinity(&affinity);
loop {
if let Ok((generation, stream)) = backlog.recv() {
if let Ok((generation, stream)) = backlog.recv_blocking() {
let start = Instant::now();
let generation_params = generation.generation_params; // copy
let sampling_params = generation.sampling_params; // copy