fix(rust-client): use join_all instead of select_all to hopefully fix nccl issues (#162)
This commit is contained in:
parent
e63a21eb4d
commit
5cddc055e6
|
@ -2,7 +2,6 @@
|
||||||
use crate::Result;
|
use crate::Result;
|
||||||
use crate::{Batch, Client, Generation};
|
use crate::{Batch, Client, Generation};
|
||||||
use futures::future::join_all;
|
use futures::future::join_all;
|
||||||
use futures::future::select_all;
|
|
||||||
use tonic::transport::Uri;
|
use tonic::transport::Uri;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
|
||||||
|
@ -60,9 +59,8 @@ impl ShardedClient {
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.map(|client| Box::pin(client.prefill(batch.clone())))
|
.map(|client| Box::pin(client.prefill(batch.clone())))
|
||||||
.collect();
|
.collect();
|
||||||
// As soon as we receive one response, we can return as all shards will return the same
|
// all shards return the same message
|
||||||
let (result, _, _) = select_all(futures).await;
|
join_all(futures).await.pop().unwrap()
|
||||||
result
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate one token for each request in the given cached batches
|
/// Generate one token for each request in the given cached batches
|
||||||
|
@ -79,8 +77,7 @@ impl ShardedClient {
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.map(|client| Box::pin(client.decode(batches.clone())))
|
.map(|client| Box::pin(client.decode(batches.clone())))
|
||||||
.collect();
|
.collect();
|
||||||
// As soon as we receive one response, we can return as all shards will return the same
|
// all shards return the same message
|
||||||
let (result, _, _) = select_all(futures).await;
|
join_all(futures).await.pop().unwrap()
|
||||||
result
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue