fix(rust-client): use join_all instead of select_all to hopefully fix nccl issues (#162)

This commit is contained in:
OlivierDehaene 2023-04-09 20:07:02 +02:00 committed by GitHub
parent e63a21eb4d
commit 5cddc055e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 7 deletions

View File

@ -2,7 +2,6 @@
use crate::Result; use crate::Result;
use crate::{Batch, Client, Generation}; use crate::{Batch, Client, Generation};
use futures::future::join_all; use futures::future::join_all;
use futures::future::select_all;
use tonic::transport::Uri; use tonic::transport::Uri;
use tracing::instrument; use tracing::instrument;
@ -60,9 +59,8 @@ impl ShardedClient {
.iter_mut() .iter_mut()
.map(|client| Box::pin(client.prefill(batch.clone()))) .map(|client| Box::pin(client.prefill(batch.clone())))
.collect(); .collect();
// As soon as we receive one response, we can return as all shards will return the same // all shards return the same message
let (result, _, _) = select_all(futures).await; join_all(futures).await.pop().unwrap()
result
} }
/// Generate one token for each request in the given cached batches /// Generate one token for each request in the given cached batches
@ -79,8 +77,7 @@ impl ShardedClient {
.iter_mut() .iter_mut()
.map(|client| Box::pin(client.decode(batches.clone()))) .map(|client| Box::pin(client.decode(batches.clone())))
.collect(); .collect();
// As soon as we receive one response, we can return as all shards will return the same // all shards return the same message
let (result, _, _) = select_all(futures).await; join_all(futures).await.pop().unwrap()
result
} }
} }