fix(launcher): parse num cuda devices from CUDA_VISIBLE_DEVICES and NVIDIA_VISIBLE_DEVICES
This commit is contained in:
parent
146e72c3be
commit
49a6c8c1b2
|
@ -455,9 +455,10 @@ fn shutdown_shards(shutdown: Arc<Mutex<bool>>, shutdown_receiver: &mpsc::Receive
|
||||||
}
|
}
|
||||||
|
|
||||||
fn num_cuda_devices() -> Option<usize> {
|
fn num_cuda_devices() -> Option<usize> {
|
||||||
let devices = env::var("CUDA_VISIBLE_DEVICES")
|
let devices = match env::var("CUDA_VISIBLE_DEVICES") {
|
||||||
.map_err(|_| env::var("NVIDIA_VISIBLE_DEVICES"))
|
Ok(devices) => devices,
|
||||||
.ok()?;
|
Err(_) => env::var("NVIDIA_VISIBLE_DEVICES").ok()?,
|
||||||
|
};
|
||||||
let n_devices = devices.split(',').count();
|
let n_devices = devices.split(',').count();
|
||||||
Some(n_devices)
|
Some(n_devices)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue