fix(launcher): parse num cuda devices from CUDA_VISIBLE_DEVICES and NVIDIA_VISIBLE_DEVICES

This commit is contained in:
OlivierDehaene 2023-05-30 13:27:48 +02:00
parent 146e72c3be
commit 49a6c8c1b2
1 changed files with 4 additions and 3 deletions

View File

@ -455,9 +455,10 @@ fn shutdown_shards(shutdown: Arc<Mutex<bool>>, shutdown_receiver: &mpsc::Receive
}
fn num_cuda_devices() -> Option<usize> {
let devices = env::var("CUDA_VISIBLE_DEVICES")
.map_err(|_| env::var("NVIDIA_VISIBLE_DEVICES"))
.ok()?;
let devices = match env::var("CUDA_VISIBLE_DEVICES") {
Ok(devices) => devices,
Err(_) => env::var("NVIDIA_VISIBLE_DEVICES").ok()?,
};
let n_devices = devices.split(',').count();
Some(n_devices)
}