feat: add payload limit
This commit is contained in:
parent
b1f9044d6c
commit
6297f1769f
|
@ -62,6 +62,8 @@ struct Args {
|
||||||
executor_worker: PathBuf,
|
executor_worker: PathBuf,
|
||||||
#[clap(default_value = "on", long, env)]
|
#[clap(default_value = "on", long, env)]
|
||||||
usage_stats: usage_stats::UsageStatsLevel,
|
usage_stats: usage_stats::UsageStatsLevel,
|
||||||
|
#[clap(default_value = "2000000", long, env)]
|
||||||
|
payload_limit: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_tokenizer(
|
async fn get_tokenizer(
|
||||||
|
@ -217,6 +219,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
||||||
auth_token,
|
auth_token,
|
||||||
executor_worker,
|
executor_worker,
|
||||||
usage_stats,
|
usage_stats,
|
||||||
|
payload_limit,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
// Launch Tokio runtime
|
// Launch Tokio runtime
|
||||||
|
@ -287,6 +290,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
||||||
tokenizer_name,
|
tokenizer_name,
|
||||||
tokenizer_config_path,
|
tokenizer_config_path,
|
||||||
revision,
|
revision,
|
||||||
|
false,
|
||||||
hostname,
|
hostname,
|
||||||
port,
|
port,
|
||||||
cors_allow_origin,
|
cors_allow_origin,
|
||||||
|
@ -296,6 +300,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
||||||
true,
|
true,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
usage_stats,
|
usage_stats,
|
||||||
|
payload_limit,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
|
@ -70,6 +70,8 @@ struct Args {
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
#[clap(default_value = "on", long, env)]
|
#[clap(default_value = "on", long, env)]
|
||||||
usage_stats: usage_stats::UsageStatsLevel,
|
usage_stats: usage_stats::UsageStatsLevel,
|
||||||
|
#[clap(default_value = "2000000", long, env)]
|
||||||
|
payload_limit: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Subcommand)]
|
#[derive(Debug, Subcommand)]
|
||||||
|
@ -114,6 +116,7 @@ async fn main() -> Result<(), RouterError> {
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
usage_stats,
|
usage_stats,
|
||||||
|
payload_limit,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
if let Some(Commands::PrintSchema) = command {
|
if let Some(Commands::PrintSchema) = command {
|
||||||
|
@ -194,6 +197,7 @@ async fn main() -> Result<(), RouterError> {
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
usage_stats,
|
usage_stats,
|
||||||
|
payload_limit,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
|
@ -70,6 +70,8 @@ struct Args {
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
#[clap(default_value = "on", long, env)]
|
#[clap(default_value = "on", long, env)]
|
||||||
usage_stats: usage_stats::UsageStatsLevel,
|
usage_stats: usage_stats::UsageStatsLevel,
|
||||||
|
#[clap(default_value = "2000000", long, env)]
|
||||||
|
payload_limit: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Subcommand)]
|
#[derive(Debug, Subcommand)]
|
||||||
|
@ -114,6 +116,7 @@ async fn main() -> Result<(), RouterError> {
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
usage_stats,
|
usage_stats,
|
||||||
|
payload_limit,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
if let Some(Commands::PrintSchema) = command {
|
if let Some(Commands::PrintSchema) = command {
|
||||||
|
@ -210,6 +213,7 @@ async fn main() -> Result<(), RouterError> {
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
usage_stats,
|
usage_stats,
|
||||||
|
payload_limit,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
|
@ -687,6 +687,12 @@ struct Args {
|
||||||
/// Defaul is on.
|
/// Defaul is on.
|
||||||
#[clap(default_value = "on", long, env)]
|
#[clap(default_value = "on", long, env)]
|
||||||
usage_stats: UsageStatsLevel,
|
usage_stats: UsageStatsLevel,
|
||||||
|
|
||||||
|
/// Payload size limit in bytes
|
||||||
|
///
|
||||||
|
/// Default is 2MB
|
||||||
|
#[clap(default_value = "2000000", long, env)]
|
||||||
|
payload_limit: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
@ -1474,6 +1480,8 @@ fn spawn_webserver(
|
||||||
format!("{}-0", args.shard_uds_path),
|
format!("{}-0", args.shard_uds_path),
|
||||||
"--tokenizer-name".to_string(),
|
"--tokenizer-name".to_string(),
|
||||||
args.model_id,
|
args.model_id,
|
||||||
|
"--payload-limit".to_string(),
|
||||||
|
args.payload_limit.to_string(),
|
||||||
];
|
];
|
||||||
if let Some(max_input_tokens) = max_input_tokens {
|
if let Some(max_input_tokens) = max_input_tokens {
|
||||||
router_args.extend_from_slice(&[
|
router_args.extend_from_slice(&[
|
||||||
|
|
|
@ -31,7 +31,7 @@ use crate::{
|
||||||
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType};
|
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType};
|
||||||
use crate::{ModelInfo, ModelsInfo};
|
use crate::{ModelInfo, ModelsInfo};
|
||||||
use async_stream::__private::AsyncStream;
|
use async_stream::__private::AsyncStream;
|
||||||
use axum::extract::Extension;
|
use axum::extract::{DefaultBodyLimit, Extension};
|
||||||
use axum::http::{HeaderMap, HeaderValue, Method, StatusCode};
|
use axum::http::{HeaderMap, HeaderValue, Method, StatusCode};
|
||||||
use axum::response::sse::{Event, KeepAlive, Sse};
|
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||||
use axum::response::{IntoResponse, Response};
|
use axum::response::{IntoResponse, Response};
|
||||||
|
@ -1673,6 +1673,7 @@ pub async fn run(
|
||||||
disable_grammar_support: bool,
|
disable_grammar_support: bool,
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
usage_stats_level: usage_stats::UsageStatsLevel,
|
usage_stats_level: usage_stats::UsageStatsLevel,
|
||||||
|
payload_limit: usize,
|
||||||
) -> Result<(), WebServerError> {
|
) -> Result<(), WebServerError> {
|
||||||
// CORS allowed origins
|
// CORS allowed origins
|
||||||
// map to go inside the option and then map to parse from String to HeaderValue
|
// map to go inside the option and then map to parse from String to HeaderValue
|
||||||
|
@ -1926,6 +1927,7 @@ pub async fn run(
|
||||||
model_info,
|
model_info,
|
||||||
compat_return_full_text,
|
compat_return_full_text,
|
||||||
allow_origin,
|
allow_origin,
|
||||||
|
payload_limit,
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
|
@ -1985,6 +1987,7 @@ async fn start(
|
||||||
model_info: HubModelInfo,
|
model_info: HubModelInfo,
|
||||||
compat_return_full_text: bool,
|
compat_return_full_text: bool,
|
||||||
allow_origin: Option<AllowOrigin>,
|
allow_origin: Option<AllowOrigin>,
|
||||||
|
payload_limit: usize,
|
||||||
) -> Result<(), WebServerError> {
|
) -> Result<(), WebServerError> {
|
||||||
// Determine the server port based on the feature and environment variable.
|
// Determine the server port based on the feature and environment variable.
|
||||||
let port = if cfg!(feature = "google") {
|
let port = if cfg!(feature = "google") {
|
||||||
|
@ -2382,6 +2385,7 @@ async fn start(
|
||||||
.layer(Extension(compute_type))
|
.layer(Extension(compute_type))
|
||||||
.layer(Extension(prom_handle.clone()))
|
.layer(Extension(prom_handle.clone()))
|
||||||
.layer(OtelAxumLayer::default())
|
.layer(OtelAxumLayer::default())
|
||||||
|
.layer(DefaultBodyLimit::max(payload_limit))
|
||||||
.layer(cors_layer);
|
.layer(cors_layer);
|
||||||
|
|
||||||
tracing::info!("Connected");
|
tracing::info!("Connected");
|
||||||
|
|
|
@ -962,9 +962,9 @@ class FlashCausalLMBatch(Batch):
|
||||||
self.input_lengths_tensor = torch.tensor(
|
self.input_lengths_tensor = torch.tensor(
|
||||||
self.input_lengths, dtype=torch.int32, device=device
|
self.input_lengths, dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
self.cu_seqlen_prefill = torch.nn.functional.pad(
|
cu_seqlen_prefill = self.input_lengths_tensor.new_zeros(len(self) + 1)
|
||||||
torch.cumsum(self.input_lengths_tensor, dim=0), (1, 0)
|
torch.cumsum(self.input_lengths_tensor, out=cu_seqlen_prefill[1:], dim=0)
|
||||||
).to(torch.int32)
|
self.cu_seqlen_prefill = cu_seqlen_prefill.to(torch.int32)
|
||||||
self.cache_lengths_tensor = torch.tensor(
|
self.cache_lengths_tensor = torch.tensor(
|
||||||
self.cache_lengths, dtype=torch.int32, device=device
|
self.cache_lengths, dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
|
@ -2020,9 +2020,8 @@ class FlashCausalLM(Model):
|
||||||
|
|
||||||
# For each member of the batch
|
# For each member of the batch
|
||||||
# Cumulative length
|
# Cumulative length
|
||||||
cu_accepted_ids = torch.nn.functional.pad(
|
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
|
||||||
torch.cumsum(accepted_ids, dim=0), (1, 0)
|
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
|
||||||
)
|
|
||||||
cumulative_length = 0
|
cumulative_length = 0
|
||||||
for i, (
|
for i, (
|
||||||
request,
|
request,
|
||||||
|
|
|
@ -66,8 +66,9 @@ def block_tables_to_ragged(
|
||||||
)
|
)
|
||||||
|
|
||||||
if has_triton():
|
if has_triton():
|
||||||
cu_seqlen = torch.nn.functional.pad(
|
cu_seqlen = input_lengths_tensor.new_zeros(input_lengths_tensor.shape[0] + 1)
|
||||||
torch.cumsum(input_lengths_tensor + cache_lengths_tensor, dim=0), (1, 0)
|
torch.cumsum(
|
||||||
|
input_lengths_tensor + cache_lengths_tensor, out=cu_seqlen[1:], dim=0
|
||||||
)
|
)
|
||||||
|
|
||||||
def grid(meta):
|
def grid(meta):
|
||||||
|
|
Loading…
Reference in New Issue