diff --git a/Cargo.lock b/Cargo.lock index 81b7c282..4b4e7670 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -142,6 +142,18 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" +[[package]] +name = "async-channel" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89b47800b0be77592da0afd425cc03468052844aff33b84e33cc696f64e77b6a" +dependencies = [ + "concurrent-queue", + "event-listener-strategy", + "futures-core", + "pin-project-lite", +] + [[package]] name = "async-rustls" version = "0.3.0" @@ -758,6 +770,15 @@ dependencies = [ "static_assertions", ] +[[package]] +name = "concurrent-queue" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "console" version = "0.15.8" @@ -1158,6 +1179,27 @@ dependencies = [ "cc", ] +[[package]] +name = "event-listener" +version = "5.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6032be9bd27023a771701cc49f9f053c751055f71efb2e0ae5c15809093675ba" +dependencies = [ + "concurrent-queue", + "parking", + "pin-project-lite", +] + +[[package]] +name = "event-listener-strategy" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f214dc438f977e6d4e3500aaa277f5ad94ca83fbbd9b1a15713ce2344ccc5a1" +dependencies = [ + "event-listener", + "pin-project-lite", +] + [[package]] name = "exr" version = "1.72.0" @@ -2922,6 +2964,12 @@ dependencies = [ "unicode-width", ] +[[package]] +name = "parking" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" + [[package]] name = "parking_lot" version = "0.12.3" @@ -4219,6 +4267,7 @@ dependencies = [ name = "text-generation-backend-llamacpp" version = "2.4.1-dev0" dependencies = [ + "async-channel", "async-trait", "clap 4.5.20", "cmake", diff --git a/backends/llamacpp/Cargo.toml b/backends/llamacpp/Cargo.toml index 0a5039b3..df2c3421 100644 --- a/backends/llamacpp/Cargo.toml +++ b/backends/llamacpp/Cargo.toml @@ -7,6 +7,7 @@ homepage.workspace = true [dependencies] async-trait = "0.1" +async-channel = "2.3" clap = { version = "4.5.19", features = ["derive"] } cxx = "1.0" num_cpus = "1" diff --git a/backends/llamacpp/src/backend.rs b/backends/llamacpp/src/backend.rs index e846a476..5bcb913b 100644 --- a/backends/llamacpp/src/backend.rs +++ b/backends/llamacpp/src/backend.rs @@ -2,6 +2,7 @@ use crate::ffi::{ create_worker_frontend, set_numactl_core_affinity, GenerationParams, LlamaCppWorkerFrontend, SamplingParams, }; +use async_channel::{unbounded as mpmc_unbounded, Receiver as MpmcReceiver, Sender as MpmcSender}; use async_trait::async_trait; use cxx::UniquePtr; use log::warn; @@ -19,7 +20,6 @@ use text_generation_router::{FinishReason, Token}; use thiserror::Error; use tokenizers::Tokenizer; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; -use tokio::sync::Semaphore; use tokio::task::JoinHandle; use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; @@ -102,18 +102,6 @@ pub enum LlamaCppBackendError { ModelInitializationFailed(PathBuf, String), } -struct LlamaCppWorker { - sender: Sender<(GenerationContext, UnboundedSender)>, -} - -impl LlamaCppWorker { - fn submit(&self, ctx: GenerationContext, sx: UnboundedSender) { - if let Err(err) = self.sender.send((ctx, sx)) { - // TODO: What do we do? - } - } -} - pub struct LlamaCppBackend { scheduler_sender: UnboundedSender<(GenerationContext, UnboundedSender)>, scheduler_handle: JoinHandle<()>, @@ -141,29 +129,26 @@ impl LlamaCppBackend { )); } - let cores_allocation = get_cores_allocation(num_cores_per_instance as usize); + // Allocate the multi-consumer queue to orchestrate all the workers + let (backlog_submitter, backlog_receiver) = mpmc_unbounded(); // Allocate all the workers - let streams = cores_allocation - .iter() - .map( - |affinity| match Self::allocate_worker(path, num_cores_per_instance as u32) { - Ok(worker) => { - let tokenizer = Arc::clone(&tokenizer); - let (sender, receiver) = channel(); - let affinity = affinity.clone().collect::>(); - spawn(move || worker_loop(worker, affinity, tokenizer, receiver)); - - Ok(LlamaCppWorker { sender }) - } - Err(e) => Err(e), - }, - ) - .collect::, _>>()?; + let cores_allocation = get_cores_allocation(num_cores_per_instance as usize); + cores_allocation.iter().for_each(|affinity| { + match Self::allocate_worker(path, num_cores_per_instance as u32) { + Ok(worker) => { + let tokenizer = Arc::clone(&tokenizer); + let affinity = affinity.clone().collect::>(); + let backlog_receiver = backlog_receiver.clone(); + spawn(move || worker_loop(worker, affinity, tokenizer, backlog_receiver)); + } + Err(e) => {} + } + }); // Start the scheduler loop let (scheduler_sender, scheduler_receiver) = unbounded_channel(); - let scheduler_handle = tokio::spawn(scheduler_loop(scheduler_receiver, streams)); + let scheduler_handle = tokio::spawn(scheduler_loop(scheduler_receiver, backlog_submitter)); Ok(Self { scheduler_sender, scheduler_handle, @@ -263,24 +248,16 @@ fn llama_generate_callback( async fn scheduler_loop( mut queue: UnboundedReceiver<(GenerationContext, UnboundedSender)>, - mut workers: Vec, + backlog: MpmcSender<(GenerationContext, UnboundedSender)>, ) { - // Semaphore allows us to wait for a worker to become available - let permits = Semaphore::new(workers.len()); - // Let's receive incoming requests loop { match queue.recv().await { None => break, Some((ctx, sender)) => { - let permit = permits.try_acquire(); - if let Err(err) = permit { - let _ = sender.send(Err(InferError::Overloaded(err))); + if let Err(e) = backlog.send((ctx, sender)).await { + todo!("What do we do") } - - // We can unwrap because we wouldn't have a semaphore available otherwise - let worker = workers.pop().unwrap(); - worker.submit(ctx, sender); } } } @@ -290,7 +267,7 @@ fn worker_loop( mut backend: UniquePtr, affinity: Vec, tokenizer: Arc, - backlog: Receiver<(GenerationContext, UnboundedSender)>, + backlog: MpmcReceiver<(GenerationContext, UnboundedSender)>, ) { // This loop will mostly decode single token at every step, so no need to rely on parallelism tokenizers::utils::parallelism::set_parallelism(false); @@ -299,7 +276,7 @@ fn worker_loop( set_numactl_core_affinity(&affinity); loop { - if let Ok((generation, stream)) = backlog.recv() { + if let Ok((generation, stream)) = backlog.recv_blocking() { let start = Instant::now(); let generation_params = generation.generation_params; // copy let sampling_params = generation.sampling_params; // copy