From 49a6c8c1b28742e806dd95a36af82db5b45d181d Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 30 May 2023 13:27:48 +0200 Subject: [PATCH] fix(launcher): parse num cuda devices from CUDA_VISIBLE_DEVICES and NVIDIA_VISIBLE_DEVICES --- launcher/src/main.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index a863ea4a..0810d979 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -455,9 +455,10 @@ fn shutdown_shards(shutdown: Arc>, shutdown_receiver: &mpsc::Receive } fn num_cuda_devices() -> Option { - let devices = env::var("CUDA_VISIBLE_DEVICES") - .map_err(|_| env::var("NVIDIA_VISIBLE_DEVICES")) - .ok()?; + let devices = match env::var("CUDA_VISIBLE_DEVICES") { + Ok(devices) => devices, + Err(_) => env::var("NVIDIA_VISIBLE_DEVICES").ok()?, + }; let n_devices = devices.split(',').count(); Some(n_devices) }