Choosing input/total tokens automatically based on available VRAM? (#2673)
* Choosing input/total tokens automatically based on available VRAM? * Update doc. * Remove generated files. * Trying to fix non chunking targets. * Attempt #2 * fix. * QuantLinear is rocm compatible. * Much simpler logic after the overhead. * Updating logic + non flash. * Revert doc text. * Simple updates. * Fix integration mt0 (transformers update).
This commit is contained in:
parent
2e4f4ba1bb
commit
0c9b6cdd76
|
@ -5,6 +5,8 @@ router/tokenizer.json
|
||||||
|
|
||||||
backends/v2/src/client/pb
|
backends/v2/src/client/pb
|
||||||
backends/v3/src/client/pb
|
backends/v3/src/client/pb
|
||||||
|
backends/client/src/v2/pb
|
||||||
|
backends/client/src/v3/pb
|
||||||
|
|
||||||
# ROCm auto-generated files
|
# ROCm auto-generated files
|
||||||
*.hip
|
*.hip
|
||||||
|
|
|
@ -107,20 +107,22 @@ impl Client {
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
pub async fn warmup(
|
pub async fn warmup(
|
||||||
&mut self,
|
&mut self,
|
||||||
max_input_length: u32,
|
max_input_tokens: Option<u32>,
|
||||||
max_prefill_tokens: u32,
|
max_prefill_tokens: u32,
|
||||||
max_total_tokens: u32,
|
max_total_tokens: Option<u32>,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
) -> Result<Option<u32>> {
|
) -> Result<(Option<u32>, u32, u32)> {
|
||||||
let mut n_tokens = 0;
|
let mut n_tokens = 0;
|
||||||
let mut requests = Vec::new();
|
let mut requests = Vec::new();
|
||||||
// Create requests
|
// Create requests
|
||||||
while n_tokens < max_prefill_tokens {
|
while n_tokens < max_prefill_tokens {
|
||||||
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
let mut truncate = max_prefill_tokens - n_tokens;
|
||||||
|
if let Some(max_input_tokens) = max_input_tokens {
|
||||||
|
truncate = min(max_input_tokens, truncate);
|
||||||
|
}
|
||||||
|
|
||||||
let mut input_chunks = Vec::new();
|
let mut input_chunks = Vec::new();
|
||||||
input_chunks
|
input_chunks.push(Chunk::Text("_test ".to_string().repeat(truncate as usize)).into());
|
||||||
.push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
|
|
||||||
if n_tokens == 0 {
|
if n_tokens == 0 {
|
||||||
input_chunks.push(
|
input_chunks.push(
|
||||||
Chunk::Image(Image {
|
Chunk::Image(Image {
|
||||||
|
@ -136,7 +138,7 @@ impl Client {
|
||||||
// been updated to support chunks.
|
// been updated to support chunks.
|
||||||
|
|
||||||
let mut inputs = String::new();
|
let mut inputs = String::new();
|
||||||
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
inputs.push_str(&"_test ".to_string().repeat(truncate as usize));
|
||||||
if n_tokens == 0 {
|
if n_tokens == 0 {
|
||||||
// 1 request is enough to test vision heads.
|
// 1 request is enough to test vision heads.
|
||||||
// Sending images on other queries messes up easily with truncation.
|
// Sending images on other queries messes up easily with truncation.
|
||||||
|
@ -145,6 +147,12 @@ impl Client {
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let max_new_tokens = if let Some(max_total_tokens) = max_total_tokens {
|
||||||
|
max_total_tokens - truncate
|
||||||
|
} else {
|
||||||
|
1
|
||||||
|
};
|
||||||
|
|
||||||
requests.push(Request {
|
requests.push(Request {
|
||||||
id: 0,
|
id: 0,
|
||||||
inputs,
|
inputs,
|
||||||
|
@ -175,7 +183,7 @@ impl Client {
|
||||||
grammar_type: GrammarType::None as i32,
|
grammar_type: GrammarType::None as i32,
|
||||||
}),
|
}),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
max_new_tokens: max_total_tokens - truncate,
|
max_new_tokens,
|
||||||
stop_sequences: vec![],
|
stop_sequences: vec![],
|
||||||
ignore_eos_token: true,
|
ignore_eos_token: true,
|
||||||
}),
|
}),
|
||||||
|
@ -183,7 +191,7 @@ impl Client {
|
||||||
top_n_tokens: 20,
|
top_n_tokens: 20,
|
||||||
adapter_id: None,
|
adapter_id: None,
|
||||||
});
|
});
|
||||||
n_tokens += max_input_length;
|
n_tokens += truncate;
|
||||||
|
|
||||||
// Check max_batch_size
|
// Check max_batch_size
|
||||||
if Some(requests.len()) == max_batch_size {
|
if Some(requests.len()) == max_batch_size {
|
||||||
|
@ -195,19 +203,23 @@ impl Client {
|
||||||
id: 0,
|
id: 0,
|
||||||
size: requests.len() as u32,
|
size: requests.len() as u32,
|
||||||
requests,
|
requests,
|
||||||
max_tokens: max_input_length,
|
max_tokens: max_input_tokens.unwrap_or(0),
|
||||||
max_blocks: 0,
|
max_blocks: 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
let request = tonic::Request::new(WarmupRequest {
|
let request = tonic::Request::new(WarmupRequest {
|
||||||
batch: Some(batch),
|
batch: Some(batch),
|
||||||
max_input_length,
|
max_input_tokens,
|
||||||
max_prefill_tokens,
|
max_prefill_tokens,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
})
|
})
|
||||||
.inject_context();
|
.inject_context();
|
||||||
let response = self.stub.warmup(request).await?.into_inner();
|
let response = self.stub.warmup(request).await?.into_inner();
|
||||||
Ok(response.max_supported_total_tokens)
|
Ok((
|
||||||
|
response.max_supported_total_tokens,
|
||||||
|
response.max_input_tokens,
|
||||||
|
response.max_total_tokens,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate one token for each request in the given batch
|
/// Generate one token for each request in the given batch
|
||||||
|
|
|
@ -101,11 +101,11 @@ impl ShardedClient {
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
pub async fn warmup(
|
pub async fn warmup(
|
||||||
&mut self,
|
&mut self,
|
||||||
max_input_length: u32,
|
max_input_length: Option<u32>,
|
||||||
max_prefill_tokens: u32,
|
max_prefill_tokens: u32,
|
||||||
max_total_tokens: u32,
|
max_total_tokens: Option<u32>,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
) -> Result<Option<u32>> {
|
) -> Result<(Option<u32>, u32, u32)> {
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
.clients
|
.clients
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
|
@ -122,8 +122,16 @@ impl ShardedClient {
|
||||||
let results = join_all(futures)
|
let results = join_all(futures)
|
||||||
.await
|
.await
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.collect::<Result<Vec<Option<u32>>>>()?;
|
.collect::<Result<Vec<(Option<u32>, u32, u32)>>>()?;
|
||||||
Ok(results.into_iter().flatten().min())
|
|
||||||
|
// Take the minimum value
|
||||||
|
// Different shards hold different parts of vocab, might yield
|
||||||
|
// different available block size.
|
||||||
|
let min = results
|
||||||
|
.iter()
|
||||||
|
.min()
|
||||||
|
.expect("Expect at least 1 warmup result");
|
||||||
|
Ok(*min)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate one token for each request in the given batch
|
/// Generate one token for each request in the given batch
|
||||||
|
|
|
@ -108,20 +108,22 @@ impl Client {
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
pub async fn warmup(
|
pub async fn warmup(
|
||||||
&mut self,
|
&mut self,
|
||||||
max_input_length: u32,
|
max_input_tokens: Option<u32>,
|
||||||
max_prefill_tokens: u32,
|
max_prefill_tokens: u32,
|
||||||
max_total_tokens: u32,
|
max_total_tokens: Option<u32>,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
) -> Result<Option<u32>> {
|
) -> Result<(Option<u32>, u32, u32)> {
|
||||||
let mut n_tokens = 0;
|
let mut n_tokens = 0;
|
||||||
let mut requests = Vec::new();
|
let mut requests = Vec::new();
|
||||||
// Create requests
|
// Create requests
|
||||||
while n_tokens < max_prefill_tokens {
|
while n_tokens < max_prefill_tokens {
|
||||||
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
let mut truncate = max_prefill_tokens - n_tokens;
|
||||||
|
if let Some(max_input_tokens) = max_input_tokens {
|
||||||
|
truncate = min(max_input_tokens, truncate);
|
||||||
|
}
|
||||||
|
|
||||||
let mut input_chunks = Vec::new();
|
let mut input_chunks = Vec::new();
|
||||||
input_chunks
|
input_chunks.push(Chunk::Text("_test ".to_string().repeat(truncate as usize)).into());
|
||||||
.push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
|
|
||||||
if n_tokens == 0 {
|
if n_tokens == 0 {
|
||||||
input_chunks.push(
|
input_chunks.push(
|
||||||
Chunk::Image(Image {
|
Chunk::Image(Image {
|
||||||
|
@ -137,7 +139,7 @@ impl Client {
|
||||||
// been updated to support chunks.
|
// been updated to support chunks.
|
||||||
|
|
||||||
let mut inputs = String::new();
|
let mut inputs = String::new();
|
||||||
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
inputs.push_str(&"_test ".to_string().repeat(truncate as usize));
|
||||||
if n_tokens == 0 {
|
if n_tokens == 0 {
|
||||||
// 1 request is enough to test vision heads.
|
// 1 request is enough to test vision heads.
|
||||||
// Sending images on other queries messes up easily with truncation.
|
// Sending images on other queries messes up easily with truncation.
|
||||||
|
@ -146,6 +148,12 @@ impl Client {
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let max_new_tokens = if let Some(max_total_tokens) = max_total_tokens {
|
||||||
|
max_total_tokens - truncate
|
||||||
|
} else {
|
||||||
|
1
|
||||||
|
};
|
||||||
|
|
||||||
requests.push(Request {
|
requests.push(Request {
|
||||||
id: 0,
|
id: 0,
|
||||||
inputs,
|
inputs,
|
||||||
|
@ -175,7 +183,7 @@ impl Client {
|
||||||
grammar_type: GrammarType::None as i32,
|
grammar_type: GrammarType::None as i32,
|
||||||
}),
|
}),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
max_new_tokens: max_total_tokens - truncate,
|
max_new_tokens,
|
||||||
stop_sequences: vec![],
|
stop_sequences: vec![],
|
||||||
ignore_eos_token: true,
|
ignore_eos_token: true,
|
||||||
}),
|
}),
|
||||||
|
@ -183,7 +191,7 @@ impl Client {
|
||||||
top_n_tokens: 20,
|
top_n_tokens: 20,
|
||||||
adapter_id: None,
|
adapter_id: None,
|
||||||
});
|
});
|
||||||
n_tokens += max_input_length;
|
n_tokens += truncate;
|
||||||
|
|
||||||
// Check max_batch_size
|
// Check max_batch_size
|
||||||
if Some(requests.len()) == max_batch_size {
|
if Some(requests.len()) == max_batch_size {
|
||||||
|
@ -195,19 +203,23 @@ impl Client {
|
||||||
id: 0,
|
id: 0,
|
||||||
size: requests.len() as u32,
|
size: requests.len() as u32,
|
||||||
requests,
|
requests,
|
||||||
max_tokens: max_input_length,
|
max_tokens: max_input_tokens.unwrap_or(0),
|
||||||
max_blocks: 0,
|
max_blocks: 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
let request = tonic::Request::new(WarmupRequest {
|
let request = tonic::Request::new(WarmupRequest {
|
||||||
batch: Some(batch),
|
batch: Some(batch),
|
||||||
max_input_length,
|
max_input_tokens,
|
||||||
max_prefill_tokens,
|
max_prefill_tokens,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
})
|
})
|
||||||
.inject_context();
|
.inject_context();
|
||||||
let response = self.stub.warmup(request).await?.into_inner();
|
let response = self.stub.warmup(request).await?.into_inner();
|
||||||
Ok(response.max_supported_total_tokens)
|
Ok((
|
||||||
|
response.max_supported_total_tokens,
|
||||||
|
response.max_input_tokens,
|
||||||
|
response.max_total_tokens,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate one token for each request in the given batch
|
/// Generate one token for each request in the given batch
|
||||||
|
|
|
@ -102,11 +102,11 @@ impl ShardedClient {
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
pub async fn warmup(
|
pub async fn warmup(
|
||||||
&mut self,
|
&mut self,
|
||||||
max_input_length: u32,
|
max_input_length: Option<u32>,
|
||||||
max_prefill_tokens: u32,
|
max_prefill_tokens: u32,
|
||||||
max_total_tokens: u32,
|
max_total_tokens: Option<u32>,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
) -> Result<Option<u32>> {
|
) -> Result<(Option<u32>, u32, u32)> {
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
.clients
|
.clients
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
|
@ -119,12 +119,19 @@ impl ShardedClient {
|
||||||
))
|
))
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
// Take the minimum value
|
|
||||||
let results = join_all(futures)
|
let results = join_all(futures)
|
||||||
.await
|
.await
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.collect::<Result<Vec<Option<u32>>>>()?;
|
.collect::<Result<Vec<(Option<u32>, u32, u32)>>>()?;
|
||||||
Ok(results.into_iter().flatten().min())
|
|
||||||
|
// Take the minimum value
|
||||||
|
// Different shards hold different parts of vocab, might yield
|
||||||
|
// different available block size.
|
||||||
|
let min = results
|
||||||
|
.iter()
|
||||||
|
.min()
|
||||||
|
.expect("Expect at least 1 warmup result");
|
||||||
|
Ok(*min)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate one token for each request in the given batch
|
/// Generate one token for each request in the given batch
|
||||||
|
|
|
@ -37,12 +37,17 @@ pub struct BackendInfo {
|
||||||
pub attention_impl: String,
|
pub attention_impl: String,
|
||||||
#[schema(example = "1")]
|
#[schema(example = "1")]
|
||||||
pub block_size: u32,
|
pub block_size: u32,
|
||||||
|
|
||||||
|
#[schema(example = "30000")]
|
||||||
|
pub max_input_tokens: usize,
|
||||||
|
#[schema(example = "32000")]
|
||||||
|
pub max_total_tokens: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub async fn connect_backend(
|
pub async fn connect_backend(
|
||||||
max_input_tokens: usize,
|
max_input_tokens: Option<usize>,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: Option<usize>,
|
||||||
master_shard_uds_path: String,
|
master_shard_uds_path: String,
|
||||||
waiting_served_ratio: f32,
|
waiting_served_ratio: f32,
|
||||||
max_batch_prefill_tokens: u32,
|
max_batch_prefill_tokens: u32,
|
||||||
|
@ -51,14 +56,32 @@ pub async fn connect_backend(
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
) -> Result<(BackendV3, BackendInfo), V3Error> {
|
) -> Result<(BackendV3, BackendInfo), V3Error> {
|
||||||
// Helper function
|
// Helper function
|
||||||
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
|
let check_max_batch_total_tokens = |(
|
||||||
|
max_supported_batch_total_tokens,
|
||||||
|
shard_max_input_tokens,
|
||||||
|
shard_max_total_tokens,
|
||||||
|
): (Option<u32>, u32, u32)|
|
||||||
|
-> Result<(u32, usize, usize), V3Error> {
|
||||||
|
if let Some(max_input_tokens) = max_input_tokens {
|
||||||
|
assert_eq!(max_input_tokens as u32, shard_max_input_tokens);
|
||||||
|
}
|
||||||
|
if let Some(max_total_tokens) = max_total_tokens {
|
||||||
|
assert_eq!(max_total_tokens as u32, shard_max_total_tokens);
|
||||||
|
}
|
||||||
match max_supported_batch_total_tokens {
|
match max_supported_batch_total_tokens {
|
||||||
// Older models do not support automatic max-batch-total-tokens
|
// Older models do not support automatic max-batch-total-tokens
|
||||||
None => {
|
None => {
|
||||||
let max_batch_total_tokens = max_batch_total_tokens
|
let max_batch_total_tokens = max_batch_total_tokens.unwrap_or(
|
||||||
.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)));
|
16000
|
||||||
|
.max(shard_max_total_tokens)
|
||||||
|
.max(max_batch_prefill_tokens),
|
||||||
|
);
|
||||||
tracing::warn!("Model does not support automatic max batch total tokens");
|
tracing::warn!("Model does not support automatic max batch total tokens");
|
||||||
Ok(max_batch_total_tokens)
|
Ok((
|
||||||
|
max_batch_total_tokens,
|
||||||
|
shard_max_input_tokens as usize,
|
||||||
|
shard_max_total_tokens as usize,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
// Flash attention models return their max supported total tokens
|
// Flash attention models return their max supported total tokens
|
||||||
Some(max_supported_batch_total_tokens) => {
|
Some(max_supported_batch_total_tokens) => {
|
||||||
|
@ -72,11 +95,15 @@ pub async fn connect_backend(
|
||||||
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
|
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
if max_total_tokens as u32 > max_supported_batch_total_tokens {
|
if shard_max_total_tokens > max_supported_batch_total_tokens {
|
||||||
return Err(V3Error::NotEnoughMemory(max_total_tokens));
|
return Err(V3Error::NotEnoughMemory(shard_max_total_tokens as usize));
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(max_supported_batch_total_tokens)
|
Ok((
|
||||||
|
max_supported_batch_total_tokens,
|
||||||
|
shard_max_input_tokens as usize,
|
||||||
|
shard_max_total_tokens as usize,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -96,23 +123,25 @@ pub async fn connect_backend(
|
||||||
|
|
||||||
// Warmup model
|
// Warmup model
|
||||||
tracing::info!("Warming up model");
|
tracing::info!("Warming up model");
|
||||||
let max_batch_total_tokens = check_max_batch_total_tokens(
|
let answer = sharded_client
|
||||||
sharded_client
|
|
||||||
.warmup(
|
.warmup(
|
||||||
max_input_tokens as u32,
|
max_input_tokens.map(|p| p as u32),
|
||||||
max_batch_prefill_tokens,
|
max_batch_prefill_tokens,
|
||||||
max_total_tokens as u32,
|
max_total_tokens.map(|p| p as u32),
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.map_err(V3Error::Warmup)?,
|
.map_err(V3Error::Warmup)?;
|
||||||
)?;
|
let (max_batch_total_tokens, max_input_tokens, max_total_tokens) =
|
||||||
|
check_max_batch_total_tokens(answer)?;
|
||||||
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
|
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
|
||||||
metrics::gauge!("tgi_batch_max_total_tokens").set(max_batch_total_tokens);
|
metrics::gauge!("tgi_batch_max_total_tokens").set(max_batch_total_tokens);
|
||||||
|
|
||||||
let backend_info = BackendInfo {
|
let backend_info = BackendInfo {
|
||||||
waiting_served_ratio,
|
waiting_served_ratio,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
|
max_input_tokens,
|
||||||
|
max_total_tokens,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
model_device_type: shard_info.device_type.clone(),
|
model_device_type: shard_info.device_type.clone(),
|
||||||
|
|
|
@ -18,10 +18,10 @@ struct Args {
|
||||||
max_stop_sequences: usize,
|
max_stop_sequences: usize,
|
||||||
#[clap(default_value = "5", long, env)]
|
#[clap(default_value = "5", long, env)]
|
||||||
max_top_n_tokens: u32,
|
max_top_n_tokens: u32,
|
||||||
#[clap(default_value = "1024", long, env)]
|
#[clap(long, env)]
|
||||||
max_input_tokens: usize,
|
max_input_tokens: Option<usize>,
|
||||||
#[clap(default_value = "2048", long, env)]
|
#[clap(long, env)]
|
||||||
max_total_tokens: usize,
|
max_total_tokens: Option<usize>,
|
||||||
#[clap(default_value = "1.2", long, env)]
|
#[clap(default_value = "1.2", long, env)]
|
||||||
waiting_served_ratio: f32,
|
waiting_served_ratio: f32,
|
||||||
#[clap(default_value = "4096", long, env)]
|
#[clap(default_value = "4096", long, env)]
|
||||||
|
@ -126,12 +126,6 @@ async fn main() -> Result<(), RouterError> {
|
||||||
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
|
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
|
||||||
|
|
||||||
// Validate args
|
// Validate args
|
||||||
if max_input_tokens >= max_total_tokens {
|
|
||||||
return Err(RouterError::ArgumentValidation(
|
|
||||||
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
if validation_workers == 0 {
|
if validation_workers == 0 {
|
||||||
return Err(RouterError::ArgumentValidation(
|
return Err(RouterError::ArgumentValidation(
|
||||||
"`validation_workers` must be > 0".to_string(),
|
"`validation_workers` must be > 0".to_string(),
|
||||||
|
@ -160,6 +154,28 @@ async fn main() -> Result<(), RouterError> {
|
||||||
// Validate remaining args now that the backend is known
|
// Validate remaining args now that the backend is known
|
||||||
let support_chunking = backend_info.support_chunking;
|
let support_chunking = backend_info.support_chunking;
|
||||||
let max_batch_total_tokens = backend_info.max_batch_total_tokens;
|
let max_batch_total_tokens = backend_info.max_batch_total_tokens;
|
||||||
|
|
||||||
|
if max_input_tokens.is_none() {
|
||||||
|
tracing::info!(
|
||||||
|
"Maximum input tokens defaulted to {}",
|
||||||
|
backend_info.max_input_tokens
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if max_total_tokens.is_none() {
|
||||||
|
tracing::info!(
|
||||||
|
"Maximum total tokens defaulted to {}",
|
||||||
|
backend_info.max_total_tokens
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let max_input_tokens = backend_info.max_input_tokens;
|
||||||
|
let max_total_tokens = backend_info.max_total_tokens;
|
||||||
|
if max_input_tokens >= max_total_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(
|
||||||
|
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
if max_input_tokens as u32 > max_batch_prefill_tokens && !support_chunking {
|
if max_input_tokens as u32 > max_batch_prefill_tokens && !support_chunking {
|
||||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
||||||
}
|
}
|
||||||
|
|
|
@ -146,7 +146,7 @@ Options:
|
||||||
## MAX_INPUT_TOKENS
|
## MAX_INPUT_TOKENS
|
||||||
```shell
|
```shell
|
||||||
--max-input-tokens <MAX_INPUT_TOKENS>
|
--max-input-tokens <MAX_INPUT_TOKENS>
|
||||||
This is the maximum allowed input length (expressed in number of tokens) for users. The larger this value, the longer prompt users can send which can impact the overall memory required to handle the load. Please note that some models have a finite range of sequence they can handle. Default to min(max_position_embeddings - 1, 4095)
|
This is the maximum allowed input length (expressed in number of tokens) for users. The larger this value, the longer prompt users can send which can impact the overall memory required to handle the load. Please note that some models have a finite range of sequence they can handle. Default to min(max_allocatable, max_position_embeddings) - 1
|
||||||
|
|
||||||
[env: MAX_INPUT_TOKENS=]
|
[env: MAX_INPUT_TOKENS=]
|
||||||
|
|
||||||
|
@ -162,7 +162,7 @@ Options:
|
||||||
## MAX_TOTAL_TOKENS
|
## MAX_TOTAL_TOKENS
|
||||||
```shell
|
```shell
|
||||||
--max-total-tokens <MAX_TOTAL_TOKENS>
|
--max-total-tokens <MAX_TOTAL_TOKENS>
|
||||||
This is the most important value to set as it defines the "memory budget" of running clients requests. Clients will send input sequences and ask to generate `max_new_tokens` on top. with a value of `1512` users can send either a prompt of `1000` and ask for `512` new tokens, or send a prompt of `1` and ask for `1511` max_new_tokens. The larger this value, the larger amount each request will be in your RAM and the less effective batching can be. Default to min(max_position_embeddings, 4096)
|
This is the most important value to set as it defines the "memory budget" of running clients requests. Clients will send input sequences and ask to generate `max_new_tokens` on top. with a value of `1512` users can send either a prompt of `1000` and ask for `512` new tokens, or send a prompt of `1` and ask for `1511` max_new_tokens. The larger this value, the larger amount each request will be in your RAM and the less effective batching can be. Default to min(max_allocatable, max_position_embeddings)
|
||||||
|
|
||||||
[env: MAX_TOTAL_TOKENS=]
|
[env: MAX_TOTAL_TOKENS=]
|
||||||
|
|
||||||
|
|
|
@ -472,7 +472,7 @@ struct Args {
|
||||||
/// for users. The larger this value, the longer prompt users can send which
|
/// for users. The larger this value, the longer prompt users can send which
|
||||||
/// can impact the overall memory required to handle the load.
|
/// can impact the overall memory required to handle the load.
|
||||||
/// Please note that some models have a finite range of sequence they can handle.
|
/// Please note that some models have a finite range of sequence they can handle.
|
||||||
/// Default to min(max_position_embeddings - 1, 4095)
|
/// Default to min(max_allocatable, max_position_embeddings) - 1
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
max_input_tokens: Option<usize>,
|
max_input_tokens: Option<usize>,
|
||||||
|
|
||||||
|
@ -488,7 +488,7 @@ struct Args {
|
||||||
/// `1511` max_new_tokens.
|
/// `1511` max_new_tokens.
|
||||||
/// The larger this value, the larger amount each request will be in your RAM
|
/// The larger this value, the larger amount each request will be in your RAM
|
||||||
/// and the less effective batching can be.
|
/// and the less effective batching can be.
|
||||||
/// Default to min(max_position_embeddings, 4096)
|
/// Default to min(max_allocatable, max_position_embeddings)
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
max_total_tokens: Option<usize>,
|
max_total_tokens: Option<usize>,
|
||||||
|
|
||||||
|
@ -718,9 +718,9 @@ fn shard_manager(
|
||||||
cuda_memory_fraction: f32,
|
cuda_memory_fraction: f32,
|
||||||
rope_scaling: Option<RopeScaling>,
|
rope_scaling: Option<RopeScaling>,
|
||||||
rope_factor: Option<f32>,
|
rope_factor: Option<f32>,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: Option<usize>,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
max_input_tokens: usize,
|
max_input_tokens: Option<usize>,
|
||||||
lora_adapters: Option<String>,
|
lora_adapters: Option<String>,
|
||||||
otlp_endpoint: Option<String>,
|
otlp_endpoint: Option<String>,
|
||||||
otlp_service_name: String,
|
otlp_service_name: String,
|
||||||
|
@ -805,8 +805,10 @@ fn shard_manager(
|
||||||
shard_args.push(otlp_service_name);
|
shard_args.push(otlp_service_name);
|
||||||
|
|
||||||
// In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter.
|
// In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter.
|
||||||
|
if let Some(max_input_tokens) = max_input_tokens {
|
||||||
shard_args.push("--max-input-tokens".to_string());
|
shard_args.push("--max-input-tokens".to_string());
|
||||||
shard_args.push(max_input_tokens.to_string());
|
shard_args.push(max_input_tokens.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
// Copy current process env
|
// Copy current process env
|
||||||
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
|
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
|
||||||
|
@ -854,10 +856,12 @@ fn shard_manager(
|
||||||
envs.push(("ROPE_FACTOR".into(), factor.to_string().into()));
|
envs.push(("ROPE_FACTOR".into(), factor.to_string().into()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(max_total_tokens) = max_total_tokens {
|
||||||
envs.push((
|
envs.push((
|
||||||
"MAX_TOTAL_TOKENS".into(),
|
"MAX_TOTAL_TOKENS".into(),
|
||||||
max_total_tokens.to_string().into(),
|
max_total_tokens.to_string().into(),
|
||||||
));
|
));
|
||||||
|
}
|
||||||
if let Some(max_batch_size) = max_batch_size {
|
if let Some(max_batch_size) = max_batch_size {
|
||||||
envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into()));
|
envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into()));
|
||||||
}
|
}
|
||||||
|
@ -1315,8 +1319,8 @@ fn spawn_shards(
|
||||||
num_shard: usize,
|
num_shard: usize,
|
||||||
args: &Args,
|
args: &Args,
|
||||||
cuda_graphs: Vec<usize>,
|
cuda_graphs: Vec<usize>,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: Option<usize>,
|
||||||
max_input_tokens: usize,
|
max_input_tokens: Option<usize>,
|
||||||
quantize: Option<Quantization>,
|
quantize: Option<Quantization>,
|
||||||
max_log_level: LevelFilter,
|
max_log_level: LevelFilter,
|
||||||
shutdown: Arc<AtomicBool>,
|
shutdown: Arc<AtomicBool>,
|
||||||
|
@ -1434,8 +1438,8 @@ fn compute_type(num_shard: usize) -> Option<String> {
|
||||||
fn spawn_webserver(
|
fn spawn_webserver(
|
||||||
num_shard: usize,
|
num_shard: usize,
|
||||||
args: Args,
|
args: Args,
|
||||||
max_input_tokens: usize,
|
max_input_tokens: Option<usize>,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: Option<usize>,
|
||||||
max_batch_prefill_tokens: u32,
|
max_batch_prefill_tokens: u32,
|
||||||
shutdown: Arc<AtomicBool>,
|
shutdown: Arc<AtomicBool>,
|
||||||
shutdown_receiver: &mpsc::Receiver<()>,
|
shutdown_receiver: &mpsc::Receiver<()>,
|
||||||
|
@ -1454,10 +1458,6 @@ fn spawn_webserver(
|
||||||
args.max_stop_sequences.to_string(),
|
args.max_stop_sequences.to_string(),
|
||||||
"--max-top-n-tokens".to_string(),
|
"--max-top-n-tokens".to_string(),
|
||||||
args.max_top_n_tokens.to_string(),
|
args.max_top_n_tokens.to_string(),
|
||||||
"--max-input-tokens".to_string(),
|
|
||||||
max_input_tokens.to_string(),
|
|
||||||
"--max-total-tokens".to_string(),
|
|
||||||
max_total_tokens.to_string(),
|
|
||||||
"--max-batch-prefill-tokens".to_string(),
|
"--max-batch-prefill-tokens".to_string(),
|
||||||
max_batch_prefill_tokens.to_string(),
|
max_batch_prefill_tokens.to_string(),
|
||||||
"--waiting-served-ratio".to_string(),
|
"--waiting-served-ratio".to_string(),
|
||||||
|
@ -1475,6 +1475,18 @@ fn spawn_webserver(
|
||||||
"--tokenizer-name".to_string(),
|
"--tokenizer-name".to_string(),
|
||||||
args.model_id,
|
args.model_id,
|
||||||
];
|
];
|
||||||
|
if let Some(max_input_tokens) = max_input_tokens {
|
||||||
|
router_args.extend_from_slice(&[
|
||||||
|
"--max-input-tokens".to_string(),
|
||||||
|
max_input_tokens.to_string(),
|
||||||
|
]);
|
||||||
|
}
|
||||||
|
if let Some(max_total_tokens) = max_total_tokens {
|
||||||
|
router_args.extend_from_slice(&[
|
||||||
|
"--max-total-tokens".to_string(),
|
||||||
|
max_total_tokens.to_string(),
|
||||||
|
]);
|
||||||
|
}
|
||||||
|
|
||||||
// Pass usage stats flags to router
|
// Pass usage stats flags to router
|
||||||
router_args.push("--usage-stats".to_string());
|
router_args.push("--usage-stats".to_string());
|
||||||
|
@ -1704,35 +1716,19 @@ fn main() -> Result<(), LauncherError> {
|
||||||
format!("Both `max_input_tokens` ({max_input_tokens}) and `max_input_length` ({max_input_length}) are set. Please define only `max_input_tokens` as `max_input_length is deprecated for naming consistency.",
|
format!("Both `max_input_tokens` ({max_input_tokens}) and `max_input_length` ({max_input_length}) are set. Please define only `max_input_tokens` as `max_input_length is deprecated for naming consistency.",
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
(Some(max_input_tokens), None) | (None, Some(max_input_tokens)) => max_input_tokens,
|
(Some(max_input_tokens), None) | (None, Some(max_input_tokens)) => {
|
||||||
(None, None) => {
|
Some(max_input_tokens)
|
||||||
let value = max_position_embeddings - 1;
|
|
||||||
tracing::info!("Default `max_input_tokens` to {value}");
|
|
||||||
value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let max_total_tokens = {
|
|
||||||
match args.max_total_tokens {
|
|
||||||
Some(max_total_tokens) => max_total_tokens,
|
|
||||||
None => {
|
|
||||||
let value = max_position_embeddings;
|
|
||||||
tracing::info!("Default `max_total_tokens` to {value}");
|
|
||||||
value
|
|
||||||
}
|
}
|
||||||
|
(None, None) => None,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
let max_total_tokens = args.max_total_tokens;
|
||||||
let max_batch_prefill_tokens = {
|
let max_batch_prefill_tokens = {
|
||||||
match args.max_batch_prefill_tokens {
|
match args.max_batch_prefill_tokens {
|
||||||
Some(max_batch_prefill_tokens) => max_batch_prefill_tokens,
|
Some(max_batch_prefill_tokens) => max_batch_prefill_tokens,
|
||||||
None => {
|
None => {
|
||||||
let value: u32 = if let Some(max_batch_size) = args.max_batch_size {
|
// TODO figure out hardware optimal value
|
||||||
max_batch_size * max_input_tokens
|
let value = 4096.min(max_position_embeddings as u32);
|
||||||
} else {
|
|
||||||
// Adding some edge in order to account for potential block_size alignement
|
|
||||||
// issue.
|
|
||||||
max_input_tokens + 50
|
|
||||||
} as u32;
|
|
||||||
tracing::info!("Default `max_batch_prefill_tokens` to {value}");
|
tracing::info!("Default `max_batch_prefill_tokens` to {value}");
|
||||||
value
|
value
|
||||||
}
|
}
|
||||||
|
@ -1740,11 +1736,13 @@ fn main() -> Result<(), LauncherError> {
|
||||||
};
|
};
|
||||||
|
|
||||||
// Validate args
|
// Validate args
|
||||||
|
if let (Some(max_input_tokens), Some(max_total_tokens)) = (max_input_tokens, max_total_tokens) {
|
||||||
if max_input_tokens >= max_total_tokens {
|
if max_input_tokens >= max_total_tokens {
|
||||||
return Err(LauncherError::ArgumentValidation(
|
return Err(LauncherError::ArgumentValidation(
|
||||||
"`max_input_tokens must be < `max_total_tokens`".to_string(),
|
format!("`max_input_tokens`({max_input_tokens}) must be < `max_total_tokens`({max_total_tokens})"),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if matches!(args.quantize, Some(Quantization::Bitsandbytes)) {
|
if matches!(args.quantize, Some(Quantization::Bitsandbytes)) {
|
||||||
tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases.");
|
tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases.");
|
||||||
|
@ -1798,6 +1796,7 @@ fn main() -> Result<(), LauncherError> {
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
|
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
|
||||||
|
if let Some(max_total_tokens) = max_total_tokens {
|
||||||
if max_total_tokens as u32 > *max_batch_total_tokens {
|
if max_total_tokens as u32 > *max_batch_total_tokens {
|
||||||
return Err(LauncherError::ArgumentValidation(format!(
|
return Err(LauncherError::ArgumentValidation(format!(
|
||||||
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
||||||
|
@ -1805,6 +1804,7 @@ fn main() -> Result<(), LauncherError> {
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if args.ngrok {
|
if args.ngrok {
|
||||||
if args.ngrok_authtoken.is_none() {
|
if args.ngrok_authtoken.is_none() {
|
||||||
|
|
|
@ -272,12 +272,18 @@ message DecodeResponse {
|
||||||
message WarmupRequest {
|
message WarmupRequest {
|
||||||
/// Batch to warmup on
|
/// Batch to warmup on
|
||||||
Batch batch = 1;
|
Batch batch = 1;
|
||||||
uint32 max_input_length = 2;
|
optional uint32 max_input_tokens = 2;
|
||||||
uint32 max_prefill_tokens = 3;
|
uint32 max_prefill_tokens = 3;
|
||||||
uint32 max_total_tokens = 4;
|
optional uint32 max_total_tokens = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
message WarmupResponse {
|
message WarmupResponse {
|
||||||
/// Maximum number of tokens supported by the model
|
/// Maximum number of tokens supported by the model
|
||||||
optional uint32 max_supported_total_tokens = 1;
|
optional uint32 max_supported_total_tokens = 1;
|
||||||
|
/// Maximum input tokens by clients should be equal to request value if it's set
|
||||||
|
/// Otherwise warmup automatically allocates a value here
|
||||||
|
uint32 max_input_tokens = 2;
|
||||||
|
/// Maximum total tokens by clients should be equal to request value if it's set
|
||||||
|
/// Otherwise warmup automatically allocates a value here
|
||||||
|
uint32 max_total_tokens = 3;
|
||||||
}
|
}
|
||||||
|
|
|
@ -86,6 +86,10 @@ tracer = trace.get_tracer(__name__)
|
||||||
SLIDING_WINDOW: Optional[int] = None
|
SLIDING_WINDOW: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
def small_power_of_2(n: int):
|
||||||
|
return 1 << ((n - 1).bit_length() - 1)
|
||||||
|
|
||||||
|
|
||||||
def set_sliding_window(sliding_window: int):
|
def set_sliding_window(sliding_window: int):
|
||||||
global SLIDING_WINDOW
|
global SLIDING_WINDOW
|
||||||
SLIDING_WINDOW = sliding_window
|
SLIDING_WINDOW = sliding_window
|
||||||
|
@ -1495,11 +1499,22 @@ class FlashCausalLM(Model):
|
||||||
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
|
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
def warmup(self, batch: FlashCausalLMBatch):
|
def warmup(
|
||||||
|
self,
|
||||||
|
batch: FlashCausalLMBatch,
|
||||||
|
max_input_tokens: Optional[int],
|
||||||
|
max_total_tokens: Optional[int],
|
||||||
|
):
|
||||||
# The warmup batch is the biggest batch we could ever receive
|
# The warmup batch is the biggest batch we could ever receive
|
||||||
self.kv_cache = []
|
self.kv_cache = []
|
||||||
empty_cache()
|
empty_cache()
|
||||||
|
|
||||||
|
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
||||||
|
# Calculate the number of blocks that can be allocated with the free memory
|
||||||
|
dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
|
||||||
|
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
||||||
|
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.init_kv_cache(
|
self.init_kv_cache(
|
||||||
batch.num_blocks,
|
batch.num_blocks,
|
||||||
|
@ -1511,10 +1526,11 @@ class FlashCausalLM(Model):
|
||||||
)
|
)
|
||||||
max_bt = batch.max_blocks
|
max_bt = batch.max_blocks
|
||||||
max_s = max_bt * BLOCK_SIZE
|
max_s = max_bt * BLOCK_SIZE
|
||||||
|
batch_num_blocks = batch.num_blocks
|
||||||
|
|
||||||
if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
|
if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
|
||||||
torch.cuda.tunable.tuning_enable(False)
|
torch.cuda.tunable.tuning_enable(False)
|
||||||
_, batch, _ = self.generate_token(batch)
|
_, _batch, _ = self.generate_token(batch)
|
||||||
except torch.cuda.OutOfMemoryError as e:
|
except torch.cuda.OutOfMemoryError as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Not enough memory to handle {batch.to_pb().current_tokens} prefill tokens. "
|
f"Not enough memory to handle {batch.to_pb().current_tokens} prefill tokens. "
|
||||||
|
@ -1523,14 +1539,7 @@ class FlashCausalLM(Model):
|
||||||
|
|
||||||
synchronize(self.device)
|
synchronize(self.device)
|
||||||
|
|
||||||
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
|
||||||
# Calculate the number of blocks that can be allocated with the free memory
|
|
||||||
dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
|
|
||||||
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
|
||||||
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
|
||||||
|
|
||||||
free_memory = get_free_memory(self.device, MEMORY_FRACTION)
|
free_memory = get_free_memory(self.device, MEMORY_FRACTION)
|
||||||
batch_num_blocks = batch.num_blocks if batch is not None else 0
|
|
||||||
|
|
||||||
num_blocks = (
|
num_blocks = (
|
||||||
# Leave 5% for some wiggle room
|
# Leave 5% for some wiggle room
|
||||||
|
@ -1540,8 +1549,27 @@ class FlashCausalLM(Model):
|
||||||
)
|
)
|
||||||
|
|
||||||
log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}")
|
log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}")
|
||||||
|
if max_total_tokens is None:
|
||||||
|
if get_support_chunking():
|
||||||
|
model_max_length = self.tokenizer.model_max_length
|
||||||
|
max_input_tokens = (
|
||||||
|
min((num_blocks * BLOCK_SIZE - 1), model_max_length)
|
||||||
|
if max_input_tokens is None
|
||||||
|
else max_input_tokens
|
||||||
|
)
|
||||||
|
max_total_tokens = num_blocks * BLOCK_SIZE
|
||||||
|
|
||||||
del batch
|
else:
|
||||||
|
max_total_tokens = sum(batch.cache_lengths)
|
||||||
|
max_input_tokens = (
|
||||||
|
max_total_tokens - 1
|
||||||
|
if max_input_tokens is None
|
||||||
|
else max_input_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
del _batch, batch
|
||||||
|
self.kv_cache = []
|
||||||
|
empty_cache()
|
||||||
|
|
||||||
self.init_kv_cache(
|
self.init_kv_cache(
|
||||||
num_blocks,
|
num_blocks,
|
||||||
|
@ -1623,7 +1651,9 @@ class FlashCausalLM(Model):
|
||||||
logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})."
|
logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})."
|
||||||
)
|
)
|
||||||
|
|
||||||
return int(num_blocks * BLOCK_SIZE)
|
assert max_input_tokens is not None
|
||||||
|
assert max_total_tokens is not None
|
||||||
|
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
|
||||||
|
|
||||||
def tunableop_warmup(self, seqlen: int):
|
def tunableop_warmup(self, seqlen: int):
|
||||||
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
|
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
from text_generation_server.models.custom_modeling.mamba_modeling import (
|
from text_generation_server.models.custom_modeling.mamba_modeling import (
|
||||||
MambaConfig,
|
MambaConfig,
|
||||||
)
|
)
|
||||||
|
@ -475,7 +475,9 @@ class Mamba(Model):
|
||||||
def batch_type(self) -> Type[MambaBatch]:
|
def batch_type(self) -> Type[MambaBatch]:
|
||||||
return MambaBatch
|
return MambaBatch
|
||||||
|
|
||||||
def warmup(self, batch) -> Optional[int]:
|
def warmup(
|
||||||
|
self, batch, max_input_tokens: Optional[int], max_total_tokens: Optional[int]
|
||||||
|
) -> Union[Optional[int], Optional[int], Optional[int]]:
|
||||||
# TODO: implement warmup for Mamba if needed
|
# TODO: implement warmup for Mamba if needed
|
||||||
if CUDA_GRAPHS:
|
if CUDA_GRAPHS:
|
||||||
if self.speculate is None or self.speculate == 0:
|
if self.speculate is None or self.speculate == 0:
|
||||||
|
@ -489,7 +491,12 @@ class Mamba(Model):
|
||||||
else:
|
else:
|
||||||
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
|
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
|
||||||
|
|
||||||
return None
|
if max_total_tokens is None:
|
||||||
|
max_total_tokens = min(self.tokenizer.model_max_length, 4096)
|
||||||
|
|
||||||
|
if max_input_tokens is None:
|
||||||
|
max_input_tokens = max_total_tokens - 1
|
||||||
|
return None, max_input_tokens, max_total_tokens
|
||||||
|
|
||||||
def cuda_graph_warmup(self, batch_size: int):
|
def cuda_graph_warmup(self, batch_size: int):
|
||||||
input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device)
|
input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device)
|
||||||
|
|
|
@ -128,9 +128,17 @@ class Model(ABC):
|
||||||
) -> Tuple[List[Generation], Optional[B], Tuple[int, int]]:
|
) -> Tuple[List[Generation], Optional[B], Tuple[int, int]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def warmup(self, batch: B) -> Optional[int]:
|
def warmup(
|
||||||
|
self, batch: B, max_input_tokens: Optional[int], max_total_tokens: Optional[int]
|
||||||
|
) -> Tuple[Optional[int], int, int]:
|
||||||
self.generate_token(batch)
|
self.generate_token(batch)
|
||||||
return None
|
total = sum(len(i) for i in batch.input_ids)
|
||||||
|
if max_total_tokens is None:
|
||||||
|
max_total_tokens = total
|
||||||
|
|
||||||
|
if max_input_tokens is None:
|
||||||
|
max_input_tokens = max_total_tokens - 1
|
||||||
|
return None, max_input_tokens, max_total_tokens
|
||||||
|
|
||||||
def decode_token(
|
def decode_token(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -132,10 +132,22 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
batch = self.model.batch_type.from_pb(
|
batch = self.model.batch_type.from_pb(
|
||||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||||
)
|
)
|
||||||
max_supported_total_tokens = self.model.warmup(batch)
|
|
||||||
|
# Override default values with None for clearer semantics.
|
||||||
|
max_input_tokens = (
|
||||||
|
request.max_input_tokens if request.HasField("max_input_tokens") else None
|
||||||
|
)
|
||||||
|
max_total_tokens = (
|
||||||
|
request.max_total_tokens if request.HasField("max_total_tokens") else None
|
||||||
|
)
|
||||||
|
max_supported_total_tokens, max_input_tokens, max_total_tokens = (
|
||||||
|
self.model.warmup(batch, max_input_tokens, max_total_tokens)
|
||||||
|
)
|
||||||
|
|
||||||
return generate_pb2.WarmupResponse(
|
return generate_pb2.WarmupResponse(
|
||||||
max_supported_total_tokens=max_supported_total_tokens
|
max_supported_total_tokens=max_supported_total_tokens,
|
||||||
|
max_input_tokens=max_input_tokens,
|
||||||
|
max_total_tokens=max_total_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def Prefill(self, request, context):
|
async def Prefill(self, request, context):
|
||||||
|
|
Loading…
Reference in New Issue