feat(launcher): default num_shard to CUDA_VISIBLE_DEVICES if possible (#108)
This commit is contained in:
parent
0ac38d336a
commit
5fd2dcb513
|
@ -115,13 +115,11 @@ fn main() -> ExitCode {
|
|||
None => {
|
||||
// try to default to the number of available GPUs
|
||||
tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES");
|
||||
let cuda_visible_devices = env::var("CUDA_VISIBLE_DEVICES")
|
||||
let n_devices = num_cuda_devices()
|
||||
.expect("--num-shard and CUDA_VISIBLE_DEVICES are not set");
|
||||
let n_devices = cuda_visible_devices.split(",").count();
|
||||
if n_devices <= 1 {
|
||||
panic!("`sharded` is true but only found {n_devices} CUDA devices");
|
||||
}
|
||||
tracing::info!("Sharding on {n_devices} found CUDA devices");
|
||||
n_devices
|
||||
}
|
||||
Some(num_shard) => {
|
||||
|
@ -144,9 +142,19 @@ fn main() -> ExitCode {
|
|||
}
|
||||
}
|
||||
} else {
|
||||
// default to a single shard
|
||||
num_shard.unwrap_or(1)
|
||||
match num_shard {
|
||||
// get num_shard from CUDA_VISIBLE_DEVICES or default to a single shard
|
||||
None => num_cuda_devices().unwrap_or(1),
|
||||
Some(num_shard) => num_shard,
|
||||
}
|
||||
};
|
||||
if num_shard < 1 {
|
||||
panic!("`num_shard` cannot be < 1");
|
||||
}
|
||||
|
||||
if num_shard > 1 {
|
||||
tracing::info!("Sharding model on {num_shard} processes");
|
||||
}
|
||||
|
||||
// Signal handler
|
||||
let running = Arc::new(AtomicBool::new(true));
|
||||
|
@ -669,3 +677,11 @@ fn shutdown_shards(shutdown: Arc<Mutex<bool>>, shutdown_receiver: &mpsc::Receive
|
|||
// This will block till all shutdown_sender are dropped
|
||||
let _ = shutdown_receiver.recv();
|
||||
}
|
||||
|
||||
fn num_cuda_devices() -> Option<usize> {
|
||||
if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") {
|
||||
let n_devices = cuda_visible_devices.split(',').count();
|
||||
return Some(n_devices);
|
||||
}
|
||||
None
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue