feat(client): Simplify sharded logic
This commit is contained in:
parent
c8ce9b2515
commit
beb552127a
|
@ -2,76 +2,18 @@
|
|||
use crate::Result;
|
||||
use crate::{Batch, Client, GeneratedText};
|
||||
use futures::future::join_all;
|
||||
use tokio::sync::{broadcast, mpsc};
|
||||
use futures::future::select_all;
|
||||
use tonic::transport::Uri;
|
||||
|
||||
/// List of all available commands that can be sent through the command channel
|
||||
#[derive(Clone, Debug)]
|
||||
enum Command {
|
||||
Generate(
|
||||
Batch,
|
||||
mpsc::Sender<Result<(Vec<GeneratedText>, Option<Batch>)>>,
|
||||
),
|
||||
GenerateWithCache(
|
||||
Vec<Batch>,
|
||||
mpsc::Sender<Result<(Vec<GeneratedText>, Option<Batch>)>>,
|
||||
),
|
||||
ClearCache(mpsc::Sender<Result<()>>),
|
||||
}
|
||||
|
||||
/// Tokio task that handles the communication with a single shard
|
||||
///
|
||||
/// We subscribe on a broadcast channel to receive commands that will be sent by
|
||||
/// the ShardedClient.
|
||||
///
|
||||
/// Each command is fan out to all shards.
|
||||
///
|
||||
/// The result of the command is sent back to the ShardedClient through a mpsc channel (multi
|
||||
/// producer = the shards, single consumer = the ShardedClient).
|
||||
async fn client_task(mut client: Client, mut request_subscriber: broadcast::Receiver<Command>) {
|
||||
while let Ok(message) = request_subscriber.recv().await {
|
||||
match message {
|
||||
Command::Generate(batch, response_tx) => {
|
||||
let result = client.generate(batch).await;
|
||||
// We can unwrap_or(()) here because the only error that can happen is if the
|
||||
// receiver is dropped, which means that the ShardedClient already received a
|
||||
// response from another shard
|
||||
response_tx.try_send(result).unwrap_or(());
|
||||
}
|
||||
Command::GenerateWithCache(batches, response_tx) => {
|
||||
let result = client.generate_with_cache(batches).await;
|
||||
response_tx.try_send(result).unwrap_or(());
|
||||
}
|
||||
Command::ClearCache(response_tx) => {
|
||||
let result = client.clear_cache().await;
|
||||
response_tx.try_send(result).unwrap_or(());
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/// Text Generation Inference gRPC multi client
|
||||
pub struct ShardedClient {
|
||||
_clients: Vec<Client>,
|
||||
request_tx: broadcast::Sender<Command>,
|
||||
clients: Vec<Client>,
|
||||
}
|
||||
|
||||
impl ShardedClient {
|
||||
fn new(clients: Vec<Client>) -> Self {
|
||||
// The broadcast channel to communicate with the shards
|
||||
// We use a capacity of one as the shards are not asynchronous and can only process one
|
||||
// command at a time
|
||||
let (request_tx, _) = broadcast::channel(1);
|
||||
|
||||
// Spawn client tasks
|
||||
for client in clients.iter() {
|
||||
let request_subscriber = request_tx.subscribe();
|
||||
tokio::spawn(client_task(client.clone(), request_subscriber));
|
||||
}
|
||||
|
||||
Self {
|
||||
_clients: clients,
|
||||
request_tx,
|
||||
clients,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -101,15 +43,15 @@ impl ShardedClient {
|
|||
///
|
||||
/// Returns a list of generated texts of request that met their stopping criteria
|
||||
/// and the next cached batch
|
||||
pub async fn generate(&self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
||||
// Create a channel to receive the response from the shards
|
||||
// We will only ever receive one message on this channel
|
||||
let (response_tx, mut response_rx) = mpsc::channel(1);
|
||||
self.request_tx
|
||||
.send(Command::Generate(batch, response_tx))
|
||||
.unwrap();
|
||||
pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| Box::pin(client.generate(batch.clone())))
|
||||
.collect();
|
||||
// As soon as we receive one response, we can return as all shards will return the same
|
||||
response_rx.recv().await.unwrap()
|
||||
let (result, _, _) = select_all(futures).await;
|
||||
result
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given cached batch
|
||||
|
@ -117,27 +59,26 @@ impl ShardedClient {
|
|||
/// Returns a list of generated texts of request that met their stopping criteria
|
||||
/// and the next cached batch
|
||||
pub async fn generate_with_cache(
|
||||
&self,
|
||||
&mut self,
|
||||
batches: Vec<Batch>,
|
||||
) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
||||
// Create a channel to receive the response from the shards
|
||||
// We will only ever receive one message on this channel
|
||||
let (response_tx, mut response_rx) = mpsc::channel(1);
|
||||
self.request_tx
|
||||
.send(Command::GenerateWithCache(batches, response_tx))
|
||||
.unwrap();
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| Box::pin(client.generate_with_cache(batches.clone())))
|
||||
.collect();
|
||||
// As soon as we receive one response, we can return as all shards will return the same
|
||||
response_rx.recv().await.unwrap()
|
||||
let (result, _, _) = select_all(futures).await;
|
||||
result
|
||||
}
|
||||
|
||||
/// Clear the past generations cache
|
||||
pub async fn clear_cache(&self) -> Result<()> {
|
||||
// Create a channel to receive the response from the shards
|
||||
// We will only ever receive one message on this channel
|
||||
let (response_tx, mut response_rx) = mpsc::channel(1);
|
||||
self.request_tx
|
||||
.send(Command::ClearCache(response_tx))
|
||||
.unwrap();
|
||||
response_rx.recv().await.unwrap()
|
||||
pub async fn clear_cache(&mut self) -> Result<()> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| client.clear_cache())
|
||||
.collect();
|
||||
join_all(futures).await.into_iter().collect()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -39,9 +39,9 @@ impl Batcher {
|
|||
|
||||
// Spawn batching background task that contains all the inference logic
|
||||
tokio::spawn(batching_task(
|
||||
client,
|
||||
max_batch_size,
|
||||
max_waiting_tokens,
|
||||
client,
|
||||
db.clone(),
|
||||
shared.clone(),
|
||||
));
|
||||
|
@ -86,9 +86,9 @@ impl Batcher {
|
|||
/// Batches requests and sends them to the inference server
|
||||
#[instrument(skip(client, db, shared))]
|
||||
async fn batching_task(
|
||||
mut client: ShardedClient,
|
||||
max_batch_size: usize,
|
||||
max_waiting_tokens: usize,
|
||||
client: ShardedClient,
|
||||
db: Db,
|
||||
shared: Arc<Shared>,
|
||||
) {
|
||||
|
|
|
@ -61,7 +61,7 @@ fn main() -> Result<(), std::io::Error> {
|
|||
.unwrap()
|
||||
.block_on(async {
|
||||
// Instantiate sharded client from the master unix socket
|
||||
let sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||
.await
|
||||
.expect("Could not connect to server");
|
||||
// Clear the cache; useful if the webserver rebooted
|
||||
|
|
Loading…
Reference in New Issue