Adding batch_dimension_flag (to be used for Neuron other forced padding targets).

This commit is contained in:
Nicolas Patry 2024-02-05 14:29:32 +01:00
parent 0da00be52c
commit e1dc168188
5 changed files with 105 additions and 51 deletions

38
Cargo.lock generated
View File

@ -2787,7 +2787,7 @@ dependencies = [
"tabled", "tabled",
"text-generation-client", "text-generation-client",
"thiserror", "thiserror",
"tokenizers", "tokenizers 0.14.1",
"tokio", "tokio",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
@ -2850,7 +2850,7 @@ dependencies = [
"serde_json", "serde_json",
"text-generation-client", "text-generation-client",
"thiserror", "thiserror",
"tokenizers", "tokenizers 0.15.1",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tower-http", "tower-http",
@ -2972,6 +2972,40 @@ dependencies = [
"unicode_categories", "unicode_categories",
] ]
[[package]]
name = "tokenizers"
version = "0.15.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6db445cceba5dfeb0f9702be7d6bfd91801ddcbe8fe8722defe7f2e96da75812"
dependencies = [
"aho-corasick",
"clap",
"derive_builder",
"esaxx-rs",
"getrandom",
"hf-hub",
"indicatif",
"itertools 0.11.0",
"lazy_static",
"log",
"macro_rules_attribute",
"monostate",
"onig",
"paste",
"rand",
"rayon",
"rayon-cond",
"regex",
"regex-syntax 0.7.5",
"serde",
"serde_json",
"spm_precompiled",
"thiserror",
"unicode-normalization-alignments",
"unicode-segmentation",
"unicode_categories",
]
[[package]] [[package]]
name = "tokio" name = "tokio"
version = "1.35.1" version = "1.35.1"

View File

@ -368,6 +368,12 @@ struct Args {
#[clap(long, env)] #[clap(long, env)]
ngrok_edge: Option<String>, ngrok_edge: Option<String>,
/// Specific flag for hardware targets that do not support unpadded inference
/// For those we do not send the tokenizer to the router so that all the scheduling
/// assumes those pad tokens exist (and potentially even more).
#[clap(long, env)]
batch_dimension: bool,
/// The path to the tokenizer config file. This path is used to load the tokenizer configuration which may /// The path to the tokenizer config file. This path is used to load the tokenizer configuration which may
/// include a `chat_template`. If not provided, the default config will be used from the model hub. /// include a `chat_template`. If not provided, the default config will be used from the model hub.
#[clap(long, env)] #[clap(long, env)]
@ -1034,6 +1040,10 @@ fn spawn_webserver(
args.model_id, args.model_id,
]; ];
if args.batch_dimension{
router_args.push("--batch-dimension".to_string());
}
// Tokenizer config path // Tokenizer config path
if let Some(ref tokenizer_config_path) = args.tokenizer_config_path { if let Some(ref tokenizer_config_path) = args.tokenizer_config_path {
router_args.push("--tokenizer-config-path".to_string()); router_args.push("--tokenizer-config-path".to_string());

View File

@ -37,6 +37,8 @@ struct Args {
max_input_length: usize, max_input_length: usize,
#[clap(default_value = "2048", long, env)] #[clap(default_value = "2048", long, env)]
max_total_tokens: usize, max_total_tokens: usize,
#[clap(long, env)]
batch_dimension: bool,
#[clap(default_value = "1.2", long, env)] #[clap(default_value = "1.2", long, env)]
waiting_served_ratio: f32, waiting_served_ratio: f32,
#[clap(default_value = "4096", long, env)] #[clap(default_value = "4096", long, env)]
@ -87,6 +89,7 @@ async fn main() -> Result<(), RouterError> {
max_top_n_tokens, max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
batch_dimension,
waiting_served_ratio, waiting_served_ratio,
max_batch_prefill_tokens, max_batch_prefill_tokens,
max_batch_total_tokens, max_batch_total_tokens,
@ -340,6 +343,7 @@ async fn main() -> Result<(), RouterError> {
max_top_n_tokens, max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
batch_dimension,
waiting_served_ratio, waiting_served_ratio,
max_batch_prefill_tokens, max_batch_prefill_tokens,
max_supported_batch_total_tokens, max_supported_batch_total_tokens,

View File

@ -758,6 +758,7 @@ pub async fn run(
max_top_n_tokens: u32, max_top_n_tokens: u32,
max_input_length: usize, max_input_length: usize,
max_total_tokens: usize, max_total_tokens: usize,
batch_dimension: bool,
waiting_served_ratio: f32, waiting_served_ratio: f32,
max_batch_prefill_tokens: u32, max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
@ -833,6 +834,7 @@ pub async fn run(
max_top_n_tokens, max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
batch_dimension
); );
let generation_health = Arc::new(AtomicBool::new(false)); let generation_health = Arc::new(AtomicBool::new(false));
let health_ext = Health::new(client.clone(), generation_health.clone()); let health_ext = Health::new(client.clone(), generation_health.clone());

View File

@ -19,6 +19,7 @@ pub struct Validation {
max_top_n_tokens: u32, max_top_n_tokens: u32,
max_input_length: usize, max_input_length: usize,
max_total_tokens: usize, max_total_tokens: usize,
batched_dimension: bool,
/// Channel to communicate with the background tokenization task /// Channel to communicate with the background tokenization task
sender: Option<mpsc::UnboundedSender<TokenizerRequest>>, sender: Option<mpsc::UnboundedSender<TokenizerRequest>>,
} }
@ -32,6 +33,7 @@ impl Validation {
max_top_n_tokens: u32, max_top_n_tokens: u32,
max_input_length: usize, max_input_length: usize,
max_total_tokens: usize, max_total_tokens: usize,
batched_dimension: bool,
) -> Self { ) -> Self {
// If we have a fast tokenizer // If we have a fast tokenizer
let sender = if let Some(tokenizer) = tokenizer { let sender = if let Some(tokenizer) = tokenizer {
@ -66,6 +68,7 @@ impl Validation {
max_top_n_tokens, max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
batched_dimension
} }
} }
@ -103,61 +106,62 @@ impl Validation {
) -> Result<(String, usize, u32), ValidationError> { ) -> Result<(String, usize, u32), ValidationError> {
// If we have a fast tokenizer // If we have a fast tokenizer
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? { if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
// Create response channel if self.batched_dimension{
let input_length = encoding.len(); let input_length = encoding.len();
// Get total tokens // Get total tokens
let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens { let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
max_new_tokens max_new_tokens
} else { } else {
self.max_total_tokens.saturating_sub(input_length) as u32 self.max_total_tokens.saturating_sub(input_length) as u32
}; };
let total_tokens = input_length + max_new_tokens as usize; let total_tokens = input_length + max_new_tokens as usize;
// Validate MaxTotalTokens // Validate MaxTotalTokens
if total_tokens > self.max_total_tokens { if total_tokens > self.max_total_tokens {
return Err(ValidationError::MaxTotalTokens( return Err(ValidationError::MaxTotalTokens(
self.max_total_tokens, self.max_total_tokens,
input_length, input_length,
max_new_tokens, max_new_tokens,
)); ));
}
// Validate InputLength
if input_length > self.max_input_length {
return Err(ValidationError::InputLength(
self.max_input_length,
input_length,
));
}
//
metrics::histogram!("tgi_request_input_length", input_length as f64);
return Ok((inputs, input_length, max_new_tokens));
} }
// Validate InputLength
if input_length > self.max_input_length {
return Err(ValidationError::InputLength(
self.max_input_length,
input_length,
));
}
metrics::histogram!("tgi_request_input_length", input_length as f64);
Ok((inputs, input_length, max_new_tokens))
} }
// Return inputs without validation // Either we don't have a tokenizer or batched_dimension purposefully
else { // will ignore the actual length in order to schedule the job correctly.
// In this case, we don't know the real length in tokens of the inputs // In this case, we don't know the real length in tokens of the inputs
// However, the inputs will be truncated by the python servers // However, the inputs will be truncated by the python servers
// We make sure that truncate + max_new_tokens <= self.max_total_tokens // We make sure that truncate + max_new_tokens <= self.max_total_tokens
let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens { let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
max_new_tokens max_new_tokens
} else if let Some(truncate) = truncate { } else if let Some(truncate) = truncate {
self.max_total_tokens.saturating_sub(truncate) as u32 self.max_total_tokens.saturating_sub(truncate) as u32
} else { } else {
return Err(ValidationError::UnsetMaxNewTokens); return Err(ValidationError::UnsetMaxNewTokens);
}; };
let input_length = truncate.unwrap_or(self.max_input_length); let input_length = truncate.unwrap_or(self.max_input_length);
// Validate MaxNewTokens // Validate MaxNewTokens
if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 { if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
return Err(ValidationError::MaxNewTokens( return Err(ValidationError::MaxNewTokens(
self.max_total_tokens - self.max_input_length, self.max_total_tokens - self.max_input_length,
max_new_tokens, max_new_tokens,
)); ));
}
Ok((inputs, input_length, max_new_tokens))
} }
Ok((inputs, input_length, max_new_tokens))
} }
/// Validate a payload and get the number of tokens in the input /// Validate a payload and get the number of tokens in the input