diff --git a/launcher/src/main.rs b/launcher/src/main.rs index a863ea4a..0810d979 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -455,9 +455,10 @@ fn shutdown_shards(shutdown: Arc>, shutdown_receiver: &mpsc::Receive } fn num_cuda_devices() -> Option { - let devices = env::var("CUDA_VISIBLE_DEVICES") - .map_err(|_| env::var("NVIDIA_VISIBLE_DEVICES")) - .ok()?; + let devices = match env::var("CUDA_VISIBLE_DEVICES") { + Ok(devices) => devices, + Err(_) => env::var("NVIDIA_VISIBLE_DEVICES").ok()?, + }; let n_devices = devices.split(',').count(); Some(n_devices) }