avoid join_all

This commit is contained in:
OlivierDehaene 2024-06-18 15:44:28 +02:00
parent b21ed583ac
commit e5c27364be
1 changed files with 40 additions and 19 deletions

View File

@ -4,7 +4,8 @@ use crate::{ClientError, Result};
use crate::v3::{Chunk, InfoResponse, Input, TerminatedGeneration}; use crate::v3::{Chunk, InfoResponse, Input, TerminatedGeneration};
use async_trait::async_trait; use async_trait::async_trait;
use futures::future::join_all; use futures::stream::FuturesUnordered;
use futures::stream::StreamExt;
use tonic::transport::Uri; use tonic::transport::Uri;
use tracing::instrument; use tracing::instrument;
use v3::client::{DecodeTimings, PrefillTimings}; use v3::client::{DecodeTimings, PrefillTimings};
@ -29,8 +30,12 @@ impl ShardedClient {
async fn from_master_client(mut master_client: Client) -> Result<Self> { async fn from_master_client(mut master_client: Client) -> Result<Self> {
// Get all uris/unix sockets from the master client // Get all uris/unix sockets from the master client
let uris = master_client.service_discovery().await?; let uris = master_client.service_discovery().await?;
let futures = uris.into_iter().map(Client::connect_uds); let futures: FuturesUnordered<_> = uris.into_iter().map(Client::connect_uds).collect();
let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect(); let clients: Result<Vec<Client>> = futures
.collect::<Vec<Result<_>>>()
.await
.into_iter()
.collect();
Ok(Self::new(clients?)) Ok(Self::new(clients?))
} }
@ -49,34 +54,43 @@ impl ShardedClient {
/// Get the model info /// Get the model info
#[instrument(skip(self))] #[instrument(skip(self))]
pub async fn info(&mut self) -> Result<ShardInfo> { pub async fn info(&mut self) -> Result<ShardInfo> {
let futures: Vec<_> = self let futures: FuturesUnordered<_> = self
.clients .clients
.iter_mut() .iter_mut()
.map(|client| client.info()) .map(|client| client.info())
.collect(); .collect();
join_all(futures).await.pop().unwrap().map(ShardInfo::from) futures
.collect::<Vec<Result<_>>>()
.await
.pop()
.unwrap()
.map(ShardInfo::from)
} }
/// GRPC health check /// GRPC health check
#[instrument(skip(self))] #[instrument(skip(self))]
pub async fn health(&mut self) -> Result<HealthResponse> { pub async fn health(&mut self) -> Result<HealthResponse> {
let futures: Vec<_> = self let futures: FuturesUnordered<_> = self
.clients .clients
.iter_mut() .iter_mut()
.map(|client| client.health()) .map(|client| client.health())
.collect(); .collect();
join_all(futures).await.pop().unwrap() futures.collect::<Vec<Result<_>>>().await.pop().unwrap()
} }
/// Clear the past generations cache /// Clear the past generations cache
#[instrument(skip(self))] #[instrument(skip(self))]
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> { pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
let futures: Vec<_> = self let futures: FuturesUnordered<_> = self
.clients .clients
.iter_mut() .iter_mut()
.map(|client| client.clear_cache(batch_id)) .map(|client| client.clear_cache(batch_id))
.collect(); .collect();
join_all(futures).await.into_iter().collect() futures
.collect::<Vec<Result<_>>>()
.await
.into_iter()
.collect()
} }
/// Filter a cached batch /// Filter a cached batch
@ -87,7 +101,7 @@ impl ShardedClient {
kept_requests: Vec<KeptRequest>, kept_requests: Vec<KeptRequest>,
terminated_request_ids: Vec<u64>, terminated_request_ids: Vec<u64>,
) -> Result<(Option<CachedBatch>, Vec<TerminatedGeneration>)> { ) -> Result<(Option<CachedBatch>, Vec<TerminatedGeneration>)> {
let futures: Vec<_> = self let futures: FuturesUnordered<_> = self
.clients .clients
.iter_mut() .iter_mut()
.map(|client| { .map(|client| {
@ -99,7 +113,7 @@ impl ShardedClient {
}) })
.collect(); .collect();
// all shards return the same message // all shards return the same message
join_all(futures).await.pop().unwrap() futures.collect::<Vec<Result<_>>>().await.pop().unwrap()
} }
/// Warmup on a max size batch /// Warmup on a max size batch
@ -113,7 +127,7 @@ impl ShardedClient {
max_total_tokens: u32, max_total_tokens: u32,
max_batch_size: Option<usize>, max_batch_size: Option<usize>,
) -> Result<Option<u32>> { ) -> Result<Option<u32>> {
let futures: Vec<_> = self let futures: FuturesUnordered<_> = self
.clients .clients
.iter_mut() .iter_mut()
.map(|client| { .map(|client| {
@ -126,7 +140,8 @@ impl ShardedClient {
}) })
.collect(); .collect();
// Take the minimum value // Take the minimum value
let results = join_all(futures) let results = futures
.collect::<Vec<Result<_>>>()
.await .await
.into_iter() .into_iter()
.collect::<Result<Vec<Option<u32>>>>()?; .collect::<Result<Vec<Option<u32>>>>()?;
@ -142,14 +157,17 @@ impl ShardedClient {
&mut self, &mut self,
batch: Batch, batch: Batch,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> { ) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let futures: Vec<_> = self let futures: FuturesUnordered<_> = self
.clients .clients
.iter_mut() .iter_mut()
.map(|client| Box::pin(client.prefill(batch.clone()))) .map(|client| Box::pin(client.prefill(batch.clone())))
.collect(); .collect();
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> = let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> = futures
join_all(futures).await.into_iter().collect(); .collect::<Vec<Result<_>>>()
.await
.into_iter()
.collect();
let mut results = results?; let mut results = results?;
let (mut generations, next_batch, mut timings) = let (mut generations, next_batch, mut timings) =
@ -175,14 +193,17 @@ impl ShardedClient {
&mut self, &mut self,
batches: Vec<CachedBatch>, batches: Vec<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> { ) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
let futures: Vec<_> = self let futures: FuturesUnordered<_> = self
.clients .clients
.iter_mut() .iter_mut()
.map(|client| Box::pin(client.decode(batches.clone()))) .map(|client| Box::pin(client.decode(batches.clone())))
.collect(); .collect();
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> = let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> = futures
join_all(futures).await.into_iter().collect(); .collect::<Vec<Result<_>>>()
.await
.into_iter()
.collect();
let mut results = results?; let mut results = results?;
let (mut generations, next_batch, mut timings) = let (mut generations, next_batch, mut timings) =