diff --git a/router/client/src/v3/sharded_client.rs b/router/client/src/v3/sharded_client.rs index 1f9ec3ad..f89bf75d 100644 --- a/router/client/src/v3/sharded_client.rs +++ b/router/client/src/v3/sharded_client.rs @@ -4,7 +4,8 @@ use crate::{ClientError, Result}; use crate::v3::{Chunk, InfoResponse, Input, TerminatedGeneration}; use async_trait::async_trait; -use futures::future::join_all; +use futures::stream::FuturesUnordered; +use futures::stream::StreamExt; use tonic::transport::Uri; use tracing::instrument; use v3::client::{DecodeTimings, PrefillTimings}; @@ -29,8 +30,12 @@ impl ShardedClient { async fn from_master_client(mut master_client: Client) -> Result { // Get all uris/unix sockets from the master client let uris = master_client.service_discovery().await?; - let futures = uris.into_iter().map(Client::connect_uds); - let clients: Result> = join_all(futures).await.into_iter().collect(); + let futures: FuturesUnordered<_> = uris.into_iter().map(Client::connect_uds).collect(); + let clients: Result> = futures + .collect::>>() + .await + .into_iter() + .collect(); Ok(Self::new(clients?)) } @@ -49,34 +54,43 @@ impl ShardedClient { /// Get the model info #[instrument(skip(self))] pub async fn info(&mut self) -> Result { - let futures: Vec<_> = self + let futures: FuturesUnordered<_> = self .clients .iter_mut() .map(|client| client.info()) .collect(); - join_all(futures).await.pop().unwrap().map(ShardInfo::from) + futures + .collect::>>() + .await + .pop() + .unwrap() + .map(ShardInfo::from) } /// GRPC health check #[instrument(skip(self))] pub async fn health(&mut self) -> Result { - let futures: Vec<_> = self + let futures: FuturesUnordered<_> = self .clients .iter_mut() .map(|client| client.health()) .collect(); - join_all(futures).await.pop().unwrap() + futures.collect::>>().await.pop().unwrap() } /// Clear the past generations cache #[instrument(skip(self))] pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> { - let futures: Vec<_> = self + let futures: FuturesUnordered<_> = self .clients .iter_mut() .map(|client| client.clear_cache(batch_id)) .collect(); - join_all(futures).await.into_iter().collect() + futures + .collect::>>() + .await + .into_iter() + .collect() } /// Filter a cached batch @@ -87,7 +101,7 @@ impl ShardedClient { kept_requests: Vec, terminated_request_ids: Vec, ) -> Result<(Option, Vec)> { - let futures: Vec<_> = self + let futures: FuturesUnordered<_> = self .clients .iter_mut() .map(|client| { @@ -99,7 +113,7 @@ impl ShardedClient { }) .collect(); // all shards return the same message - join_all(futures).await.pop().unwrap() + futures.collect::>>().await.pop().unwrap() } /// Warmup on a max size batch @@ -113,7 +127,7 @@ impl ShardedClient { max_total_tokens: u32, max_batch_size: Option, ) -> Result> { - let futures: Vec<_> = self + let futures: FuturesUnordered<_> = self .clients .iter_mut() .map(|client| { @@ -126,7 +140,8 @@ impl ShardedClient { }) .collect(); // Take the minimum value - let results = join_all(futures) + let results = futures + .collect::>>() .await .into_iter() .collect::>>>()?; @@ -142,14 +157,17 @@ impl ShardedClient { &mut self, batch: Batch, ) -> Result<(Vec, Option, PrefillTimings)> { - let futures: Vec<_> = self + let futures: FuturesUnordered<_> = self .clients .iter_mut() .map(|client| Box::pin(client.prefill(batch.clone()))) .collect(); #[allow(clippy::type_complexity)] - let results: Result, Option, PrefillTimings)>> = - join_all(futures).await.into_iter().collect(); + let results: Result, Option, PrefillTimings)>> = futures + .collect::>>() + .await + .into_iter() + .collect(); let mut results = results?; let (mut generations, next_batch, mut timings) = @@ -175,14 +193,17 @@ impl ShardedClient { &mut self, batches: Vec, ) -> Result<(Vec, Option, DecodeTimings)> { - let futures: Vec<_> = self + let futures: FuturesUnordered<_> = self .clients .iter_mut() .map(|client| Box::pin(client.decode(batches.clone()))) .collect(); #[allow(clippy::type_complexity)] - let results: Result, Option, DecodeTimings)>> = - join_all(futures).await.into_iter().collect(); + let results: Result, Option, DecodeTimings)>> = futures + .collect::>>() + .await + .into_iter() + .collect(); let mut results = results?; let (mut generations, next_batch, mut timings) =