diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 916e72b4..6fddfd7e 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -2,76 +2,18 @@ use crate::Result; use crate::{Batch, Client, GeneratedText}; use futures::future::join_all; -use tokio::sync::{broadcast, mpsc}; +use futures::future::select_all; use tonic::transport::Uri; -/// List of all available commands that can be sent through the command channel -#[derive(Clone, Debug)] -enum Command { - Generate( - Batch, - mpsc::Sender, Option)>>, - ), - GenerateWithCache( - Vec, - mpsc::Sender, Option)>>, - ), - ClearCache(mpsc::Sender>), -} - -/// Tokio task that handles the communication with a single shard -/// -/// We subscribe on a broadcast channel to receive commands that will be sent by -/// the ShardedClient. -/// -/// Each command is fan out to all shards. -/// -/// The result of the command is sent back to the ShardedClient through a mpsc channel (multi -/// producer = the shards, single consumer = the ShardedClient). -async fn client_task(mut client: Client, mut request_subscriber: broadcast::Receiver) { - while let Ok(message) = request_subscriber.recv().await { - match message { - Command::Generate(batch, response_tx) => { - let result = client.generate(batch).await; - // We can unwrap_or(()) here because the only error that can happen is if the - // receiver is dropped, which means that the ShardedClient already received a - // response from another shard - response_tx.try_send(result).unwrap_or(()); - } - Command::GenerateWithCache(batches, response_tx) => { - let result = client.generate_with_cache(batches).await; - response_tx.try_send(result).unwrap_or(()); - } - Command::ClearCache(response_tx) => { - let result = client.clear_cache().await; - response_tx.try_send(result).unwrap_or(()); - } - }; - } -} - /// Text Generation Inference gRPC multi client pub struct ShardedClient { - _clients: Vec, - request_tx: broadcast::Sender, + clients: Vec, } impl ShardedClient { fn new(clients: Vec) -> Self { - // The broadcast channel to communicate with the shards - // We use a capacity of one as the shards are not asynchronous and can only process one - // command at a time - let (request_tx, _) = broadcast::channel(1); - - // Spawn client tasks - for client in clients.iter() { - let request_subscriber = request_tx.subscribe(); - tokio::spawn(client_task(client.clone(), request_subscriber)); - } - Self { - _clients: clients, - request_tx, + clients, } } @@ -101,15 +43,15 @@ impl ShardedClient { /// /// Returns a list of generated texts of request that met their stopping criteria /// and the next cached batch - pub async fn generate(&self, batch: Batch) -> Result<(Vec, Option)> { - // Create a channel to receive the response from the shards - // We will only ever receive one message on this channel - let (response_tx, mut response_rx) = mpsc::channel(1); - self.request_tx - .send(Command::Generate(batch, response_tx)) - .unwrap(); + pub async fn generate(&mut self, batch: Batch) -> Result<(Vec, Option)> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.generate(batch.clone()))) + .collect(); // As soon as we receive one response, we can return as all shards will return the same - response_rx.recv().await.unwrap() + let (result, _, _) = select_all(futures).await; + result } /// Generate one token for each request in the given cached batch @@ -117,27 +59,26 @@ impl ShardedClient { /// Returns a list of generated texts of request that met their stopping criteria /// and the next cached batch pub async fn generate_with_cache( - &self, + &mut self, batches: Vec, ) -> Result<(Vec, Option)> { - // Create a channel to receive the response from the shards - // We will only ever receive one message on this channel - let (response_tx, mut response_rx) = mpsc::channel(1); - self.request_tx - .send(Command::GenerateWithCache(batches, response_tx)) - .unwrap(); + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.generate_with_cache(batches.clone()))) + .collect(); // As soon as we receive one response, we can return as all shards will return the same - response_rx.recv().await.unwrap() + let (result, _, _) = select_all(futures).await; + result } /// Clear the past generations cache - pub async fn clear_cache(&self) -> Result<()> { - // Create a channel to receive the response from the shards - // We will only ever receive one message on this channel - let (response_tx, mut response_rx) = mpsc::channel(1); - self.request_tx - .send(Command::ClearCache(response_tx)) - .unwrap(); - response_rx.recv().await.unwrap() + pub async fn clear_cache(&mut self) -> Result<()> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| client.clear_cache()) + .collect(); + join_all(futures).await.into_iter().collect() } } diff --git a/router/src/batcher.rs b/router/src/batcher.rs index 052716a4..f71428e5 100644 --- a/router/src/batcher.rs +++ b/router/src/batcher.rs @@ -39,9 +39,9 @@ impl Batcher { // Spawn batching background task that contains all the inference logic tokio::spawn(batching_task( + client, max_batch_size, max_waiting_tokens, - client, db.clone(), shared.clone(), )); @@ -86,9 +86,9 @@ impl Batcher { /// Batches requests and sends them to the inference server #[instrument(skip(client, db, shared))] async fn batching_task( + mut client: ShardedClient, max_batch_size: usize, max_waiting_tokens: usize, - client: ShardedClient, db: Db, shared: Arc, ) { diff --git a/router/src/main.rs b/router/src/main.rs index 6d1a0fb9..ea7ebd12 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -61,7 +61,7 @@ fn main() -> Result<(), std::io::Error> { .unwrap() .block_on(async { // Instantiate sharded client from the master unix socket - let sharded_client = ShardedClient::connect_uds(master_shard_uds_path) + let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) .await .expect("Could not connect to server"); // Clear the cache; useful if the webserver rebooted