feat(client): Simplify sharded logic

This commit is contained in:
OlivierDehaene 2022-10-22 23:40:05 +02:00
parent c8ce9b2515
commit beb552127a
3 changed files with 29 additions and 88 deletions

View File

@ -2,76 +2,18 @@
use crate::Result;
use crate::{Batch, Client, GeneratedText};
use futures::future::join_all;
use tokio::sync::{broadcast, mpsc};
use futures::future::select_all;
use tonic::transport::Uri;
/// List of all available commands that can be sent through the command channel
#[derive(Clone, Debug)]
enum Command {
Generate(
Batch,
mpsc::Sender<Result<(Vec<GeneratedText>, Option<Batch>)>>,
),
GenerateWithCache(
Vec<Batch>,
mpsc::Sender<Result<(Vec<GeneratedText>, Option<Batch>)>>,
),
ClearCache(mpsc::Sender<Result<()>>),
}
/// Tokio task that handles the communication with a single shard
///
/// We subscribe on a broadcast channel to receive commands that will be sent by
/// the ShardedClient.
///
/// Each command is fan out to all shards.
///
/// The result of the command is sent back to the ShardedClient through a mpsc channel (multi
/// producer = the shards, single consumer = the ShardedClient).
async fn client_task(mut client: Client, mut request_subscriber: broadcast::Receiver<Command>) {
while let Ok(message) = request_subscriber.recv().await {
match message {
Command::Generate(batch, response_tx) => {
let result = client.generate(batch).await;
// We can unwrap_or(()) here because the only error that can happen is if the
// receiver is dropped, which means that the ShardedClient already received a
// response from another shard
response_tx.try_send(result).unwrap_or(());
}
Command::GenerateWithCache(batches, response_tx) => {
let result = client.generate_with_cache(batches).await;
response_tx.try_send(result).unwrap_or(());
}
Command::ClearCache(response_tx) => {
let result = client.clear_cache().await;
response_tx.try_send(result).unwrap_or(());
}
};
}
}
/// Text Generation Inference gRPC multi client
pub struct ShardedClient {
_clients: Vec<Client>,
request_tx: broadcast::Sender<Command>,
clients: Vec<Client>,
}
impl ShardedClient {
fn new(clients: Vec<Client>) -> Self {
// The broadcast channel to communicate with the shards
// We use a capacity of one as the shards are not asynchronous and can only process one
// command at a time
let (request_tx, _) = broadcast::channel(1);
// Spawn client tasks
for client in clients.iter() {
let request_subscriber = request_tx.subscribe();
tokio::spawn(client_task(client.clone(), request_subscriber));
}
Self {
_clients: clients,
request_tx,
clients,
}
}
@ -101,15 +43,15 @@ impl ShardedClient {
///
/// Returns a list of generated texts of request that met their stopping criteria
/// and the next cached batch
pub async fn generate(&self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
// Create a channel to receive the response from the shards
// We will only ever receive one message on this channel
let (response_tx, mut response_rx) = mpsc::channel(1);
self.request_tx
.send(Command::Generate(batch, response_tx))
.unwrap();
pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.generate(batch.clone())))
.collect();
// As soon as we receive one response, we can return as all shards will return the same
response_rx.recv().await.unwrap()
let (result, _, _) = select_all(futures).await;
result
}
/// Generate one token for each request in the given cached batch
@ -117,27 +59,26 @@ impl ShardedClient {
/// Returns a list of generated texts of request that met their stopping criteria
/// and the next cached batch
pub async fn generate_with_cache(
&self,
&mut self,
batches: Vec<Batch>,
) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
// Create a channel to receive the response from the shards
// We will only ever receive one message on this channel
let (response_tx, mut response_rx) = mpsc::channel(1);
self.request_tx
.send(Command::GenerateWithCache(batches, response_tx))
.unwrap();
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.generate_with_cache(batches.clone())))
.collect();
// As soon as we receive one response, we can return as all shards will return the same
response_rx.recv().await.unwrap()
let (result, _, _) = select_all(futures).await;
result
}
/// Clear the past generations cache
pub async fn clear_cache(&self) -> Result<()> {
// Create a channel to receive the response from the shards
// We will only ever receive one message on this channel
let (response_tx, mut response_rx) = mpsc::channel(1);
self.request_tx
.send(Command::ClearCache(response_tx))
.unwrap();
response_rx.recv().await.unwrap()
pub async fn clear_cache(&mut self) -> Result<()> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.clear_cache())
.collect();
join_all(futures).await.into_iter().collect()
}
}

View File

@ -39,9 +39,9 @@ impl Batcher {
// Spawn batching background task that contains all the inference logic
tokio::spawn(batching_task(
client,
max_batch_size,
max_waiting_tokens,
client,
db.clone(),
shared.clone(),
));
@ -86,9 +86,9 @@ impl Batcher {
/// Batches requests and sends them to the inference server
#[instrument(skip(client, db, shared))]
async fn batching_task(
mut client: ShardedClient,
max_batch_size: usize,
max_waiting_tokens: usize,
client: ShardedClient,
db: Db,
shared: Arc<Shared>,
) {

View File

@ -61,7 +61,7 @@ fn main() -> Result<(), std::io::Error> {
.unwrap()
.block_on(async {
// Instantiate sharded client from the master unix socket
let sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
.await
.expect("Could not connect to server");
// Clear the cache; useful if the webserver rebooted