diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 40e7364f..f264e000 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -448,6 +448,8 @@ fn shard_manager( cuda_memory_fraction: f32, rope_scaling: Option, rope_factor: Option, + max_total_tokens: usize, + max_batch_size: Option, otlp_endpoint: Option, status_sender: mpsc::Sender, shutdown: Arc, @@ -512,6 +514,7 @@ fn shard_manager( (Some(scaling), Some(factor)) => Some((scaling, factor)), (None, Some(factor)) => Some((RopeScaling::Linear, factor)), }; + // OpenTelemetry if let Some(otlp_endpoint) = otlp_endpoint { shard_args.push("--otlp-endpoint".to_string()); @@ -564,6 +567,14 @@ fn shard_manager( envs.push(("ROPE_FACTOR".into(), factor.to_string().into())); } + envs.push(( + "MAX_TOTAL_TOKENS".into(), + max_total_tokens.to_string().into(), + )); + if let Some(max_batch_size) = max_batch_size { + envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into())); + } + // If huggingface_hub_cache is some, pass it to the shard // Useful when running inside a docker container if let Some(huggingface_hub_cache) = huggingface_hub_cache { @@ -967,6 +978,7 @@ fn spawn_shards( num_shard: usize, args: &Args, cuda_graphs: Vec, + max_total_tokens: usize, shutdown: Arc, shutdown_receiver: &mpsc::Receiver<()>, shutdown_sender: mpsc::Sender<()>, @@ -998,6 +1010,7 @@ fn spawn_shards( let cuda_memory_fraction = args.cuda_memory_fraction; let rope_scaling = args.rope_scaling; let rope_factor = args.rope_factor; + let max_batch_size = args.max_batch_size; thread::spawn(move || { shard_manager( model_id, @@ -1020,6 +1033,8 @@ fn spawn_shards( cuda_memory_fraction, rope_scaling, rope_factor, + max_total_tokens, + max_batch_size, otlp_endpoint, status_sender, shutdown, @@ -1474,6 +1489,7 @@ fn main() -> Result<(), LauncherError> { num_shard, &args, cuda_graphs, + max_total_tokens, shutdown.clone(), &shutdown_receiver, shutdown_sender,