avoid join_all
This commit is contained in:
parent
b21ed583ac
commit
e5c27364be
|
@ -4,7 +4,8 @@ use crate::{ClientError, Result};
|
||||||
|
|
||||||
use crate::v3::{Chunk, InfoResponse, Input, TerminatedGeneration};
|
use crate::v3::{Chunk, InfoResponse, Input, TerminatedGeneration};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use futures::future::join_all;
|
use futures::stream::FuturesUnordered;
|
||||||
|
use futures::stream::StreamExt;
|
||||||
use tonic::transport::Uri;
|
use tonic::transport::Uri;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
use v3::client::{DecodeTimings, PrefillTimings};
|
use v3::client::{DecodeTimings, PrefillTimings};
|
||||||
|
@ -29,8 +30,12 @@ impl ShardedClient {
|
||||||
async fn from_master_client(mut master_client: Client) -> Result<Self> {
|
async fn from_master_client(mut master_client: Client) -> Result<Self> {
|
||||||
// Get all uris/unix sockets from the master client
|
// Get all uris/unix sockets from the master client
|
||||||
let uris = master_client.service_discovery().await?;
|
let uris = master_client.service_discovery().await?;
|
||||||
let futures = uris.into_iter().map(Client::connect_uds);
|
let futures: FuturesUnordered<_> = uris.into_iter().map(Client::connect_uds).collect();
|
||||||
let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
|
let clients: Result<Vec<Client>> = futures
|
||||||
|
.collect::<Vec<Result<_>>>()
|
||||||
|
.await
|
||||||
|
.into_iter()
|
||||||
|
.collect();
|
||||||
Ok(Self::new(clients?))
|
Ok(Self::new(clients?))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -49,34 +54,43 @@ impl ShardedClient {
|
||||||
/// Get the model info
|
/// Get the model info
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
pub async fn info(&mut self) -> Result<ShardInfo> {
|
pub async fn info(&mut self) -> Result<ShardInfo> {
|
||||||
let futures: Vec<_> = self
|
let futures: FuturesUnordered<_> = self
|
||||||
.clients
|
.clients
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.map(|client| client.info())
|
.map(|client| client.info())
|
||||||
.collect();
|
.collect();
|
||||||
join_all(futures).await.pop().unwrap().map(ShardInfo::from)
|
futures
|
||||||
|
.collect::<Vec<Result<_>>>()
|
||||||
|
.await
|
||||||
|
.pop()
|
||||||
|
.unwrap()
|
||||||
|
.map(ShardInfo::from)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// GRPC health check
|
/// GRPC health check
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
pub async fn health(&mut self) -> Result<HealthResponse> {
|
pub async fn health(&mut self) -> Result<HealthResponse> {
|
||||||
let futures: Vec<_> = self
|
let futures: FuturesUnordered<_> = self
|
||||||
.clients
|
.clients
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.map(|client| client.health())
|
.map(|client| client.health())
|
||||||
.collect();
|
.collect();
|
||||||
join_all(futures).await.pop().unwrap()
|
futures.collect::<Vec<Result<_>>>().await.pop().unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Clear the past generations cache
|
/// Clear the past generations cache
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||||
let futures: Vec<_> = self
|
let futures: FuturesUnordered<_> = self
|
||||||
.clients
|
.clients
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.map(|client| client.clear_cache(batch_id))
|
.map(|client| client.clear_cache(batch_id))
|
||||||
.collect();
|
.collect();
|
||||||
join_all(futures).await.into_iter().collect()
|
futures
|
||||||
|
.collect::<Vec<Result<_>>>()
|
||||||
|
.await
|
||||||
|
.into_iter()
|
||||||
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Filter a cached batch
|
/// Filter a cached batch
|
||||||
|
@ -87,7 +101,7 @@ impl ShardedClient {
|
||||||
kept_requests: Vec<KeptRequest>,
|
kept_requests: Vec<KeptRequest>,
|
||||||
terminated_request_ids: Vec<u64>,
|
terminated_request_ids: Vec<u64>,
|
||||||
) -> Result<(Option<CachedBatch>, Vec<TerminatedGeneration>)> {
|
) -> Result<(Option<CachedBatch>, Vec<TerminatedGeneration>)> {
|
||||||
let futures: Vec<_> = self
|
let futures: FuturesUnordered<_> = self
|
||||||
.clients
|
.clients
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.map(|client| {
|
.map(|client| {
|
||||||
|
@ -99,7 +113,7 @@ impl ShardedClient {
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
// all shards return the same message
|
// all shards return the same message
|
||||||
join_all(futures).await.pop().unwrap()
|
futures.collect::<Vec<Result<_>>>().await.pop().unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Warmup on a max size batch
|
/// Warmup on a max size batch
|
||||||
|
@ -113,7 +127,7 @@ impl ShardedClient {
|
||||||
max_total_tokens: u32,
|
max_total_tokens: u32,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
) -> Result<Option<u32>> {
|
) -> Result<Option<u32>> {
|
||||||
let futures: Vec<_> = self
|
let futures: FuturesUnordered<_> = self
|
||||||
.clients
|
.clients
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.map(|client| {
|
.map(|client| {
|
||||||
|
@ -126,7 +140,8 @@ impl ShardedClient {
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
// Take the minimum value
|
// Take the minimum value
|
||||||
let results = join_all(futures)
|
let results = futures
|
||||||
|
.collect::<Vec<Result<_>>>()
|
||||||
.await
|
.await
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.collect::<Result<Vec<Option<u32>>>>()?;
|
.collect::<Result<Vec<Option<u32>>>>()?;
|
||||||
|
@ -142,14 +157,17 @@ impl ShardedClient {
|
||||||
&mut self,
|
&mut self,
|
||||||
batch: Batch,
|
batch: Batch,
|
||||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||||
let futures: Vec<_> = self
|
let futures: FuturesUnordered<_> = self
|
||||||
.clients
|
.clients
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.map(|client| Box::pin(client.prefill(batch.clone())))
|
.map(|client| Box::pin(client.prefill(batch.clone())))
|
||||||
.collect();
|
.collect();
|
||||||
#[allow(clippy::type_complexity)]
|
#[allow(clippy::type_complexity)]
|
||||||
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
|
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> = futures
|
||||||
join_all(futures).await.into_iter().collect();
|
.collect::<Vec<Result<_>>>()
|
||||||
|
.await
|
||||||
|
.into_iter()
|
||||||
|
.collect();
|
||||||
let mut results = results?;
|
let mut results = results?;
|
||||||
|
|
||||||
let (mut generations, next_batch, mut timings) =
|
let (mut generations, next_batch, mut timings) =
|
||||||
|
@ -175,14 +193,17 @@ impl ShardedClient {
|
||||||
&mut self,
|
&mut self,
|
||||||
batches: Vec<CachedBatch>,
|
batches: Vec<CachedBatch>,
|
||||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||||
let futures: Vec<_> = self
|
let futures: FuturesUnordered<_> = self
|
||||||
.clients
|
.clients
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.map(|client| Box::pin(client.decode(batches.clone())))
|
.map(|client| Box::pin(client.decode(batches.clone())))
|
||||||
.collect();
|
.collect();
|
||||||
#[allow(clippy::type_complexity)]
|
#[allow(clippy::type_complexity)]
|
||||||
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
|
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> = futures
|
||||||
join_all(futures).await.into_iter().collect();
|
.collect::<Vec<Result<_>>>()
|
||||||
|
.await
|
||||||
|
.into_iter()
|
||||||
|
.collect();
|
||||||
let mut results = results?;
|
let mut results = results?;
|
||||||
|
|
||||||
let (mut generations, next_batch, mut timings) =
|
let (mut generations, next_batch, mut timings) =
|
||||||
|
|
Loading…
Reference in New Issue