feat: add payload limit

This commit is contained in:
OlivierDehaene 2024-11-05 16:35:33 +01:00
parent b1f9044d6c
commit 6297f1769f
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C
7 changed files with 34 additions and 9 deletions

View File

@ -62,6 +62,8 @@ struct Args {
executor_worker: PathBuf, executor_worker: PathBuf,
#[clap(default_value = "on", long, env)] #[clap(default_value = "on", long, env)]
usage_stats: usage_stats::UsageStatsLevel, usage_stats: usage_stats::UsageStatsLevel,
#[clap(default_value = "2000000", long, env)]
payload_limit: usize,
} }
async fn get_tokenizer( async fn get_tokenizer(
@ -217,6 +219,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
auth_token, auth_token,
executor_worker, executor_worker,
usage_stats, usage_stats,
payload_limit,
} = args; } = args;
// Launch Tokio runtime // Launch Tokio runtime
@ -287,6 +290,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
tokenizer_name, tokenizer_name,
tokenizer_config_path, tokenizer_config_path,
revision, revision,
false,
hostname, hostname,
port, port,
cors_allow_origin, cors_allow_origin,
@ -296,6 +300,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
true, true,
max_client_batch_size, max_client_batch_size,
usage_stats, usage_stats,
payload_limit,
) )
.await?; .await?;
Ok(()) Ok(())

View File

@ -70,6 +70,8 @@ struct Args {
max_client_batch_size: usize, max_client_batch_size: usize,
#[clap(default_value = "on", long, env)] #[clap(default_value = "on", long, env)]
usage_stats: usage_stats::UsageStatsLevel, usage_stats: usage_stats::UsageStatsLevel,
#[clap(default_value = "2000000", long, env)]
payload_limit: usize,
} }
#[derive(Debug, Subcommand)] #[derive(Debug, Subcommand)]
@ -114,6 +116,7 @@ async fn main() -> Result<(), RouterError> {
disable_grammar_support, disable_grammar_support,
max_client_batch_size, max_client_batch_size,
usage_stats, usage_stats,
payload_limit,
} = args; } = args;
if let Some(Commands::PrintSchema) = command { if let Some(Commands::PrintSchema) = command {
@ -194,6 +197,7 @@ async fn main() -> Result<(), RouterError> {
disable_grammar_support, disable_grammar_support,
max_client_batch_size, max_client_batch_size,
usage_stats, usage_stats,
payload_limit,
) )
.await?; .await?;
Ok(()) Ok(())

View File

@ -70,6 +70,8 @@ struct Args {
max_client_batch_size: usize, max_client_batch_size: usize,
#[clap(default_value = "on", long, env)] #[clap(default_value = "on", long, env)]
usage_stats: usage_stats::UsageStatsLevel, usage_stats: usage_stats::UsageStatsLevel,
#[clap(default_value = "2000000", long, env)]
payload_limit: usize,
} }
#[derive(Debug, Subcommand)] #[derive(Debug, Subcommand)]
@ -114,6 +116,7 @@ async fn main() -> Result<(), RouterError> {
disable_grammar_support, disable_grammar_support,
max_client_batch_size, max_client_batch_size,
usage_stats, usage_stats,
payload_limit,
} = args; } = args;
if let Some(Commands::PrintSchema) = command { if let Some(Commands::PrintSchema) = command {
@ -210,6 +213,7 @@ async fn main() -> Result<(), RouterError> {
disable_grammar_support, disable_grammar_support,
max_client_batch_size, max_client_batch_size,
usage_stats, usage_stats,
payload_limit,
) )
.await?; .await?;
Ok(()) Ok(())

View File

@ -687,6 +687,12 @@ struct Args {
/// Defaul is on. /// Defaul is on.
#[clap(default_value = "on", long, env)] #[clap(default_value = "on", long, env)]
usage_stats: UsageStatsLevel, usage_stats: UsageStatsLevel,
/// Payload size limit in bytes
///
/// Default is 2MB
#[clap(default_value = "2000000", long, env)]
payload_limit: usize,
} }
#[derive(Debug)] #[derive(Debug)]
@ -1474,6 +1480,8 @@ fn spawn_webserver(
format!("{}-0", args.shard_uds_path), format!("{}-0", args.shard_uds_path),
"--tokenizer-name".to_string(), "--tokenizer-name".to_string(),
args.model_id, args.model_id,
"--payload-limit".to_string(),
args.payload_limit.to_string(),
]; ];
if let Some(max_input_tokens) = max_input_tokens { if let Some(max_input_tokens) = max_input_tokens {
router_args.extend_from_slice(&[ router_args.extend_from_slice(&[

View File

@ -31,7 +31,7 @@ use crate::{
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType}; use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType};
use crate::{ModelInfo, ModelsInfo}; use crate::{ModelInfo, ModelsInfo};
use async_stream::__private::AsyncStream; use async_stream::__private::AsyncStream;
use axum::extract::Extension; use axum::extract::{DefaultBodyLimit, Extension};
use axum::http::{HeaderMap, HeaderValue, Method, StatusCode}; use axum::http::{HeaderMap, HeaderValue, Method, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse}; use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::{IntoResponse, Response}; use axum::response::{IntoResponse, Response};
@ -1673,6 +1673,7 @@ pub async fn run(
disable_grammar_support: bool, disable_grammar_support: bool,
max_client_batch_size: usize, max_client_batch_size: usize,
usage_stats_level: usage_stats::UsageStatsLevel, usage_stats_level: usage_stats::UsageStatsLevel,
payload_limit: usize,
) -> Result<(), WebServerError> { ) -> Result<(), WebServerError> {
// CORS allowed origins // CORS allowed origins
// map to go inside the option and then map to parse from String to HeaderValue // map to go inside the option and then map to parse from String to HeaderValue
@ -1926,6 +1927,7 @@ pub async fn run(
model_info, model_info,
compat_return_full_text, compat_return_full_text,
allow_origin, allow_origin,
payload_limit,
) )
.await; .await;
@ -1985,6 +1987,7 @@ async fn start(
model_info: HubModelInfo, model_info: HubModelInfo,
compat_return_full_text: bool, compat_return_full_text: bool,
allow_origin: Option<AllowOrigin>, allow_origin: Option<AllowOrigin>,
payload_limit: usize,
) -> Result<(), WebServerError> { ) -> Result<(), WebServerError> {
// Determine the server port based on the feature and environment variable. // Determine the server port based on the feature and environment variable.
let port = if cfg!(feature = "google") { let port = if cfg!(feature = "google") {
@ -2382,6 +2385,7 @@ async fn start(
.layer(Extension(compute_type)) .layer(Extension(compute_type))
.layer(Extension(prom_handle.clone())) .layer(Extension(prom_handle.clone()))
.layer(OtelAxumLayer::default()) .layer(OtelAxumLayer::default())
.layer(DefaultBodyLimit::max(payload_limit))
.layer(cors_layer); .layer(cors_layer);
tracing::info!("Connected"); tracing::info!("Connected");

View File

@ -962,9 +962,9 @@ class FlashCausalLMBatch(Batch):
self.input_lengths_tensor = torch.tensor( self.input_lengths_tensor = torch.tensor(
self.input_lengths, dtype=torch.int32, device=device self.input_lengths, dtype=torch.int32, device=device
) )
self.cu_seqlen_prefill = torch.nn.functional.pad( cu_seqlen_prefill = self.input_lengths_tensor.new_zeros(len(self) + 1)
torch.cumsum(self.input_lengths_tensor, dim=0), (1, 0) torch.cumsum(self.input_lengths_tensor, out=cu_seqlen_prefill[1:], dim=0)
).to(torch.int32) self.cu_seqlen_prefill = cu_seqlen_prefill.to(torch.int32)
self.cache_lengths_tensor = torch.tensor( self.cache_lengths_tensor = torch.tensor(
self.cache_lengths, dtype=torch.int32, device=device self.cache_lengths, dtype=torch.int32, device=device
) )
@ -2020,9 +2020,8 @@ class FlashCausalLM(Model):
# For each member of the batch # For each member of the batch
# Cumulative length # Cumulative length
cu_accepted_ids = torch.nn.functional.pad( cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
torch.cumsum(accepted_ids, dim=0), (1, 0) torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
)
cumulative_length = 0 cumulative_length = 0
for i, ( for i, (
request, request,

View File

@ -66,8 +66,9 @@ def block_tables_to_ragged(
) )
if has_triton(): if has_triton():
cu_seqlen = torch.nn.functional.pad( cu_seqlen = input_lengths_tensor.new_zeros(input_lengths_tensor.shape[0] + 1)
torch.cumsum(input_lengths_tensor + cache_lengths_tensor, dim=0), (1, 0) torch.cumsum(
input_lengths_tensor + cache_lengths_tensor, out=cu_seqlen[1:], dim=0
) )
def grid(meta): def grid(meta):