Choosing input/total tokens automatically based on available VRAM? (#2673)

* Choosing input/total tokens automatically based on available VRAM?

* Update doc.

* Remove generated files.

* Trying to fix non chunking targets.

* Attempt #2

* fix.

* QuantLinear is rocm compatible.

* Much simpler logic after the overhead.

* Updating logic + non flash.

* Revert doc text.

* Simple updates.

* Fix integration mt0 (transformers update).
This commit is contained in:
Nicolas Patry 2024-10-28 04:59:49 +01:00 committed by GitHub
parent 2e4f4ba1bb
commit 0c9b6cdd76
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 285 additions and 136 deletions

2
.gitignore vendored
View File

@ -5,6 +5,8 @@ router/tokenizer.json
backends/v2/src/client/pb
backends/v3/src/client/pb
backends/client/src/v2/pb
backends/client/src/v3/pb
# ROCm auto-generated files
*.hip

View File

@ -107,20 +107,22 @@ impl Client {
#[instrument(skip_all)]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_input_tokens: Option<u32>,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_total_tokens: Option<u32>,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
) -> Result<(Option<u32>, u32, u32)> {
let mut n_tokens = 0;
let mut requests = Vec::new();
// Create requests
while n_tokens < max_prefill_tokens {
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
let mut truncate = max_prefill_tokens - n_tokens;
if let Some(max_input_tokens) = max_input_tokens {
truncate = min(max_input_tokens, truncate);
}
let mut input_chunks = Vec::new();
input_chunks
.push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
input_chunks.push(Chunk::Text("_test ".to_string().repeat(truncate as usize)).into());
if n_tokens == 0 {
input_chunks.push(
Chunk::Image(Image {
@ -136,7 +138,7 @@ impl Client {
// been updated to support chunks.
let mut inputs = String::new();
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
inputs.push_str(&"_test ".to_string().repeat(truncate as usize));
if n_tokens == 0 {
// 1 request is enough to test vision heads.
// Sending images on other queries messes up easily with truncation.
@ -145,6 +147,12 @@ impl Client {
));
}
let max_new_tokens = if let Some(max_total_tokens) = max_total_tokens {
max_total_tokens - truncate
} else {
1
};
requests.push(Request {
id: 0,
inputs,
@ -175,7 +183,7 @@ impl Client {
grammar_type: GrammarType::None as i32,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate,
max_new_tokens,
stop_sequences: vec![],
ignore_eos_token: true,
}),
@ -183,7 +191,7 @@ impl Client {
top_n_tokens: 20,
adapter_id: None,
});
n_tokens += max_input_length;
n_tokens += truncate;
// Check max_batch_size
if Some(requests.len()) == max_batch_size {
@ -195,19 +203,23 @@ impl Client {
id: 0,
size: requests.len() as u32,
requests,
max_tokens: max_input_length,
max_tokens: max_input_tokens.unwrap_or(0),
max_blocks: 0,
};
let request = tonic::Request::new(WarmupRequest {
batch: Some(batch),
max_input_length,
max_input_tokens,
max_prefill_tokens,
max_total_tokens,
})
.inject_context();
let response = self.stub.warmup(request).await?.into_inner();
Ok(response.max_supported_total_tokens)
Ok((
response.max_supported_total_tokens,
response.max_input_tokens,
response.max_total_tokens,
))
}
/// Generate one token for each request in the given batch

View File

@ -101,11 +101,11 @@ impl ShardedClient {
#[instrument(skip(self))]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_input_length: Option<u32>,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_total_tokens: Option<u32>,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
) -> Result<(Option<u32>, u32, u32)> {
let futures: Vec<_> = self
.clients
.iter_mut()
@ -122,8 +122,16 @@ impl ShardedClient {
let results = join_all(futures)
.await
.into_iter()
.collect::<Result<Vec<Option<u32>>>>()?;
Ok(results.into_iter().flatten().min())
.collect::<Result<Vec<(Option<u32>, u32, u32)>>>()?;
// Take the minimum value
// Different shards hold different parts of vocab, might yield
// different available block size.
let min = results
.iter()
.min()
.expect("Expect at least 1 warmup result");
Ok(*min)
}
/// Generate one token for each request in the given batch

View File

@ -108,20 +108,22 @@ impl Client {
#[instrument(skip_all)]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_input_tokens: Option<u32>,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_total_tokens: Option<u32>,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
) -> Result<(Option<u32>, u32, u32)> {
let mut n_tokens = 0;
let mut requests = Vec::new();
// Create requests
while n_tokens < max_prefill_tokens {
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
let mut truncate = max_prefill_tokens - n_tokens;
if let Some(max_input_tokens) = max_input_tokens {
truncate = min(max_input_tokens, truncate);
}
let mut input_chunks = Vec::new();
input_chunks
.push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
input_chunks.push(Chunk::Text("_test ".to_string().repeat(truncate as usize)).into());
if n_tokens == 0 {
input_chunks.push(
Chunk::Image(Image {
@ -137,7 +139,7 @@ impl Client {
// been updated to support chunks.
let mut inputs = String::new();
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
inputs.push_str(&"_test ".to_string().repeat(truncate as usize));
if n_tokens == 0 {
// 1 request is enough to test vision heads.
// Sending images on other queries messes up easily with truncation.
@ -146,6 +148,12 @@ impl Client {
));
}
let max_new_tokens = if let Some(max_total_tokens) = max_total_tokens {
max_total_tokens - truncate
} else {
1
};
requests.push(Request {
id: 0,
inputs,
@ -175,7 +183,7 @@ impl Client {
grammar_type: GrammarType::None as i32,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate,
max_new_tokens,
stop_sequences: vec![],
ignore_eos_token: true,
}),
@ -183,7 +191,7 @@ impl Client {
top_n_tokens: 20,
adapter_id: None,
});
n_tokens += max_input_length;
n_tokens += truncate;
// Check max_batch_size
if Some(requests.len()) == max_batch_size {
@ -195,19 +203,23 @@ impl Client {
id: 0,
size: requests.len() as u32,
requests,
max_tokens: max_input_length,
max_tokens: max_input_tokens.unwrap_or(0),
max_blocks: 0,
};
let request = tonic::Request::new(WarmupRequest {
batch: Some(batch),
max_input_length,
max_input_tokens,
max_prefill_tokens,
max_total_tokens,
})
.inject_context();
let response = self.stub.warmup(request).await?.into_inner();
Ok(response.max_supported_total_tokens)
Ok((
response.max_supported_total_tokens,
response.max_input_tokens,
response.max_total_tokens,
))
}
/// Generate one token for each request in the given batch

View File

@ -102,11 +102,11 @@ impl ShardedClient {
#[instrument(skip(self))]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_input_length: Option<u32>,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_total_tokens: Option<u32>,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
) -> Result<(Option<u32>, u32, u32)> {
let futures: Vec<_> = self
.clients
.iter_mut()
@ -119,12 +119,19 @@ impl ShardedClient {
))
})
.collect();
// Take the minimum value
let results = join_all(futures)
.await
.into_iter()
.collect::<Result<Vec<Option<u32>>>>()?;
Ok(results.into_iter().flatten().min())
.collect::<Result<Vec<(Option<u32>, u32, u32)>>>()?;
// Take the minimum value
// Different shards hold different parts of vocab, might yield
// different available block size.
let min = results
.iter()
.min()
.expect("Expect at least 1 warmup result");
Ok(*min)
}
/// Generate one token for each request in the given batch

View File

@ -37,12 +37,17 @@ pub struct BackendInfo {
pub attention_impl: String,
#[schema(example = "1")]
pub block_size: u32,
#[schema(example = "30000")]
pub max_input_tokens: usize,
#[schema(example = "32000")]
pub max_total_tokens: usize,
}
#[allow(clippy::too_many_arguments)]
pub async fn connect_backend(
max_input_tokens: usize,
max_total_tokens: usize,
max_input_tokens: Option<usize>,
max_total_tokens: Option<usize>,
master_shard_uds_path: String,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
@ -51,14 +56,32 @@ pub async fn connect_backend(
max_batch_size: Option<usize>,
) -> Result<(BackendV3, BackendInfo), V3Error> {
// Helper function
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
let check_max_batch_total_tokens = |(
max_supported_batch_total_tokens,
shard_max_input_tokens,
shard_max_total_tokens,
): (Option<u32>, u32, u32)|
-> Result<(u32, usize, usize), V3Error> {
if let Some(max_input_tokens) = max_input_tokens {
assert_eq!(max_input_tokens as u32, shard_max_input_tokens);
}
if let Some(max_total_tokens) = max_total_tokens {
assert_eq!(max_total_tokens as u32, shard_max_total_tokens);
}
match max_supported_batch_total_tokens {
// Older models do not support automatic max-batch-total-tokens
None => {
let max_batch_total_tokens = max_batch_total_tokens
.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)));
let max_batch_total_tokens = max_batch_total_tokens.unwrap_or(
16000
.max(shard_max_total_tokens)
.max(max_batch_prefill_tokens),
);
tracing::warn!("Model does not support automatic max batch total tokens");
Ok(max_batch_total_tokens)
Ok((
max_batch_total_tokens,
shard_max_input_tokens as usize,
shard_max_total_tokens as usize,
))
}
// Flash attention models return their max supported total tokens
Some(max_supported_batch_total_tokens) => {
@ -72,11 +95,15 @@ pub async fn connect_backend(
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
);
}
if max_total_tokens as u32 > max_supported_batch_total_tokens {
return Err(V3Error::NotEnoughMemory(max_total_tokens));
if shard_max_total_tokens > max_supported_batch_total_tokens {
return Err(V3Error::NotEnoughMemory(shard_max_total_tokens as usize));
}
Ok(max_supported_batch_total_tokens)
Ok((
max_supported_batch_total_tokens,
shard_max_input_tokens as usize,
shard_max_total_tokens as usize,
))
}
}
};
@ -96,23 +123,25 @@ pub async fn connect_backend(
// Warmup model
tracing::info!("Warming up model");
let max_batch_total_tokens = check_max_batch_total_tokens(
sharded_client
let answer = sharded_client
.warmup(
max_input_tokens as u32,
max_input_tokens.map(|p| p as u32),
max_batch_prefill_tokens,
max_total_tokens as u32,
max_total_tokens.map(|p| p as u32),
max_batch_size,
)
.await
.map_err(V3Error::Warmup)?,
)?;
.map_err(V3Error::Warmup)?;
let (max_batch_total_tokens, max_input_tokens, max_total_tokens) =
check_max_batch_total_tokens(answer)?;
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
metrics::gauge!("tgi_batch_max_total_tokens").set(max_batch_total_tokens);
let backend_info = BackendInfo {
waiting_served_ratio,
max_batch_total_tokens,
max_input_tokens,
max_total_tokens,
max_waiting_tokens,
max_batch_size,
model_device_type: shard_info.device_type.clone(),

View File

@ -18,10 +18,10 @@ struct Args {
max_stop_sequences: usize,
#[clap(default_value = "5", long, env)]
max_top_n_tokens: u32,
#[clap(default_value = "1024", long, env)]
max_input_tokens: usize,
#[clap(default_value = "2048", long, env)]
max_total_tokens: usize,
#[clap(long, env)]
max_input_tokens: Option<usize>,
#[clap(long, env)]
max_total_tokens: Option<usize>,
#[clap(default_value = "1.2", long, env)]
waiting_served_ratio: f32,
#[clap(default_value = "4096", long, env)]
@ -126,12 +126,6 @@ async fn main() -> Result<(), RouterError> {
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
// Validate args
if max_input_tokens >= max_total_tokens {
return Err(RouterError::ArgumentValidation(
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
));
}
if validation_workers == 0 {
return Err(RouterError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(),
@ -160,6 +154,28 @@ async fn main() -> Result<(), RouterError> {
// Validate remaining args now that the backend is known
let support_chunking = backend_info.support_chunking;
let max_batch_total_tokens = backend_info.max_batch_total_tokens;
if max_input_tokens.is_none() {
tracing::info!(
"Maximum input tokens defaulted to {}",
backend_info.max_input_tokens
);
}
if max_total_tokens.is_none() {
tracing::info!(
"Maximum total tokens defaulted to {}",
backend_info.max_total_tokens
);
}
let max_input_tokens = backend_info.max_input_tokens;
let max_total_tokens = backend_info.max_total_tokens;
if max_input_tokens >= max_total_tokens {
return Err(RouterError::ArgumentValidation(
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
));
}
if max_input_tokens as u32 > max_batch_prefill_tokens && !support_chunking {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
}

View File

@ -146,7 +146,7 @@ Options:
## MAX_INPUT_TOKENS
```shell
--max-input-tokens <MAX_INPUT_TOKENS>
This is the maximum allowed input length (expressed in number of tokens) for users. The larger this value, the longer prompt users can send which can impact the overall memory required to handle the load. Please note that some models have a finite range of sequence they can handle. Default to min(max_position_embeddings - 1, 4095)
This is the maximum allowed input length (expressed in number of tokens) for users. The larger this value, the longer prompt users can send which can impact the overall memory required to handle the load. Please note that some models have a finite range of sequence they can handle. Default to min(max_allocatable, max_position_embeddings) - 1
[env: MAX_INPUT_TOKENS=]
@ -162,7 +162,7 @@ Options:
## MAX_TOTAL_TOKENS
```shell
--max-total-tokens <MAX_TOTAL_TOKENS>
This is the most important value to set as it defines the "memory budget" of running clients requests. Clients will send input sequences and ask to generate `max_new_tokens` on top. with a value of `1512` users can send either a prompt of `1000` and ask for `512` new tokens, or send a prompt of `1` and ask for `1511` max_new_tokens. The larger this value, the larger amount each request will be in your RAM and the less effective batching can be. Default to min(max_position_embeddings, 4096)
This is the most important value to set as it defines the "memory budget" of running clients requests. Clients will send input sequences and ask to generate `max_new_tokens` on top. with a value of `1512` users can send either a prompt of `1000` and ask for `512` new tokens, or send a prompt of `1` and ask for `1511` max_new_tokens. The larger this value, the larger amount each request will be in your RAM and the less effective batching can be. Default to min(max_allocatable, max_position_embeddings)
[env: MAX_TOTAL_TOKENS=]

View File

@ -472,7 +472,7 @@ struct Args {
/// for users. The larger this value, the longer prompt users can send which
/// can impact the overall memory required to handle the load.
/// Please note that some models have a finite range of sequence they can handle.
/// Default to min(max_position_embeddings - 1, 4095)
/// Default to min(max_allocatable, max_position_embeddings) - 1
#[clap(long, env)]
max_input_tokens: Option<usize>,
@ -488,7 +488,7 @@ struct Args {
/// `1511` max_new_tokens.
/// The larger this value, the larger amount each request will be in your RAM
/// and the less effective batching can be.
/// Default to min(max_position_embeddings, 4096)
/// Default to min(max_allocatable, max_position_embeddings)
#[clap(long, env)]
max_total_tokens: Option<usize>,
@ -718,9 +718,9 @@ fn shard_manager(
cuda_memory_fraction: f32,
rope_scaling: Option<RopeScaling>,
rope_factor: Option<f32>,
max_total_tokens: usize,
max_total_tokens: Option<usize>,
max_batch_size: Option<usize>,
max_input_tokens: usize,
max_input_tokens: Option<usize>,
lora_adapters: Option<String>,
otlp_endpoint: Option<String>,
otlp_service_name: String,
@ -805,8 +805,10 @@ fn shard_manager(
shard_args.push(otlp_service_name);
// In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter.
if let Some(max_input_tokens) = max_input_tokens {
shard_args.push("--max-input-tokens".to_string());
shard_args.push(max_input_tokens.to_string());
}
// Copy current process env
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
@ -854,10 +856,12 @@ fn shard_manager(
envs.push(("ROPE_FACTOR".into(), factor.to_string().into()));
}
if let Some(max_total_tokens) = max_total_tokens {
envs.push((
"MAX_TOTAL_TOKENS".into(),
max_total_tokens.to_string().into(),
));
}
if let Some(max_batch_size) = max_batch_size {
envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into()));
}
@ -1315,8 +1319,8 @@ fn spawn_shards(
num_shard: usize,
args: &Args,
cuda_graphs: Vec<usize>,
max_total_tokens: usize,
max_input_tokens: usize,
max_total_tokens: Option<usize>,
max_input_tokens: Option<usize>,
quantize: Option<Quantization>,
max_log_level: LevelFilter,
shutdown: Arc<AtomicBool>,
@ -1434,8 +1438,8 @@ fn compute_type(num_shard: usize) -> Option<String> {
fn spawn_webserver(
num_shard: usize,
args: Args,
max_input_tokens: usize,
max_total_tokens: usize,
max_input_tokens: Option<usize>,
max_total_tokens: Option<usize>,
max_batch_prefill_tokens: u32,
shutdown: Arc<AtomicBool>,
shutdown_receiver: &mpsc::Receiver<()>,
@ -1454,10 +1458,6 @@ fn spawn_webserver(
args.max_stop_sequences.to_string(),
"--max-top-n-tokens".to_string(),
args.max_top_n_tokens.to_string(),
"--max-input-tokens".to_string(),
max_input_tokens.to_string(),
"--max-total-tokens".to_string(),
max_total_tokens.to_string(),
"--max-batch-prefill-tokens".to_string(),
max_batch_prefill_tokens.to_string(),
"--waiting-served-ratio".to_string(),
@ -1475,6 +1475,18 @@ fn spawn_webserver(
"--tokenizer-name".to_string(),
args.model_id,
];
if let Some(max_input_tokens) = max_input_tokens {
router_args.extend_from_slice(&[
"--max-input-tokens".to_string(),
max_input_tokens.to_string(),
]);
}
if let Some(max_total_tokens) = max_total_tokens {
router_args.extend_from_slice(&[
"--max-total-tokens".to_string(),
max_total_tokens.to_string(),
]);
}
// Pass usage stats flags to router
router_args.push("--usage-stats".to_string());
@ -1704,35 +1716,19 @@ fn main() -> Result<(), LauncherError> {
format!("Both `max_input_tokens` ({max_input_tokens}) and `max_input_length` ({max_input_length}) are set. Please define only `max_input_tokens` as `max_input_length is deprecated for naming consistency.",
)));
}
(Some(max_input_tokens), None) | (None, Some(max_input_tokens)) => max_input_tokens,
(None, None) => {
let value = max_position_embeddings - 1;
tracing::info!("Default `max_input_tokens` to {value}");
value
}
}
};
let max_total_tokens = {
match args.max_total_tokens {
Some(max_total_tokens) => max_total_tokens,
None => {
let value = max_position_embeddings;
tracing::info!("Default `max_total_tokens` to {value}");
value
(Some(max_input_tokens), None) | (None, Some(max_input_tokens)) => {
Some(max_input_tokens)
}
(None, None) => None,
}
};
let max_total_tokens = args.max_total_tokens;
let max_batch_prefill_tokens = {
match args.max_batch_prefill_tokens {
Some(max_batch_prefill_tokens) => max_batch_prefill_tokens,
None => {
let value: u32 = if let Some(max_batch_size) = args.max_batch_size {
max_batch_size * max_input_tokens
} else {
// Adding some edge in order to account for potential block_size alignement
// issue.
max_input_tokens + 50
} as u32;
// TODO figure out hardware optimal value
let value = 4096.min(max_position_embeddings as u32);
tracing::info!("Default `max_batch_prefill_tokens` to {value}");
value
}
@ -1740,11 +1736,13 @@ fn main() -> Result<(), LauncherError> {
};
// Validate args
if let (Some(max_input_tokens), Some(max_total_tokens)) = (max_input_tokens, max_total_tokens) {
if max_input_tokens >= max_total_tokens {
return Err(LauncherError::ArgumentValidation(
"`max_input_tokens must be < `max_total_tokens`".to_string(),
format!("`max_input_tokens`({max_input_tokens}) must be < `max_total_tokens`({max_total_tokens})"),
));
}
}
if matches!(args.quantize, Some(Quantization::Bitsandbytes)) {
tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases.");
@ -1798,6 +1796,7 @@ fn main() -> Result<(), LauncherError> {
}
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
if let Some(max_total_tokens) = max_total_tokens {
if max_total_tokens as u32 > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
@ -1805,6 +1804,7 @@ fn main() -> Result<(), LauncherError> {
)));
}
}
}
if args.ngrok {
if args.ngrok_authtoken.is_none() {

View File

@ -272,12 +272,18 @@ message DecodeResponse {
message WarmupRequest {
/// Batch to warmup on
Batch batch = 1;
uint32 max_input_length = 2;
optional uint32 max_input_tokens = 2;
uint32 max_prefill_tokens = 3;
uint32 max_total_tokens = 4;
optional uint32 max_total_tokens = 4;
}
message WarmupResponse {
/// Maximum number of tokens supported by the model
optional uint32 max_supported_total_tokens = 1;
/// Maximum input tokens by clients should be equal to request value if it's set
/// Otherwise warmup automatically allocates a value here
uint32 max_input_tokens = 2;
/// Maximum total tokens by clients should be equal to request value if it's set
/// Otherwise warmup automatically allocates a value here
uint32 max_total_tokens = 3;
}

View File

@ -86,6 +86,10 @@ tracer = trace.get_tracer(__name__)
SLIDING_WINDOW: Optional[int] = None
def small_power_of_2(n: int):
return 1 << ((n - 1).bit_length() - 1)
def set_sliding_window(sliding_window: int):
global SLIDING_WINDOW
SLIDING_WINDOW = sliding_window
@ -1495,11 +1499,22 @@ class FlashCausalLM(Model):
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
torch.cuda.synchronize()
def warmup(self, batch: FlashCausalLMBatch):
def warmup(
self,
batch: FlashCausalLMBatch,
max_input_tokens: Optional[int],
max_total_tokens: Optional[int],
):
# The warmup batch is the biggest batch we could ever receive
self.kv_cache = []
empty_cache()
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the free memory
dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
try:
self.init_kv_cache(
batch.num_blocks,
@ -1511,10 +1526,11 @@ class FlashCausalLM(Model):
)
max_bt = batch.max_blocks
max_s = max_bt * BLOCK_SIZE
batch_num_blocks = batch.num_blocks
if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
torch.cuda.tunable.tuning_enable(False)
_, batch, _ = self.generate_token(batch)
_, _batch, _ = self.generate_token(batch)
except torch.cuda.OutOfMemoryError as e:
raise RuntimeError(
f"Not enough memory to handle {batch.to_pb().current_tokens} prefill tokens. "
@ -1523,14 +1539,7 @@ class FlashCausalLM(Model):
synchronize(self.device)
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the free memory
dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
free_memory = get_free_memory(self.device, MEMORY_FRACTION)
batch_num_blocks = batch.num_blocks if batch is not None else 0
num_blocks = (
# Leave 5% for some wiggle room
@ -1540,8 +1549,27 @@ class FlashCausalLM(Model):
)
log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}")
if max_total_tokens is None:
if get_support_chunking():
model_max_length = self.tokenizer.model_max_length
max_input_tokens = (
min((num_blocks * BLOCK_SIZE - 1), model_max_length)
if max_input_tokens is None
else max_input_tokens
)
max_total_tokens = num_blocks * BLOCK_SIZE
del batch
else:
max_total_tokens = sum(batch.cache_lengths)
max_input_tokens = (
max_total_tokens - 1
if max_input_tokens is None
else max_input_tokens
)
del _batch, batch
self.kv_cache = []
empty_cache()
self.init_kv_cache(
num_blocks,
@ -1623,7 +1651,9 @@ class FlashCausalLM(Model):
logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})."
)
return int(num_blocks * BLOCK_SIZE)
assert max_input_tokens is not None
assert max_total_tokens is not None
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
def tunableop_warmup(self, seqlen: int):
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)

View File

@ -1,7 +1,7 @@
import torch
import torch.distributed
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from typing import Optional
from typing import Optional, Union
from text_generation_server.models.custom_modeling.mamba_modeling import (
MambaConfig,
)
@ -475,7 +475,9 @@ class Mamba(Model):
def batch_type(self) -> Type[MambaBatch]:
return MambaBatch
def warmup(self, batch) -> Optional[int]:
def warmup(
self, batch, max_input_tokens: Optional[int], max_total_tokens: Optional[int]
) -> Union[Optional[int], Optional[int], Optional[int]]:
# TODO: implement warmup for Mamba if needed
if CUDA_GRAPHS:
if self.speculate is None or self.speculate == 0:
@ -489,7 +491,12 @@ class Mamba(Model):
else:
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
return None
if max_total_tokens is None:
max_total_tokens = min(self.tokenizer.model_max_length, 4096)
if max_input_tokens is None:
max_input_tokens = max_total_tokens - 1
return None, max_input_tokens, max_total_tokens
def cuda_graph_warmup(self, batch_size: int):
input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device)

View File

@ -128,9 +128,17 @@ class Model(ABC):
) -> Tuple[List[Generation], Optional[B], Tuple[int, int]]:
raise NotImplementedError
def warmup(self, batch: B) -> Optional[int]:
def warmup(
self, batch: B, max_input_tokens: Optional[int], max_total_tokens: Optional[int]
) -> Tuple[Optional[int], int, int]:
self.generate_token(batch)
return None
total = sum(len(i) for i in batch.input_ids)
if max_total_tokens is None:
max_total_tokens = total
if max_input_tokens is None:
max_input_tokens = max_total_tokens - 1
return None, max_input_tokens, max_total_tokens
def decode_token(
self,

View File

@ -132,10 +132,22 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
)
max_supported_total_tokens = self.model.warmup(batch)
# Override default values with None for clearer semantics.
max_input_tokens = (
request.max_input_tokens if request.HasField("max_input_tokens") else None
)
max_total_tokens = (
request.max_total_tokens if request.HasField("max_total_tokens") else None
)
max_supported_total_tokens, max_input_tokens, max_total_tokens = (
self.model.warmup(batch, max_input_tokens, max_total_tokens)
)
return generate_pb2.WarmupResponse(
max_supported_total_tokens=max_supported_total_tokens
max_supported_total_tokens=max_supported_total_tokens,
max_input_tokens=max_input_tokens,
max_total_tokens=max_total_tokens,
)
async def Prefill(self, request, context):