fix(launcher): parse num cuda devices from CUDA_VISIBLE_DEVICES and NVIDIA_VISIBLE_DEVICES
This commit is contained in:
parent
5fde8d9991
commit
146e72c3be
|
@ -455,11 +455,11 @@ fn shutdown_shards(shutdown: Arc<Mutex<bool>>, shutdown_receiver: &mpsc::Receive
|
|||
}
|
||||
|
||||
fn num_cuda_devices() -> Option<usize> {
|
||||
if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") {
|
||||
let n_devices = cuda_visible_devices.split(',').count();
|
||||
return Some(n_devices);
|
||||
}
|
||||
None
|
||||
let devices = env::var("CUDA_VISIBLE_DEVICES")
|
||||
.map_err(|_| env::var("NVIDIA_VISIBLE_DEVICES"))
|
||||
.ok()?;
|
||||
let n_devices = devices.split(',').count();
|
||||
Some(n_devices)
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
|
@ -509,9 +509,9 @@ fn find_num_shards(sharded: Option<bool>, num_shard: Option<usize>) -> usize {
|
|||
let num_shard = match (sharded, num_shard) {
|
||||
(Some(true), None) => {
|
||||
// try to default to the number of available GPUs
|
||||
tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES");
|
||||
let n_devices =
|
||||
num_cuda_devices().expect("--num-shard and CUDA_VISIBLE_DEVICES are not set");
|
||||
tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES");
|
||||
let n_devices = num_cuda_devices()
|
||||
.expect("--num-shard and CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES are not set");
|
||||
if n_devices <= 1 {
|
||||
panic!("`sharded` is true but only found {n_devices} CUDA devices");
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue