Adding batch_dimension_flag (to be used for Neuron other forced padding targets).
This commit is contained in:
parent
0da00be52c
commit
e1dc168188
|
@ -2787,7 +2787,7 @@ dependencies = [
|
||||||
"tabled",
|
"tabled",
|
||||||
"text-generation-client",
|
"text-generation-client",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tokenizers",
|
"tokenizers 0.14.1",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
|
@ -2850,7 +2850,7 @@ dependencies = [
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"text-generation-client",
|
"text-generation-client",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tokenizers",
|
"tokenizers 0.15.1",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-stream",
|
"tokio-stream",
|
||||||
"tower-http",
|
"tower-http",
|
||||||
|
@ -2972,6 +2972,40 @@ dependencies = [
|
||||||
"unicode_categories",
|
"unicode_categories",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tokenizers"
|
||||||
|
version = "0.15.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6db445cceba5dfeb0f9702be7d6bfd91801ddcbe8fe8722defe7f2e96da75812"
|
||||||
|
dependencies = [
|
||||||
|
"aho-corasick",
|
||||||
|
"clap",
|
||||||
|
"derive_builder",
|
||||||
|
"esaxx-rs",
|
||||||
|
"getrandom",
|
||||||
|
"hf-hub",
|
||||||
|
"indicatif",
|
||||||
|
"itertools 0.11.0",
|
||||||
|
"lazy_static",
|
||||||
|
"log",
|
||||||
|
"macro_rules_attribute",
|
||||||
|
"monostate",
|
||||||
|
"onig",
|
||||||
|
"paste",
|
||||||
|
"rand",
|
||||||
|
"rayon",
|
||||||
|
"rayon-cond",
|
||||||
|
"regex",
|
||||||
|
"regex-syntax 0.7.5",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"spm_precompiled",
|
||||||
|
"thiserror",
|
||||||
|
"unicode-normalization-alignments",
|
||||||
|
"unicode-segmentation",
|
||||||
|
"unicode_categories",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tokio"
|
name = "tokio"
|
||||||
version = "1.35.1"
|
version = "1.35.1"
|
||||||
|
|
|
@ -368,6 +368,12 @@ struct Args {
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
ngrok_edge: Option<String>,
|
ngrok_edge: Option<String>,
|
||||||
|
|
||||||
|
/// Specific flag for hardware targets that do not support unpadded inference
|
||||||
|
/// For those we do not send the tokenizer to the router so that all the scheduling
|
||||||
|
/// assumes those pad tokens exist (and potentially even more).
|
||||||
|
#[clap(long, env)]
|
||||||
|
batch_dimension: bool,
|
||||||
|
|
||||||
/// The path to the tokenizer config file. This path is used to load the tokenizer configuration which may
|
/// The path to the tokenizer config file. This path is used to load the tokenizer configuration which may
|
||||||
/// include a `chat_template`. If not provided, the default config will be used from the model hub.
|
/// include a `chat_template`. If not provided, the default config will be used from the model hub.
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
|
@ -1034,6 +1040,10 @@ fn spawn_webserver(
|
||||||
args.model_id,
|
args.model_id,
|
||||||
];
|
];
|
||||||
|
|
||||||
|
if args.batch_dimension{
|
||||||
|
router_args.push("--batch-dimension".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
// Tokenizer config path
|
// Tokenizer config path
|
||||||
if let Some(ref tokenizer_config_path) = args.tokenizer_config_path {
|
if let Some(ref tokenizer_config_path) = args.tokenizer_config_path {
|
||||||
router_args.push("--tokenizer-config-path".to_string());
|
router_args.push("--tokenizer-config-path".to_string());
|
||||||
|
|
|
@ -37,6 +37,8 @@ struct Args {
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
#[clap(default_value = "2048", long, env)]
|
#[clap(default_value = "2048", long, env)]
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
|
#[clap(long, env)]
|
||||||
|
batch_dimension: bool,
|
||||||
#[clap(default_value = "1.2", long, env)]
|
#[clap(default_value = "1.2", long, env)]
|
||||||
waiting_served_ratio: f32,
|
waiting_served_ratio: f32,
|
||||||
#[clap(default_value = "4096", long, env)]
|
#[clap(default_value = "4096", long, env)]
|
||||||
|
@ -87,6 +89,7 @@ async fn main() -> Result<(), RouterError> {
|
||||||
max_top_n_tokens,
|
max_top_n_tokens,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
|
batch_dimension,
|
||||||
waiting_served_ratio,
|
waiting_served_ratio,
|
||||||
max_batch_prefill_tokens,
|
max_batch_prefill_tokens,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
|
@ -340,6 +343,7 @@ async fn main() -> Result<(), RouterError> {
|
||||||
max_top_n_tokens,
|
max_top_n_tokens,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
|
batch_dimension,
|
||||||
waiting_served_ratio,
|
waiting_served_ratio,
|
||||||
max_batch_prefill_tokens,
|
max_batch_prefill_tokens,
|
||||||
max_supported_batch_total_tokens,
|
max_supported_batch_total_tokens,
|
||||||
|
|
|
@ -758,6 +758,7 @@ pub async fn run(
|
||||||
max_top_n_tokens: u32,
|
max_top_n_tokens: u32,
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
|
batch_dimension: bool,
|
||||||
waiting_served_ratio: f32,
|
waiting_served_ratio: f32,
|
||||||
max_batch_prefill_tokens: u32,
|
max_batch_prefill_tokens: u32,
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
|
@ -833,6 +834,7 @@ pub async fn run(
|
||||||
max_top_n_tokens,
|
max_top_n_tokens,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
|
batch_dimension
|
||||||
);
|
);
|
||||||
let generation_health = Arc::new(AtomicBool::new(false));
|
let generation_health = Arc::new(AtomicBool::new(false));
|
||||||
let health_ext = Health::new(client.clone(), generation_health.clone());
|
let health_ext = Health::new(client.clone(), generation_health.clone());
|
||||||
|
|
|
@ -19,6 +19,7 @@ pub struct Validation {
|
||||||
max_top_n_tokens: u32,
|
max_top_n_tokens: u32,
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
|
batched_dimension: bool,
|
||||||
/// Channel to communicate with the background tokenization task
|
/// Channel to communicate with the background tokenization task
|
||||||
sender: Option<mpsc::UnboundedSender<TokenizerRequest>>,
|
sender: Option<mpsc::UnboundedSender<TokenizerRequest>>,
|
||||||
}
|
}
|
||||||
|
@ -32,6 +33,7 @@ impl Validation {
|
||||||
max_top_n_tokens: u32,
|
max_top_n_tokens: u32,
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
|
batched_dimension: bool,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// If we have a fast tokenizer
|
// If we have a fast tokenizer
|
||||||
let sender = if let Some(tokenizer) = tokenizer {
|
let sender = if let Some(tokenizer) = tokenizer {
|
||||||
|
@ -66,6 +68,7 @@ impl Validation {
|
||||||
max_top_n_tokens,
|
max_top_n_tokens,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
|
batched_dimension
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -103,61 +106,62 @@ impl Validation {
|
||||||
) -> Result<(String, usize, u32), ValidationError> {
|
) -> Result<(String, usize, u32), ValidationError> {
|
||||||
// If we have a fast tokenizer
|
// If we have a fast tokenizer
|
||||||
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
|
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
|
||||||
// Create response channel
|
if self.batched_dimension{
|
||||||
let input_length = encoding.len();
|
let input_length = encoding.len();
|
||||||
|
|
||||||
// Get total tokens
|
// Get total tokens
|
||||||
let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
|
let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
|
||||||
max_new_tokens
|
max_new_tokens
|
||||||
} else {
|
} else {
|
||||||
self.max_total_tokens.saturating_sub(input_length) as u32
|
self.max_total_tokens.saturating_sub(input_length) as u32
|
||||||
};
|
};
|
||||||
let total_tokens = input_length + max_new_tokens as usize;
|
let total_tokens = input_length + max_new_tokens as usize;
|
||||||
|
|
||||||
// Validate MaxTotalTokens
|
// Validate MaxTotalTokens
|
||||||
if total_tokens > self.max_total_tokens {
|
if total_tokens > self.max_total_tokens {
|
||||||
return Err(ValidationError::MaxTotalTokens(
|
return Err(ValidationError::MaxTotalTokens(
|
||||||
self.max_total_tokens,
|
self.max_total_tokens,
|
||||||
input_length,
|
input_length,
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
));
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate InputLength
|
||||||
|
if input_length > self.max_input_length {
|
||||||
|
return Err(ValidationError::InputLength(
|
||||||
|
self.max_input_length,
|
||||||
|
input_length,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
metrics::histogram!("tgi_request_input_length", input_length as f64);
|
||||||
|
return Ok((inputs, input_length, max_new_tokens));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate InputLength
|
|
||||||
if input_length > self.max_input_length {
|
|
||||||
return Err(ValidationError::InputLength(
|
|
||||||
self.max_input_length,
|
|
||||||
input_length,
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
metrics::histogram!("tgi_request_input_length", input_length as f64);
|
|
||||||
Ok((inputs, input_length, max_new_tokens))
|
|
||||||
}
|
}
|
||||||
// Return inputs without validation
|
// Either we don't have a tokenizer or batched_dimension purposefully
|
||||||
else {
|
// will ignore the actual length in order to schedule the job correctly.
|
||||||
// In this case, we don't know the real length in tokens of the inputs
|
// In this case, we don't know the real length in tokens of the inputs
|
||||||
// However, the inputs will be truncated by the python servers
|
// However, the inputs will be truncated by the python servers
|
||||||
// We make sure that truncate + max_new_tokens <= self.max_total_tokens
|
// We make sure that truncate + max_new_tokens <= self.max_total_tokens
|
||||||
let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
|
let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
|
||||||
max_new_tokens
|
max_new_tokens
|
||||||
} else if let Some(truncate) = truncate {
|
} else if let Some(truncate) = truncate {
|
||||||
self.max_total_tokens.saturating_sub(truncate) as u32
|
self.max_total_tokens.saturating_sub(truncate) as u32
|
||||||
} else {
|
} else {
|
||||||
return Err(ValidationError::UnsetMaxNewTokens);
|
return Err(ValidationError::UnsetMaxNewTokens);
|
||||||
};
|
};
|
||||||
let input_length = truncate.unwrap_or(self.max_input_length);
|
let input_length = truncate.unwrap_or(self.max_input_length);
|
||||||
|
|
||||||
// Validate MaxNewTokens
|
// Validate MaxNewTokens
|
||||||
if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
|
if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
|
||||||
return Err(ValidationError::MaxNewTokens(
|
return Err(ValidationError::MaxNewTokens(
|
||||||
self.max_total_tokens - self.max_input_length,
|
self.max_total_tokens - self.max_input_length,
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
));
|
));
|
||||||
}
|
|
||||||
|
|
||||||
Ok((inputs, input_length, max_new_tokens))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ok((inputs, input_length, max_new_tokens))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Validate a payload and get the number of tokens in the input
|
/// Validate a payload and get the number of tokens in the input
|
||||||
|
|
Loading…
Reference in New Issue