Improve the defaults for the launcher ()

# What does this PR do?

- Renamed `max_input_length` into `max_input_tokens` for consistency
(backward compatible change, will yell if both are set.)
- Will now use the config for `max_input_tokens` `max_total_token` and
`max_batch_total_tokens`.
- Capping the values to 16k in order to save VRAM on behalf of users
(overriddable by simply setting the values).

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->
This commit is contained in:
Nicolas Patry 2024-04-12 14:20:31 +02:00 committed by GitHub
parent 9d8f21cace
commit 1b2670c823
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 226 additions and 68 deletions
Cargo.lock
docs/source/basic_tutorials
integration-tests/models
launcher
router/src

2
Cargo.lock generated
View File

@ -3452,7 +3452,9 @@ dependencies = [
"clap",
"ctrlc",
"float_eq",
"hf-hub",
"nix",
"once_cell",
"reqwest",
"serde",
"serde_json",

View File

@ -60,9 +60,9 @@ Options:
[env: QUANTIZE=]
Possible values:
- awq: 4 bit quantization. Requires a specific AWQ quantized model: https://hf.co/models?search=awq. Should replace GPTQ models wherever possible because of the better latency
- eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from https://github.com/NetEase-FuXi/EETQ.git
- gptq: 4 bit quantization. Requires a specific GTPQ quantized model: https://hf.co/models?search=gptq. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels
- awq: 4 bit quantization. Requires a specific AWQ quantized model: <https://hf.co/models?search=awq>. Should replace GPTQ models wherever possible because of the better latency
- eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
- gptq: 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels
- bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16
- bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16
- bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model
@ -129,23 +129,29 @@ Options:
[env: MAX_TOP_N_TOKENS=]
[default: 5]
```
## MAX_INPUT_TOKENS
```shell
--max-input-tokens <MAX_INPUT_TOKENS>
This is the maximum allowed input length (expressed in number of tokens) for users. The larger this value, the longer prompt users can send which can impact the overall memory required to handle the load. Please note that some models have a finite range of sequence they can handle. Default to min(max_position_embeddings - 1, 4095)
[env: MAX_INPUT_TOKENS=]
```
## MAX_INPUT_LENGTH
```shell
--max-input-length <MAX_INPUT_LENGTH>
This is the maximum allowed input length (expressed in number of tokens) for users. The larger this value, the longer prompt users can send which can impact the overall memory required to handle the load. Please note that some models have a finite range of sequence they can handle
Legacy version of [`Args::max_input_tokens`]
[env: MAX_INPUT_LENGTH=]
[default: 1024]
```
## MAX_TOTAL_TOKENS
```shell
--max-total-tokens <MAX_TOTAL_TOKENS>
This is the most important value to set as it defines the "memory budget" of running clients requests. Clients will send input sequences and ask to generate `max_new_tokens` on top. with a value of `1512` users can send either a prompt of `1000` and ask for `512` new tokens, or send a prompt of `1` and ask for `1511` max_new_tokens. The larger this value, the larger amount each request will be in your RAM and the less effective batching can be
This is the most important value to set as it defines the "memory budget" of running clients requests. Clients will send input sequences and ask to generate `max_new_tokens` on top. with a value of `1512` users can send either a prompt of `1000` and ask for `512` new tokens, or send a prompt of `1` and ask for `1511` max_new_tokens. The larger this value, the larger amount each request will be in your RAM and the less effective batching can be. Default to min(max_position_embeddings, 4096)
[env: MAX_TOTAL_TOKENS=]
[default: 2048]
```
## WAITING_SERVED_RATIO
@ -162,10 +168,9 @@ Options:
## MAX_BATCH_PREFILL_TOKENS
```shell
--max-batch-prefill-tokens <MAX_BATCH_PREFILL_TOKENS>
Limits the number of tokens for the prefill operation. Since this operation take the most memory and is compute bound, it is interesting to limit the number of requests that can be sent
Limits the number of tokens for the prefill operation. Since this operation take the most memory and is compute bound, it is interesting to limit the number of requests that can be sent. Default to `max_input_tokens + 50` to give a bit of room
[env: MAX_BATCH_PREFILL_TOKENS=]
[default: 4096]
```
## MAX_BATCH_TOTAL_TOKENS
@ -210,10 +215,9 @@ Options:
## CUDA_GRAPHS
```shell
--cuda-graphs <CUDA_GRAPHS>
Specify the batch sizes to compute cuda graphs for. Use "0" to disable
Specify the batch sizes to compute cuda graphs for. Use "0" to disable. Default = "1,2,4,8,16,32"
[env: CUDA_GRAPHS=]
[default: 1,2,4,8,16,32,64,96,128]
```
## HOSTNAME

View File

@ -3,7 +3,7 @@ import pytest
@pytest.fixture(scope="module")
def t5_sharded_handle(launcher):
with launcher("google/flan-t5-xxl", num_shard=2) as handle:
with launcher("google/flan-t5-xxl", num_shard=4) as handle:
yield handle

View File

@ -9,8 +9,10 @@ homepage.workspace = true
[dependencies]
clap = { version = "4.4.5", features = ["derive", "env"] }
ctrlc = { version = "3.4.1", features = ["termination"] }
hf-hub = "0.3.2"
nix = { version = "0.28.0", features = ["signal"] }
serde = { version = "1.0.188", features = ["derive"] }
once_cell = "1.19.0"
serde = { version = "1.0.188", features = ["derive"] }
serde_json = "1.0.107"
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }

View File

@ -1,4 +1,5 @@
use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType};
use nix::sys::signal::{self, Signal};
use nix::unistd::Pid;
use serde::Deserialize;
@ -19,17 +20,23 @@ use tracing_subscriber::EnvFilter;
mod env_runtime;
#[derive(Deserialize)]
struct Config {
max_position_embeddings: Option<usize>,
max_seq_len: Option<usize>,
}
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Quantization {
/// 4 bit quantization. Requires a specific AWQ quantized model:
/// https://hf.co/models?search=awq.
/// <https://hf.co/models?search=awq>.
/// Should replace GPTQ models wherever possible because of the better latency
Awq,
/// 8 bit quantization, doesn't require specific model.
/// Should be a drop-in replacement to bitsandbytes with much better performance.
/// Kernels are from https://github.com/NetEase-FuXi/EETQ.git
/// Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
Eetq,
/// 4 bit quantization. Requires a specific GTPQ quantized model: https://hf.co/models?search=gptq.
/// 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>.
/// text-generation-inference will use exllama (faster) kernels wherever possible, and use
/// triton kernel (wider support) when it's not.
/// AWQ has faster kernels.
@ -214,8 +221,13 @@ struct Args {
/// for users. The larger this value, the longer prompt users can send which
/// can impact the overall memory required to handle the load.
/// Please note that some models have a finite range of sequence they can handle.
#[clap(default_value = "1024", long, env)]
max_input_length: usize,
/// Default to min(max_position_embeddings - 1, 4095)
#[clap(long, env)]
max_input_tokens: Option<usize>,
/// Legacy version of [`Args::max_input_tokens`].
#[clap(long, env)]
max_input_length: Option<usize>,
/// This is the most important value to set as it defines the "memory budget"
/// of running clients requests.
@ -225,8 +237,9 @@ struct Args {
/// `1511` max_new_tokens.
/// The larger this value, the larger amount each request will be in your RAM
/// and the less effective batching can be.
#[clap(default_value = "2048", long, env)]
max_total_tokens: usize,
/// Default to min(max_position_embeddings, 4096)
#[clap(long, env)]
max_total_tokens: Option<usize>,
/// This represents the ratio of waiting queries vs running queries where
/// you want to start considering pausing the running queries to include the waiting
@ -244,8 +257,9 @@ struct Args {
/// Limits the number of tokens for the prefill operation.
/// Since this operation take the most memory and is compute bound, it is interesting
/// to limit the number of requests that can be sent.
#[clap(default_value = "4096", long, env)]
max_batch_prefill_tokens: u32,
/// Default to `max_input_tokens + 50` to give a bit of room.
#[clap(long, env)]
max_batch_prefill_tokens: Option<u32>,
/// **IMPORTANT** This is one critical control to allow maximum usage
/// of the available hardware.
@ -294,13 +308,9 @@ struct Args {
/// Specify the batch sizes to compute cuda graphs for.
/// Use "0" to disable.
#[clap(
long,
env,
value_delimiter = ',',
default_value = "1,2,4,8,16,32,64,96,128"
)]
cuda_graphs: Vec<usize>,
/// Default = "1,2,4,8,16,32"
#[clap(long, env, value_delimiter = ',')]
cuda_graphs: Option<Vec<usize>>,
/// The IP address to listen on
#[clap(default_value = "0.0.0.0", long, env)]
@ -808,6 +818,14 @@ enum LauncherError {
WebserverCannotStart,
}
impl core::fmt::Display for LauncherError {
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
write!(f, "{self:?}")
}
}
impl std::error::Error for LauncherError {}
fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
// Enter download tracing span
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
@ -944,6 +962,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
fn spawn_shards(
num_shard: usize,
args: &Args,
cuda_graphs: Vec<usize>,
shutdown: Arc<AtomicBool>,
shutdown_receiver: &mpsc::Receiver<()>,
shutdown_sender: mpsc::Sender<()>,
@ -971,11 +990,7 @@ fn spawn_shards(
let disable_custom_kernels = args.disable_custom_kernels;
let watermark_gamma = args.watermark_gamma;
let watermark_delta = args.watermark_delta;
let cuda_graphs: Vec<usize> = args
.cuda_graphs
.iter()
.filter_map(|&c| if c > 0 { Some(c) } else { None })
.collect();
let cuda_graphs_clone = cuda_graphs.clone();
let cuda_memory_fraction = args.cuda_memory_fraction;
let rope_scaling = args.rope_scaling;
let rope_factor = args.rope_factor;
@ -997,7 +1012,7 @@ fn spawn_shards(
disable_custom_kernels,
watermark_gamma,
watermark_delta,
cuda_graphs,
cuda_graphs_clone,
cuda_memory_fraction,
rope_scaling,
rope_factor,
@ -1053,6 +1068,9 @@ fn compute_type(num_shard: usize) -> Option<String> {
fn spawn_webserver(
num_shard: usize,
args: Args,
max_input_tokens: usize,
max_total_tokens: usize,
max_batch_prefill_tokens: u32,
shutdown: Arc<AtomicBool>,
shutdown_receiver: &mpsc::Receiver<()>,
) -> Result<Child, LauncherError> {
@ -1068,12 +1086,12 @@ fn spawn_webserver(
args.max_stop_sequences.to_string(),
"--max-top-n-tokens".to_string(),
args.max_top_n_tokens.to_string(),
"--max-input-length".to_string(),
args.max_input_length.to_string(),
"--max-input-tokens".to_string(),
max_input_tokens.to_string(),
"--max-total-tokens".to_string(),
args.max_total_tokens.to_string(),
max_total_tokens.to_string(),
"--max-batch-prefill-tokens".to_string(),
args.max_batch_prefill_tokens.to_string(),
max_batch_prefill_tokens.to_string(),
"--waiting-served-ratio".to_string(),
args.waiting_served_ratio.to_string(),
"--max-waiting-tokens".to_string(),
@ -1251,19 +1269,129 @@ fn main() -> Result<(), LauncherError> {
tracing::info!("{:?}", args);
let get_max_position_embeddings = || -> Result<usize, Box<dyn std::error::Error>> {
let model_id = args.model_id.clone();
let mut path = std::path::Path::new(&args.model_id).to_path_buf();
let filename = if !path.exists() {
// Assume it's a hub id
let api = Api::new()?;
let repo = if let Some(ref revision) = args.revision {
api.repo(Repo::with_revision(
model_id,
RepoType::Model,
revision.to_string(),
))
} else {
api.model(model_id)
};
repo.get("config.json")?
} else {
path.push("config.json");
path
};
let content = std::fs::read_to_string(filename)?;
let config: Config = serde_json::from_str(&content)?;
// Quantization usually means you're even more RAM constrained.
let max_default = 4096;
let max_position_embeddings = match (config.max_position_embeddings, config.max_seq_len) {
(Some(max_position_embeddings), _) | (None, Some(max_position_embeddings)) => {
if max_position_embeddings > max_default {
let max = max_position_embeddings;
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
max_default
} else {
max_position_embeddings
}
}
_ => {
return Err(Box::new(LauncherError::ArgumentValidation(
"no max defined".to_string(),
)));
}
};
Ok(max_position_embeddings)
};
let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096);
let max_input_tokens = {
match (args.max_input_tokens, args.max_input_length) {
(Some(max_input_tokens), Some(max_input_length)) => {
return Err(LauncherError::ArgumentValidation(
format!("Both `max_input_tokens` ({max_input_tokens}) and `max_input_length` ({max_input_length}) are set. Please define only `max_input_tokens` as `max_input_length is deprecated for naming consistency.",
)));
}
(Some(max_input_tokens), None) | (None, Some(max_input_tokens)) => max_input_tokens,
(None, None) => {
let value = max_position_embeddings - 1;
tracing::info!("Default `max_input_tokens` to {value}");
value
}
}
};
let max_total_tokens = {
match args.max_total_tokens {
Some(max_total_tokens) => max_total_tokens,
None => {
let value = max_position_embeddings;
tracing::info!("Default `max_total_tokens` to {value}");
value
}
}
};
let max_batch_prefill_tokens = {
match args.max_batch_prefill_tokens {
Some(max_batch_prefill_tokens) => max_batch_prefill_tokens,
None => {
let value: u32 = if let Some(max_batch_size) = args.max_batch_size {
max_batch_size * max_input_tokens
} else {
// Adding some edge in order to account for potential block_size alignement
// issue.
max_input_tokens + 50
} as u32;
tracing::info!("Default `max_batch_prefill_tokens` to {value}");
value
}
}
};
// Validate args
if args.max_input_length >= args.max_total_tokens {
if max_input_tokens >= max_total_tokens {
return Err(LauncherError::ArgumentValidation(
"`max_input_length` must be < `max_total_tokens`".to_string(),
"`max_input_tokens must be < `max_total_tokens`".to_string(),
));
}
if args.max_input_length as u32 > args.max_batch_prefill_tokens {
if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {} and {}",
args.max_batch_prefill_tokens, args.max_input_length
"`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}",
max_batch_prefill_tokens, max_input_tokens
)));
}
let cuda_graphs = match (&args.cuda_graphs, &args.quantize) {
(Some(cuda_graphs), Some(_q)) => cuda_graphs.clone(),
#[allow(deprecated)]
(
None,
Some(
Quantization::Bitsandbytes
| Quantization::BitsandbytesNF4
| Quantization::BitsandbytesFP4,
),
) => {
tracing::info!("Bitsandbytes doesn't work with cuda graphs, deactivating them");
vec![]
}
_ => {
let cuda_graphs = vec![1, 2, 4, 8, 16, 32];
tracing::info!("Using default cuda graphs {cuda_graphs:?}");
cuda_graphs
}
};
if args.validation_workers == 0 {
return Err(LauncherError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(),
@ -1282,16 +1410,16 @@ fn main() -> Result<(), LauncherError> {
}
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
if args.max_batch_prefill_tokens > *max_batch_total_tokens {
if max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
args.max_batch_prefill_tokens, max_batch_total_tokens
max_batch_prefill_tokens, max_batch_total_tokens
)));
}
if args.max_total_tokens as u32 > *max_batch_total_tokens {
if max_total_tokens as u32 > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
args.max_total_tokens, max_batch_total_tokens
max_total_tokens, max_batch_total_tokens
)));
}
}
@ -1338,6 +1466,7 @@ fn main() -> Result<(), LauncherError> {
spawn_shards(
num_shard,
&args,
cuda_graphs,
shutdown.clone(),
&shutdown_receiver,
shutdown_sender,
@ -1352,11 +1481,19 @@ fn main() -> Result<(), LauncherError> {
return Ok(());
}
let mut webserver = spawn_webserver(num_shard, args, shutdown.clone(), &shutdown_receiver)
.map_err(|err| {
shutdown_shards(shutdown.clone(), &shutdown_receiver);
err
})?;
let mut webserver = spawn_webserver(
num_shard,
args,
max_input_tokens,
max_total_tokens,
max_batch_prefill_tokens,
shutdown.clone(),
&shutdown_receiver,
)
.map_err(|err| {
shutdown_shards(shutdown.clone(), &shutdown_receiver);
err
})?;
// Default exit code
let mut exit_code = Ok(());

View File

@ -35,7 +35,7 @@ struct Args {
#[clap(default_value = "5", long, env)]
max_top_n_tokens: u32,
#[clap(default_value = "1024", long, env)]
max_input_length: usize,
max_input_tokens: usize,
#[clap(default_value = "2048", long, env)]
max_total_tokens: usize,
#[clap(default_value = "1.2", long, env)]
@ -90,7 +90,7 @@ async fn main() -> Result<(), RouterError> {
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_length,
max_input_tokens,
max_total_tokens,
waiting_served_ratio,
max_batch_prefill_tokens,
@ -118,13 +118,13 @@ async fn main() -> Result<(), RouterError> {
init_logging(otlp_endpoint, json_output);
// Validate args
if max_input_length >= max_total_tokens {
if max_input_tokens >= max_total_tokens {
return Err(RouterError::ArgumentValidation(
"`max_input_length` must be < `max_total_tokens`".to_string(),
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
));
}
if max_input_length as u32 > max_batch_prefill_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}")));
if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
}
if validation_workers == 0 {
@ -311,7 +311,7 @@ async fn main() -> Result<(), RouterError> {
tracing::info!("Warming up model");
let max_supported_batch_total_tokens = match sharded_client
.warmup(
max_input_length as u32,
max_input_tokens as u32,
max_batch_prefill_tokens,
max_total_tokens as u32,
max_batch_size,
@ -374,7 +374,7 @@ async fn main() -> Result<(), RouterError> {
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_length,
max_input_tokens,
max_total_tokens,
waiting_served_ratio,
max_batch_prefill_tokens,
@ -402,12 +402,15 @@ async fn main() -> Result<(), RouterError> {
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
/// - LOG_FORMAT may be TEXT or JSON (default to TEXT)
/// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms)
fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
let mut layers = Vec::new();
// STDOUT/STDERR layer
let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string());
let fmt_layer = tracing_subscriber::fmt::layer()
.with_file(true)
.with_ansi(ansi)
.with_line_number(true);
let fmt_layer = match json_output {

View File

@ -190,12 +190,14 @@ impl State {
token_budget: u32,
) -> Option<NextBatch> {
if self.entries.is_empty() {
tracing::debug!("No queue");
return None;
}
// Check if we have enough entries
if let Some(min_size) = min_size {
if self.entries.len() < min_size {
tracing::debug!("Not enough entries");
return None;
}
}
@ -222,6 +224,7 @@ impl State {
// was dropped by the client)
if entry.response_tx.is_closed() {
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
tracing::debug!("Dropping entry");
continue;
}
@ -258,10 +261,12 @@ impl State {
{
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
self.entries.push_front((id, entry));
break;
}
tracing::debug!("Accepting entry");
// Create a new span to link the batch back to this entry
let entry_batch_span = info_span!(parent: &entry.span, "infer");
// Add relationships
@ -292,6 +297,7 @@ impl State {
// Empty batch
if batch_requests.is_empty() {
tracing::debug!("Filterered out all entries");
return None;
}

View File

@ -161,14 +161,17 @@ impl Validation {
} else {
return Err(ValidationError::UnsetMaxNewTokens);
};
let input_length = truncate.unwrap_or(self.max_input_length);
let mut input_length = truncate.unwrap_or(self.max_input_length);
// We don't have a tokenizer, therefore we have no idea how long is the query, let
// them through and hope for the best.
// Validate MaxNewTokens
if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
return Err(ValidationError::MaxNewTokens(
self.max_total_tokens - self.max_input_length,
max_new_tokens,
));
input_length = input_length.saturating_sub(max_new_tokens as usize);
// return Err(ValidationError::MaxNewTokens(
// self.max_total_tokens - self.max_input_length,
// max_new_tokens,
// ));
}
Ok((inputs, input_length, max_new_tokens))
@ -664,8 +667,9 @@ mod tests {
.validate_input("Hello".to_string(), None, Some(max_new_tokens))
.await
{
Err(ValidationError::MaxNewTokens(1, 10)) => (),
_ => panic!("Unexpected not max new tokens"),
// Err(ValidationError::MaxNewTokens(1, 10)) => (),
Ok((_s, 0, 10)) => (),
r => panic!("Unexpected not max new tokens: {r:?}"),
}
}