feat: add payload limit (#2726)

* feat: add payload limit

* update launcher
This commit is contained in:
OlivierDehaene 2024-11-21 19:20:15 +01:00 committed by GitHub
parent d5bc6a20bd
commit ab7ccf5bc3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 45 additions and 9 deletions

View File

@ -62,6 +62,8 @@ struct Args {
executor_worker: PathBuf,
#[clap(default_value = "on", long, env)]
usage_stats: usage_stats::UsageStatsLevel,
#[clap(default_value = "2000000", long, env)]
payload_limit: usize,
}
async fn get_tokenizer(
@ -217,6 +219,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
auth_token,
executor_worker,
usage_stats,
payload_limit,
} = args;
// Launch Tokio runtime
@ -287,6 +290,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
tokenizer_name,
tokenizer_config_path,
revision,
false,
hostname,
port,
cors_allow_origin,
@ -296,6 +300,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
true,
max_client_batch_size,
usage_stats,
payload_limit,
)
.await?;
Ok(())

View File

@ -70,6 +70,8 @@ struct Args {
max_client_batch_size: usize,
#[clap(default_value = "on", long, env)]
usage_stats: usage_stats::UsageStatsLevel,
#[clap(default_value = "2000000", long, env)]
payload_limit: usize,
}
#[derive(Debug, Subcommand)]
@ -114,6 +116,7 @@ async fn main() -> Result<(), RouterError> {
disable_grammar_support,
max_client_batch_size,
usage_stats,
payload_limit,
} = args;
if let Some(Commands::PrintSchema) = command {
@ -194,6 +197,7 @@ async fn main() -> Result<(), RouterError> {
disable_grammar_support,
max_client_batch_size,
usage_stats,
payload_limit,
)
.await?;
Ok(())

View File

@ -70,6 +70,8 @@ struct Args {
max_client_batch_size: usize,
#[clap(default_value = "on", long, env)]
usage_stats: usage_stats::UsageStatsLevel,
#[clap(default_value = "2000000", long, env)]
payload_limit: usize,
}
#[derive(Debug, Subcommand)]
@ -114,6 +116,7 @@ async fn main() -> Result<(), RouterError> {
disable_grammar_support,
max_client_batch_size,
usage_stats,
payload_limit,
} = args;
if let Some(Commands::PrintSchema) = command {
@ -210,6 +213,7 @@ async fn main() -> Result<(), RouterError> {
disable_grammar_support,
max_client_batch_size,
usage_stats,
payload_limit,
)
.await?;
Ok(())

View File

@ -456,6 +456,17 @@ Options:
- off: Disables all collection of usage statistics
- no-stack: Doesn't send the error stack trace or error type, but allows sending a crash event
```
## PAYLOAD_LIMIT
```shell
--payload-limit <PAYLOAD_LIMIT>
Payload size limit in bytes
Default is 2MB
[env: PAYLOAD_LIMIT=]
[default: 2000000]
```
## HELP
```shell

View File

@ -692,6 +692,12 @@ struct Args {
/// Defaul is on.
#[clap(default_value = "on", long, env)]
usage_stats: UsageStatsLevel,
/// Payload size limit in bytes
///
/// Default is 2MB
#[clap(default_value = "2000000", long, env)]
payload_limit: usize,
}
#[derive(Debug)]
@ -1479,6 +1485,8 @@ fn spawn_webserver(
format!("{}-0", args.shard_uds_path),
"--tokenizer-name".to_string(),
args.model_id,
"--payload-limit".to_string(),
args.payload_limit.to_string(),
];
if let Some(max_input_tokens) = max_input_tokens {
router_args.extend_from_slice(&[

View File

@ -30,7 +30,7 @@ use crate::{
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice};
use crate::{ModelInfo, ModelsInfo};
use async_stream::__private::AsyncStream;
use axum::extract::Extension;
use axum::extract::{DefaultBodyLimit, Extension};
use axum::http::{HeaderMap, HeaderValue, Method, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::{IntoResponse, Response};
@ -1674,6 +1674,7 @@ pub async fn run(
disable_grammar_support: bool,
max_client_batch_size: usize,
usage_stats_level: usage_stats::UsageStatsLevel,
payload_limit: usize,
) -> Result<(), WebServerError> {
// CORS allowed origins
// map to go inside the option and then map to parse from String to HeaderValue
@ -1928,6 +1929,7 @@ pub async fn run(
model_info,
compat_return_full_text,
allow_origin,
payload_limit,
)
.await;
@ -1987,6 +1989,7 @@ async fn start(
model_info: HubModelInfo,
compat_return_full_text: bool,
allow_origin: Option<AllowOrigin>,
payload_limit: usize,
) -> Result<(), WebServerError> {
// Determine the server port based on the feature and environment variable.
let port = if cfg!(feature = "google") {
@ -2384,6 +2387,7 @@ async fn start(
.layer(Extension(compute_type))
.layer(Extension(prom_handle.clone()))
.layer(OtelAxumLayer::default())
.layer(DefaultBodyLimit::max(payload_limit))
.layer(cors_layer);
tracing::info!("Connected");

View File

@ -962,9 +962,9 @@ class FlashCausalLMBatch(Batch):
self.input_lengths_tensor = torch.tensor(
self.input_lengths, dtype=torch.int32, device=device
)
self.cu_seqlen_prefill = torch.nn.functional.pad(
torch.cumsum(self.input_lengths_tensor, dim=0), (1, 0)
).to(torch.int32)
cu_seqlen_prefill = self.input_lengths_tensor.new_zeros(len(self) + 1)
torch.cumsum(self.input_lengths_tensor, out=cu_seqlen_prefill[1:], dim=0)
self.cu_seqlen_prefill = cu_seqlen_prefill.to(torch.int32)
self.cache_lengths_tensor = torch.tensor(
self.cache_lengths, dtype=torch.int32, device=device
)
@ -2020,9 +2020,8 @@ class FlashCausalLM(Model):
# For each member of the batch
# Cumulative length
cu_accepted_ids = torch.nn.functional.pad(
torch.cumsum(accepted_ids, dim=0), (1, 0)
)
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
cumulative_length = 0
for i, (
request,

View File

@ -66,8 +66,9 @@ def block_tables_to_ragged(
)
if has_triton():
cu_seqlen = torch.nn.functional.pad(
torch.cumsum(input_lengths_tensor + cache_lengths_tensor, dim=0), (1, 0)
cu_seqlen = input_lengths_tensor.new_zeros(input_lengths_tensor.shape[0] + 1)
torch.cumsum(
input_lengths_tensor + cache_lengths_tensor, out=cu_seqlen[1:], dim=0
)
def grid(meta):