parent
5437d49beb
commit
7b3d460d21
|
@ -1,6 +1,7 @@
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::env;
|
use std::env;
|
||||||
|
use std::ffi::OsString;
|
||||||
use std::io::{BufRead, BufReader, Read};
|
use std::io::{BufRead, BufReader, Read};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::process::ExitCode;
|
use std::process::ExitCode;
|
||||||
|
@ -118,9 +119,10 @@ fn main() -> ExitCode {
|
||||||
download_argv.push(revision.to_string())
|
download_argv.push(revision.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut env = Vec::new();
|
// Copy current process env
|
||||||
|
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
|
||||||
|
|
||||||
// If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard
|
// If huggingface_hub_cache is set, pass it to the shard
|
||||||
// Useful when running inside a docker container
|
// Useful when running inside a docker container
|
||||||
if let Some(ref huggingface_hub_cache) = huggingface_hub_cache {
|
if let Some(ref huggingface_hub_cache) = huggingface_hub_cache {
|
||||||
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
|
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
|
||||||
|
@ -455,14 +457,18 @@ fn shard_manager(
|
||||||
shard_argv.push(otlp_endpoint);
|
shard_argv.push(otlp_endpoint);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut env = vec![
|
// Copy current process env
|
||||||
("RANK".into(), rank.to_string().into()),
|
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
|
||||||
("WORLD_SIZE".into(), world_size.to_string().into()),
|
|
||||||
("MASTER_ADDR".into(), master_addr.into()),
|
// Torch Distributed Env vars
|
||||||
("MASTER_PORT".into(), master_port.to_string().into()),
|
env.push(("RANK".into(), rank.to_string().into()));
|
||||||
("SAFETENSORS_FAST_GPU".into(), "1".into()),
|
env.push(("WORLD_SIZE".into(), world_size.to_string().into()));
|
||||||
("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()),
|
env.push(("MASTER_ADDR".into(), master_addr.into()));
|
||||||
];
|
env.push(("MASTER_PORT".into(), master_port.to_string().into()));
|
||||||
|
env.push(("NCCL_ASYNC_ERROR_HANDLING".into(), "1".into()));
|
||||||
|
|
||||||
|
// Safetensors load fast
|
||||||
|
env.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
|
||||||
|
|
||||||
// If huggingface_hub_cache is some, pass it to the shard
|
// If huggingface_hub_cache is some, pass it to the shard
|
||||||
// Useful when running inside a docker container
|
// Useful when running inside a docker container
|
||||||
|
@ -484,17 +490,6 @@ fn shard_manager(
|
||||||
env.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
|
env.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the NCCL_SHM_DISABLE env var is set, pass it to the shard
|
|
||||||
// needed when running NCCL inside a docker container and when you can't increase shm size
|
|
||||||
if let Ok(nccl_shm_disalbe) = env::var("NCCL_SHM_DISABLE") {
|
|
||||||
env.push(("NCCL_SHM_DISABLE".into(), nccl_shm_disalbe.into()));
|
|
||||||
};
|
|
||||||
|
|
||||||
// If the CUDA_VISIBLE_DEVICES env var is set, pass it to the shard
|
|
||||||
if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") {
|
|
||||||
env.push(("CUDA_VISIBLE_DEVICES".into(), cuda_visible_devices.into()));
|
|
||||||
};
|
|
||||||
|
|
||||||
// Start process
|
// Start process
|
||||||
tracing::info!("Starting shard {rank}");
|
tracing::info!("Starting shard {rank}");
|
||||||
let mut p = match Popen::create(
|
let mut p = match Popen::create(
|
||||||
|
|
Loading…
Reference in New Issue