fix(launcher): parse num cuda devices from CUDA_VISIBLE_DEVICES and NVIDIA_VISIBLE_DEVICES

This commit is contained in:
OlivierDehaene 2023-05-30 13:27:48 +02:00
parent 146e72c3be
commit 49a6c8c1b2
1 changed files with 4 additions and 3 deletions

View File

@ -455,9 +455,10 @@ fn shutdown_shards(shutdown: Arc<Mutex<bool>>, shutdown_receiver: &mpsc::Receive
} }
fn num_cuda_devices() -> Option<usize> { fn num_cuda_devices() -> Option<usize> {
let devices = env::var("CUDA_VISIBLE_DEVICES") let devices = match env::var("CUDA_VISIBLE_DEVICES") {
.map_err(|_| env::var("NVIDIA_VISIBLE_DEVICES")) Ok(devices) => devices,
.ok()?; Err(_) => env::var("NVIDIA_VISIBLE_DEVICES").ok()?,
};
let n_devices = devices.split(',').count(); let n_devices = devices.split(',').count();
Some(n_devices) Some(n_devices)
} }