From 6297f1769f7d5559ed49d8ed6a9d78cf0bbae3bb Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 5 Nov 2024 16:35:33 +0100 Subject: [PATCH] feat: add payload limit --- backends/trtllm/src/main.rs | 5 +++++ backends/v2/src/main.rs | 4 ++++ backends/v3/src/main.rs | 4 ++++ launcher/src/main.rs | 8 ++++++++ router/src/server.rs | 6 +++++- .../text_generation_server/models/flash_causal_lm.py | 11 +++++------ .../text_generation_server/models/metadata_kernels.py | 5 +++-- 7 files changed, 34 insertions(+), 9 deletions(-) diff --git a/backends/trtllm/src/main.rs b/backends/trtllm/src/main.rs index 6a247fc1..8ab8c533 100644 --- a/backends/trtllm/src/main.rs +++ b/backends/trtllm/src/main.rs @@ -62,6 +62,8 @@ struct Args { executor_worker: PathBuf, #[clap(default_value = "on", long, env)] usage_stats: usage_stats::UsageStatsLevel, + #[clap(default_value = "2000000", long, env)] + payload_limit: usize, } async fn get_tokenizer( @@ -217,6 +219,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> { auth_token, executor_worker, usage_stats, + payload_limit, } = args; // Launch Tokio runtime @@ -287,6 +290,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> { tokenizer_name, tokenizer_config_path, revision, + false, hostname, port, cors_allow_origin, @@ -296,6 +300,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> { true, max_client_batch_size, usage_stats, + payload_limit, ) .await?; Ok(()) diff --git a/backends/v2/src/main.rs b/backends/v2/src/main.rs index ab4b7ce1..f537690e 100644 --- a/backends/v2/src/main.rs +++ b/backends/v2/src/main.rs @@ -70,6 +70,8 @@ struct Args { max_client_batch_size: usize, #[clap(default_value = "on", long, env)] usage_stats: usage_stats::UsageStatsLevel, + #[clap(default_value = "2000000", long, env)] + payload_limit: usize, } #[derive(Debug, Subcommand)] @@ -114,6 +116,7 @@ async fn main() -> Result<(), RouterError> { disable_grammar_support, max_client_batch_size, usage_stats, + payload_limit, } = args; if let Some(Commands::PrintSchema) = command { @@ -194,6 +197,7 @@ async fn main() -> Result<(), RouterError> { disable_grammar_support, max_client_batch_size, usage_stats, + payload_limit, ) .await?; Ok(()) diff --git a/backends/v3/src/main.rs b/backends/v3/src/main.rs index 279a8252..52e41b55 100644 --- a/backends/v3/src/main.rs +++ b/backends/v3/src/main.rs @@ -70,6 +70,8 @@ struct Args { max_client_batch_size: usize, #[clap(default_value = "on", long, env)] usage_stats: usage_stats::UsageStatsLevel, + #[clap(default_value = "2000000", long, env)] + payload_limit: usize, } #[derive(Debug, Subcommand)] @@ -114,6 +116,7 @@ async fn main() -> Result<(), RouterError> { disable_grammar_support, max_client_batch_size, usage_stats, + payload_limit, } = args; if let Some(Commands::PrintSchema) = command { @@ -210,6 +213,7 @@ async fn main() -> Result<(), RouterError> { disable_grammar_support, max_client_batch_size, usage_stats, + payload_limit, ) .await?; Ok(()) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 64f4f515..9a62781c 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -687,6 +687,12 @@ struct Args { /// Defaul is on. #[clap(default_value = "on", long, env)] usage_stats: UsageStatsLevel, + + /// Payload size limit in bytes + /// + /// Default is 2MB + #[clap(default_value = "2000000", long, env)] + payload_limit: usize, } #[derive(Debug)] @@ -1474,6 +1480,8 @@ fn spawn_webserver( format!("{}-0", args.shard_uds_path), "--tokenizer-name".to_string(), args.model_id, + "--payload-limit".to_string(), + args.payload_limit.to_string(), ]; if let Some(max_input_tokens) = max_input_tokens { router_args.extend_from_slice(&[ diff --git a/router/src/server.rs b/router/src/server.rs index 7d8d518c..adb5315f 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -31,7 +31,7 @@ use crate::{ use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType}; use crate::{ModelInfo, ModelsInfo}; use async_stream::__private::AsyncStream; -use axum::extract::Extension; +use axum::extract::{DefaultBodyLimit, Extension}; use axum::http::{HeaderMap, HeaderValue, Method, StatusCode}; use axum::response::sse::{Event, KeepAlive, Sse}; use axum::response::{IntoResponse, Response}; @@ -1673,6 +1673,7 @@ pub async fn run( disable_grammar_support: bool, max_client_batch_size: usize, usage_stats_level: usage_stats::UsageStatsLevel, + payload_limit: usize, ) -> Result<(), WebServerError> { // CORS allowed origins // map to go inside the option and then map to parse from String to HeaderValue @@ -1926,6 +1927,7 @@ pub async fn run( model_info, compat_return_full_text, allow_origin, + payload_limit, ) .await; @@ -1985,6 +1987,7 @@ async fn start( model_info: HubModelInfo, compat_return_full_text: bool, allow_origin: Option, + payload_limit: usize, ) -> Result<(), WebServerError> { // Determine the server port based on the feature and environment variable. let port = if cfg!(feature = "google") { @@ -2382,6 +2385,7 @@ async fn start( .layer(Extension(compute_type)) .layer(Extension(prom_handle.clone())) .layer(OtelAxumLayer::default()) + .layer(DefaultBodyLimit::max(payload_limit)) .layer(cors_layer); tracing::info!("Connected"); diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index bb908fd0..36f70180 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -962,9 +962,9 @@ class FlashCausalLMBatch(Batch): self.input_lengths_tensor = torch.tensor( self.input_lengths, dtype=torch.int32, device=device ) - self.cu_seqlen_prefill = torch.nn.functional.pad( - torch.cumsum(self.input_lengths_tensor, dim=0), (1, 0) - ).to(torch.int32) + cu_seqlen_prefill = self.input_lengths_tensor.new_zeros(len(self) + 1) + torch.cumsum(self.input_lengths_tensor, out=cu_seqlen_prefill[1:], dim=0) + self.cu_seqlen_prefill = cu_seqlen_prefill.to(torch.int32) self.cache_lengths_tensor = torch.tensor( self.cache_lengths, dtype=torch.int32, device=device ) @@ -2020,9 +2020,8 @@ class FlashCausalLM(Model): # For each member of the batch # Cumulative length - cu_accepted_ids = torch.nn.functional.pad( - torch.cumsum(accepted_ids, dim=0), (1, 0) - ) + cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1) + torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:]) cumulative_length = 0 for i, ( request, diff --git a/server/text_generation_server/models/metadata_kernels.py b/server/text_generation_server/models/metadata_kernels.py index 783aab80..42b77121 100644 --- a/server/text_generation_server/models/metadata_kernels.py +++ b/server/text_generation_server/models/metadata_kernels.py @@ -66,8 +66,9 @@ def block_tables_to_ragged( ) if has_triton(): - cu_seqlen = torch.nn.functional.pad( - torch.cumsum(input_lengths_tensor + cache_lengths_tensor, dim=0), (1, 0) + cu_seqlen = input_lengths_tensor.new_zeros(input_lengths_tensor.shape[0] + 1) + torch.cumsum( + input_lengths_tensor + cache_lengths_tensor, out=cu_seqlen[1:], dim=0 ) def grid(meta):