From f9cf3456250e420af65e1d813ccee6af749658ad Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 26 Apr 2024 15:44:44 +0200 Subject: [PATCH] Adding new env variables for TPU backends. (#1755) # What does this PR do? On TPU (and probably inferentia). The model needs to know right off the bat about BATCH_SIZE and MAX_TOTAL_TOKENS (since the entire cache will be determined by both). This PR sends that information to the shards to they can allocate accordingly. Should be no-op for other backends. Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- launcher/src/main.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 40e7364f..f264e000 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -448,6 +448,8 @@ fn shard_manager( cuda_memory_fraction: f32, rope_scaling: Option, rope_factor: Option, + max_total_tokens: usize, + max_batch_size: Option, otlp_endpoint: Option, status_sender: mpsc::Sender, shutdown: Arc, @@ -512,6 +514,7 @@ fn shard_manager( (Some(scaling), Some(factor)) => Some((scaling, factor)), (None, Some(factor)) => Some((RopeScaling::Linear, factor)), }; + // OpenTelemetry if let Some(otlp_endpoint) = otlp_endpoint { shard_args.push("--otlp-endpoint".to_string()); @@ -564,6 +567,14 @@ fn shard_manager( envs.push(("ROPE_FACTOR".into(), factor.to_string().into())); } + envs.push(( + "MAX_TOTAL_TOKENS".into(), + max_total_tokens.to_string().into(), + )); + if let Some(max_batch_size) = max_batch_size { + envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into())); + } + // If huggingface_hub_cache is some, pass it to the shard // Useful when running inside a docker container if let Some(huggingface_hub_cache) = huggingface_hub_cache { @@ -967,6 +978,7 @@ fn spawn_shards( num_shard: usize, args: &Args, cuda_graphs: Vec, + max_total_tokens: usize, shutdown: Arc, shutdown_receiver: &mpsc::Receiver<()>, shutdown_sender: mpsc::Sender<()>, @@ -998,6 +1010,7 @@ fn spawn_shards( let cuda_memory_fraction = args.cuda_memory_fraction; let rope_scaling = args.rope_scaling; let rope_factor = args.rope_factor; + let max_batch_size = args.max_batch_size; thread::spawn(move || { shard_manager( model_id, @@ -1020,6 +1033,8 @@ fn spawn_shards( cuda_memory_fraction, rope_scaling, rope_factor, + max_total_tokens, + max_batch_size, otlp_endpoint, status_sender, shutdown, @@ -1474,6 +1489,7 @@ fn main() -> Result<(), LauncherError> { num_shard, &args, cuda_graphs, + max_total_tokens, shutdown.clone(), &shutdown_receiver, shutdown_sender,