feat: add payload limit (#2726)
* feat: add payload limit * update launcher
This commit is contained in:
parent
d5bc6a20bd
commit
ab7ccf5bc3
|
@ -62,6 +62,8 @@ struct Args {
|
|||
executor_worker: PathBuf,
|
||||
#[clap(default_value = "on", long, env)]
|
||||
usage_stats: usage_stats::UsageStatsLevel,
|
||||
#[clap(default_value = "2000000", long, env)]
|
||||
payload_limit: usize,
|
||||
}
|
||||
|
||||
async fn get_tokenizer(
|
||||
|
@ -217,6 +219,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
|||
auth_token,
|
||||
executor_worker,
|
||||
usage_stats,
|
||||
payload_limit,
|
||||
} = args;
|
||||
|
||||
// Launch Tokio runtime
|
||||
|
@ -287,6 +290,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
|||
tokenizer_name,
|
||||
tokenizer_config_path,
|
||||
revision,
|
||||
false,
|
||||
hostname,
|
||||
port,
|
||||
cors_allow_origin,
|
||||
|
@ -296,6 +300,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
|||
true,
|
||||
max_client_batch_size,
|
||||
usage_stats,
|
||||
payload_limit,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
|
|
|
@ -70,6 +70,8 @@ struct Args {
|
|||
max_client_batch_size: usize,
|
||||
#[clap(default_value = "on", long, env)]
|
||||
usage_stats: usage_stats::UsageStatsLevel,
|
||||
#[clap(default_value = "2000000", long, env)]
|
||||
payload_limit: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Subcommand)]
|
||||
|
@ -114,6 +116,7 @@ async fn main() -> Result<(), RouterError> {
|
|||
disable_grammar_support,
|
||||
max_client_batch_size,
|
||||
usage_stats,
|
||||
payload_limit,
|
||||
} = args;
|
||||
|
||||
if let Some(Commands::PrintSchema) = command {
|
||||
|
@ -194,6 +197,7 @@ async fn main() -> Result<(), RouterError> {
|
|||
disable_grammar_support,
|
||||
max_client_batch_size,
|
||||
usage_stats,
|
||||
payload_limit,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
|
|
|
@ -70,6 +70,8 @@ struct Args {
|
|||
max_client_batch_size: usize,
|
||||
#[clap(default_value = "on", long, env)]
|
||||
usage_stats: usage_stats::UsageStatsLevel,
|
||||
#[clap(default_value = "2000000", long, env)]
|
||||
payload_limit: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Subcommand)]
|
||||
|
@ -114,6 +116,7 @@ async fn main() -> Result<(), RouterError> {
|
|||
disable_grammar_support,
|
||||
max_client_batch_size,
|
||||
usage_stats,
|
||||
payload_limit,
|
||||
} = args;
|
||||
|
||||
if let Some(Commands::PrintSchema) = command {
|
||||
|
@ -210,6 +213,7 @@ async fn main() -> Result<(), RouterError> {
|
|||
disable_grammar_support,
|
||||
max_client_batch_size,
|
||||
usage_stats,
|
||||
payload_limit,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
|
|
|
@ -456,6 +456,17 @@ Options:
|
|||
- off: Disables all collection of usage statistics
|
||||
- no-stack: Doesn't send the error stack trace or error type, but allows sending a crash event
|
||||
|
||||
```
|
||||
## PAYLOAD_LIMIT
|
||||
```shell
|
||||
--payload-limit <PAYLOAD_LIMIT>
|
||||
Payload size limit in bytes
|
||||
|
||||
Default is 2MB
|
||||
|
||||
[env: PAYLOAD_LIMIT=]
|
||||
[default: 2000000]
|
||||
|
||||
```
|
||||
## HELP
|
||||
```shell
|
||||
|
|
|
@ -692,6 +692,12 @@ struct Args {
|
|||
/// Defaul is on.
|
||||
#[clap(default_value = "on", long, env)]
|
||||
usage_stats: UsageStatsLevel,
|
||||
|
||||
/// Payload size limit in bytes
|
||||
///
|
||||
/// Default is 2MB
|
||||
#[clap(default_value = "2000000", long, env)]
|
||||
payload_limit: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
@ -1479,6 +1485,8 @@ fn spawn_webserver(
|
|||
format!("{}-0", args.shard_uds_path),
|
||||
"--tokenizer-name".to_string(),
|
||||
args.model_id,
|
||||
"--payload-limit".to_string(),
|
||||
args.payload_limit.to_string(),
|
||||
];
|
||||
if let Some(max_input_tokens) = max_input_tokens {
|
||||
router_args.extend_from_slice(&[
|
||||
|
|
|
@ -30,7 +30,7 @@ use crate::{
|
|||
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice};
|
||||
use crate::{ModelInfo, ModelsInfo};
|
||||
use async_stream::__private::AsyncStream;
|
||||
use axum::extract::Extension;
|
||||
use axum::extract::{DefaultBodyLimit, Extension};
|
||||
use axum::http::{HeaderMap, HeaderValue, Method, StatusCode};
|
||||
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||
use axum::response::{IntoResponse, Response};
|
||||
|
@ -1674,6 +1674,7 @@ pub async fn run(
|
|||
disable_grammar_support: bool,
|
||||
max_client_batch_size: usize,
|
||||
usage_stats_level: usage_stats::UsageStatsLevel,
|
||||
payload_limit: usize,
|
||||
) -> Result<(), WebServerError> {
|
||||
// CORS allowed origins
|
||||
// map to go inside the option and then map to parse from String to HeaderValue
|
||||
|
@ -1928,6 +1929,7 @@ pub async fn run(
|
|||
model_info,
|
||||
compat_return_full_text,
|
||||
allow_origin,
|
||||
payload_limit,
|
||||
)
|
||||
.await;
|
||||
|
||||
|
@ -1987,6 +1989,7 @@ async fn start(
|
|||
model_info: HubModelInfo,
|
||||
compat_return_full_text: bool,
|
||||
allow_origin: Option<AllowOrigin>,
|
||||
payload_limit: usize,
|
||||
) -> Result<(), WebServerError> {
|
||||
// Determine the server port based on the feature and environment variable.
|
||||
let port = if cfg!(feature = "google") {
|
||||
|
@ -2384,6 +2387,7 @@ async fn start(
|
|||
.layer(Extension(compute_type))
|
||||
.layer(Extension(prom_handle.clone()))
|
||||
.layer(OtelAxumLayer::default())
|
||||
.layer(DefaultBodyLimit::max(payload_limit))
|
||||
.layer(cors_layer);
|
||||
|
||||
tracing::info!("Connected");
|
||||
|
|
|
@ -962,9 +962,9 @@ class FlashCausalLMBatch(Batch):
|
|||
self.input_lengths_tensor = torch.tensor(
|
||||
self.input_lengths, dtype=torch.int32, device=device
|
||||
)
|
||||
self.cu_seqlen_prefill = torch.nn.functional.pad(
|
||||
torch.cumsum(self.input_lengths_tensor, dim=0), (1, 0)
|
||||
).to(torch.int32)
|
||||
cu_seqlen_prefill = self.input_lengths_tensor.new_zeros(len(self) + 1)
|
||||
torch.cumsum(self.input_lengths_tensor, out=cu_seqlen_prefill[1:], dim=0)
|
||||
self.cu_seqlen_prefill = cu_seqlen_prefill.to(torch.int32)
|
||||
self.cache_lengths_tensor = torch.tensor(
|
||||
self.cache_lengths, dtype=torch.int32, device=device
|
||||
)
|
||||
|
@ -2020,9 +2020,8 @@ class FlashCausalLM(Model):
|
|||
|
||||
# For each member of the batch
|
||||
# Cumulative length
|
||||
cu_accepted_ids = torch.nn.functional.pad(
|
||||
torch.cumsum(accepted_ids, dim=0), (1, 0)
|
||||
)
|
||||
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
|
||||
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
|
||||
cumulative_length = 0
|
||||
for i, (
|
||||
request,
|
||||
|
|
|
@ -66,8 +66,9 @@ def block_tables_to_ragged(
|
|||
)
|
||||
|
||||
if has_triton():
|
||||
cu_seqlen = torch.nn.functional.pad(
|
||||
torch.cumsum(input_lengths_tensor + cache_lengths_tensor, dim=0), (1, 0)
|
||||
cu_seqlen = input_lengths_tensor.new_zeros(input_lengths_tensor.shape[0] + 1)
|
||||
torch.cumsum(
|
||||
input_lengths_tensor + cache_lengths_tensor, out=cu_seqlen[1:], dim=0
|
||||
)
|
||||
|
||||
def grid(meta):
|
||||
|
|
Loading…
Reference in New Issue