feat(server): auto max_batch_total_tokens for flash att models (#630)
This commit is contained in:
parent
5e6ddfd6a4
commit
fe80f5360c
|
@ -184,8 +184,8 @@ struct Args {
|
||||||
/// depends on other parameters like if you're using quantization, flash attention
|
/// depends on other parameters like if you're using quantization, flash attention
|
||||||
/// or the model implementation, text-generation-inference cannot infer this number
|
/// or the model implementation, text-generation-inference cannot infer this number
|
||||||
/// automatically.
|
/// automatically.
|
||||||
#[clap(default_value = "16000", long, env)]
|
#[clap(long, env)]
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: Option<u32>,
|
||||||
|
|
||||||
/// This setting defines how many tokens can be passed before forcing the waiting
|
/// This setting defines how many tokens can be passed before forcing the waiting
|
||||||
/// queries to be put on the batch (if the size of the batch allows for it).
|
/// queries to be put on the batch (if the size of the batch allows for it).
|
||||||
|
@ -369,12 +369,6 @@ fn shard_manager(
|
||||||
// Copy current process env
|
// Copy current process env
|
||||||
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
|
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
|
||||||
|
|
||||||
// Use cuda allocator. It leads to less memory fragmentation
|
|
||||||
envs.push((
|
|
||||||
"PYTORCH_CUDA_ALLOC_CONF".into(),
|
|
||||||
"backend:cudaMallocAsync".into(),
|
|
||||||
));
|
|
||||||
|
|
||||||
// Torch Distributed Env vars
|
// Torch Distributed Env vars
|
||||||
envs.push(("RANK".into(), rank.to_string().into()));
|
envs.push(("RANK".into(), rank.to_string().into()));
|
||||||
envs.push(("WORLD_SIZE".into(), world_size.to_string().into()));
|
envs.push(("WORLD_SIZE".into(), world_size.to_string().into()));
|
||||||
|
@ -428,7 +422,7 @@ fn shard_manager(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start process
|
// Start process
|
||||||
tracing::info!("Starting shard {rank}");
|
tracing::info!("Starting shard");
|
||||||
let mut p = match Command::new("text-generation-server")
|
let mut p = match Command::new("text-generation-server")
|
||||||
.args(shard_args)
|
.args(shard_args)
|
||||||
.envs(envs)
|
.envs(envs)
|
||||||
|
@ -493,17 +487,17 @@ fn shard_manager(
|
||||||
if shutdown.load(Ordering::SeqCst) {
|
if shutdown.load(Ordering::SeqCst) {
|
||||||
p.kill().unwrap();
|
p.kill().unwrap();
|
||||||
let _ = p.wait();
|
let _ = p.wait();
|
||||||
tracing::info!("Shard {rank} terminated");
|
tracing::info!("Shard terminated");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shard is ready
|
// Shard is ready
|
||||||
if uds.exists() && !ready {
|
if uds.exists() && !ready {
|
||||||
tracing::info!("Shard {rank} ready in {:?}", start_time.elapsed());
|
tracing::info!("Shard ready in {:?}", start_time.elapsed());
|
||||||
status_sender.send(ShardStatus::Ready).unwrap();
|
status_sender.send(ShardStatus::Ready).unwrap();
|
||||||
ready = true;
|
ready = true;
|
||||||
} else if !ready && wait_time.elapsed() > Duration::from_secs(10) {
|
} else if !ready && wait_time.elapsed() > Duration::from_secs(10) {
|
||||||
tracing::info!("Waiting for shard {rank} to be ready...");
|
tracing::info!("Waiting for shard to be ready...");
|
||||||
wait_time = Instant::now();
|
wait_time = Instant::now();
|
||||||
}
|
}
|
||||||
sleep(Duration::from_millis(100));
|
sleep(Duration::from_millis(100));
|
||||||
|
@ -860,8 +854,6 @@ fn spawn_webserver(
|
||||||
args.max_total_tokens.to_string(),
|
args.max_total_tokens.to_string(),
|
||||||
"--max-batch-prefill-tokens".to_string(),
|
"--max-batch-prefill-tokens".to_string(),
|
||||||
args.max_batch_prefill_tokens.to_string(),
|
args.max_batch_prefill_tokens.to_string(),
|
||||||
"--max-batch-total-tokens".to_string(),
|
|
||||||
args.max_batch_total_tokens.to_string(),
|
|
||||||
"--waiting-served-ratio".to_string(),
|
"--waiting-served-ratio".to_string(),
|
||||||
args.waiting_served_ratio.to_string(),
|
args.waiting_served_ratio.to_string(),
|
||||||
"--max-waiting-tokens".to_string(),
|
"--max-waiting-tokens".to_string(),
|
||||||
|
@ -878,6 +870,12 @@ fn spawn_webserver(
|
||||||
args.model_id,
|
args.model_id,
|
||||||
];
|
];
|
||||||
|
|
||||||
|
// Model optional max batch total tokens
|
||||||
|
if let Some(max_batch_total_tokens) = args.max_batch_total_tokens {
|
||||||
|
router_args.push("--max-batch-total-tokens".to_string());
|
||||||
|
router_args.push(max_batch_total_tokens.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
// Model optional revision
|
// Model optional revision
|
||||||
if let Some(ref revision) = args.revision {
|
if let Some(ref revision) = args.revision {
|
||||||
router_args.push("--revision".to_string());
|
router_args.push("--revision".to_string());
|
||||||
|
@ -1036,18 +1034,7 @@ fn main() -> Result<(), LauncherError> {
|
||||||
args.max_batch_prefill_tokens, args.max_input_length
|
args.max_batch_prefill_tokens, args.max_input_length
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
if args.max_batch_prefill_tokens > args.max_batch_total_tokens {
|
|
||||||
return Err(LauncherError::ArgumentValidation(format!(
|
|
||||||
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
|
||||||
args.max_batch_prefill_tokens, args.max_batch_total_tokens
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
if args.max_total_tokens as u32 > args.max_batch_total_tokens {
|
|
||||||
return Err(LauncherError::ArgumentValidation(format!(
|
|
||||||
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
|
||||||
args.max_total_tokens, args.max_batch_total_tokens
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
if args.validation_workers == 0 {
|
if args.validation_workers == 0 {
|
||||||
return Err(LauncherError::ArgumentValidation(
|
return Err(LauncherError::ArgumentValidation(
|
||||||
"`validation_workers` must be > 0".to_string(),
|
"`validation_workers` must be > 0".to_string(),
|
||||||
|
@ -1065,6 +1052,21 @@ fn main() -> Result<(), LauncherError> {
|
||||||
tracing::info!("Sharding model on {num_shard} processes");
|
tracing::info!("Sharding model on {num_shard} processes");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
|
||||||
|
if args.max_batch_prefill_tokens > *max_batch_total_tokens {
|
||||||
|
return Err(LauncherError::ArgumentValidation(format!(
|
||||||
|
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
||||||
|
args.max_batch_prefill_tokens, max_batch_total_tokens
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
if args.max_total_tokens as u32 > *max_batch_total_tokens {
|
||||||
|
return Err(LauncherError::ArgumentValidation(format!(
|
||||||
|
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
||||||
|
args.max_total_tokens, max_batch_total_tokens
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Signal handler
|
// Signal handler
|
||||||
let running = Arc::new(AtomicBool::new(true));
|
let running = Arc::new(AtomicBool::new(true));
|
||||||
let r = running.clone();
|
let r = running.clone();
|
||||||
|
|
|
@ -198,9 +198,10 @@ message DecodeResponse {
|
||||||
message WarmupRequest {
|
message WarmupRequest {
|
||||||
/// Batch to warmup on
|
/// Batch to warmup on
|
||||||
Batch batch = 1;
|
Batch batch = 1;
|
||||||
/// Maximum number of tokens that the client will send
|
|
||||||
uint32 max_total_tokens = 2;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Empty response
|
/// Empty response
|
||||||
message WarmupResponse {}
|
message WarmupResponse {
|
||||||
|
/// Maximum number of tokens supported by the model
|
||||||
|
optional uint32 max_supported_total_tokens = 1;
|
||||||
|
}
|
||||||
|
|
|
@ -103,8 +103,7 @@ impl Client {
|
||||||
&mut self,
|
&mut self,
|
||||||
max_input_length: u32,
|
max_input_length: u32,
|
||||||
max_prefill_tokens: u32,
|
max_prefill_tokens: u32,
|
||||||
max_total_tokens: u32,
|
) -> Result<Option<u32>> {
|
||||||
) -> Result<()> {
|
|
||||||
let mut n_tokens = 0;
|
let mut n_tokens = 0;
|
||||||
let mut requests = Vec::new();
|
let mut requests = Vec::new();
|
||||||
|
|
||||||
|
@ -143,13 +142,9 @@ impl Client {
|
||||||
max_tokens: 0,
|
max_tokens: 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
let request = tonic::Request::new(WarmupRequest {
|
let request = tonic::Request::new(WarmupRequest { batch: Some(batch) }).inject_context();
|
||||||
batch: Some(batch),
|
let response = self.stub.warmup(request).await?.into_inner();
|
||||||
max_total_tokens,
|
Ok(response.max_supported_total_tokens)
|
||||||
})
|
|
||||||
.inject_context();
|
|
||||||
self.stub.warmup(request).await?.into_inner();
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate one token for each request in the given batch
|
/// Generate one token for each request in the given batch
|
||||||
|
|
|
@ -95,14 +95,11 @@ impl ShardedClient {
|
||||||
&mut self,
|
&mut self,
|
||||||
max_input_length: u32,
|
max_input_length: u32,
|
||||||
max_prefill_tokens: u32,
|
max_prefill_tokens: u32,
|
||||||
max_total_tokens: u32,
|
) -> Result<Option<u32>> {
|
||||||
) -> Result<()> {
|
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
.clients
|
.clients
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.map(|client| {
|
.map(|client| Box::pin(client.warmup(max_input_length, max_prefill_tokens)))
|
||||||
Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens))
|
|
||||||
})
|
|
||||||
.collect();
|
.collect();
|
||||||
// all shards return the same message
|
// all shards return the same message
|
||||||
join_all(futures).await.pop().unwrap()
|
join_all(futures).await.pop().unwrap()
|
||||||
|
|
|
@ -53,7 +53,7 @@ impl Infer {
|
||||||
generation_health: Arc<AtomicBool>,
|
generation_health: Arc<AtomicBool>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Infer shared state
|
// Infer shared state
|
||||||
let queue = Queue::new(requires_padding);
|
let queue = Queue::new(requires_padding, 16);
|
||||||
let shared = Arc::new(Shared {
|
let shared = Arc::new(Shared {
|
||||||
batching_task: Notify::new(),
|
batching_task: Notify::new(),
|
||||||
});
|
});
|
||||||
|
|
|
@ -37,8 +37,8 @@ struct Args {
|
||||||
waiting_served_ratio: f32,
|
waiting_served_ratio: f32,
|
||||||
#[clap(default_value = "4096", long, env)]
|
#[clap(default_value = "4096", long, env)]
|
||||||
max_batch_prefill_tokens: u32,
|
max_batch_prefill_tokens: u32,
|
||||||
#[clap(default_value = "16000", long, env)]
|
#[clap(long, env)]
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: Option<u32>,
|
||||||
#[clap(default_value = "20", long, env)]
|
#[clap(default_value = "20", long, env)]
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
#[clap(default_value = "0.0.0.0", long, env)]
|
#[clap(default_value = "0.0.0.0", long, env)]
|
||||||
|
@ -110,18 +110,22 @@ fn main() -> Result<(), RouterError> {
|
||||||
if max_input_length as u32 > max_batch_prefill_tokens {
|
if max_input_length as u32 > max_batch_prefill_tokens {
|
||||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}")));
|
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}")));
|
||||||
}
|
}
|
||||||
if max_batch_prefill_tokens > max_batch_total_tokens {
|
|
||||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
|
||||||
}
|
|
||||||
if max_total_tokens as u32 > max_batch_total_tokens {
|
|
||||||
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
|
||||||
}
|
|
||||||
if validation_workers == 0 {
|
if validation_workers == 0 {
|
||||||
return Err(RouterError::ArgumentValidation(
|
return Err(RouterError::ArgumentValidation(
|
||||||
"`validation_workers` must be > 0".to_string(),
|
"`validation_workers` must be > 0".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
|
||||||
|
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
||||||
|
}
|
||||||
|
if max_total_tokens as u32 > *max_batch_total_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// CORS allowed origins
|
// CORS allowed origins
|
||||||
// map to go inside the option and then map to parse from String to HeaderValue
|
// map to go inside the option and then map to parse from String to HeaderValue
|
||||||
// Finally, convert to AllowOrigin
|
// Finally, convert to AllowOrigin
|
||||||
|
@ -210,14 +214,35 @@ fn main() -> Result<(), RouterError> {
|
||||||
|
|
||||||
// Warmup model
|
// Warmup model
|
||||||
tracing::info!("Warming up model");
|
tracing::info!("Warming up model");
|
||||||
sharded_client
|
let max_supported_batch_total_tokens = match sharded_client
|
||||||
.warmup(
|
.warmup(max_input_length as u32, max_batch_prefill_tokens)
|
||||||
max_input_length as u32,
|
|
||||||
max_batch_prefill_tokens,
|
|
||||||
max_batch_total_tokens,
|
|
||||||
)
|
|
||||||
.await
|
.await
|
||||||
.map_err(RouterError::Warmup)?;
|
.map_err(RouterError::Warmup)?
|
||||||
|
{
|
||||||
|
// Older models do not support automatic max-batch-total-tokens
|
||||||
|
None => {
|
||||||
|
let max_batch_total_tokens = max_batch_total_tokens.unwrap_or(
|
||||||
|
16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)),
|
||||||
|
);
|
||||||
|
tracing::warn!("Model does not support automatic max batch total tokens");
|
||||||
|
max_batch_total_tokens
|
||||||
|
}
|
||||||
|
// Flash attention models return their max supported total tokens
|
||||||
|
Some(max_supported_batch_total_tokens) => {
|
||||||
|
// Warn if user added his own max-batch-total-tokens as we will ignore it
|
||||||
|
if max_batch_total_tokens.is_some() {
|
||||||
|
tracing::warn!(
|
||||||
|
"`--max-batch-total-tokens` is deprecated for Flash \
|
||||||
|
Attention models."
|
||||||
|
);
|
||||||
|
tracing::warn!(
|
||||||
|
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
max_supported_batch_total_tokens
|
||||||
|
}
|
||||||
|
};
|
||||||
|
tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}");
|
||||||
tracing::info!("Connected");
|
tracing::info!("Connected");
|
||||||
|
|
||||||
let addr = match hostname.parse() {
|
let addr = match hostname.parse() {
|
||||||
|
@ -240,7 +265,7 @@ fn main() -> Result<(), RouterError> {
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
waiting_served_ratio,
|
waiting_served_ratio,
|
||||||
max_batch_prefill_tokens,
|
max_batch_prefill_tokens,
|
||||||
max_batch_total_tokens,
|
max_supported_batch_total_tokens,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
sharded_client,
|
sharded_client,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
|
|
@ -33,12 +33,12 @@ pub(crate) struct Queue {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Queue {
|
impl Queue {
|
||||||
pub(crate) fn new(requires_padding: bool) -> Self {
|
pub(crate) fn new(requires_padding: bool, block_size: u32) -> Self {
|
||||||
// Create channel
|
// Create channel
|
||||||
let (queue_sender, queue_receiver) = flume::unbounded();
|
let (queue_sender, queue_receiver) = flume::unbounded();
|
||||||
|
|
||||||
// Launch background queue task
|
// Launch background queue task
|
||||||
tokio::spawn(queue_task(requires_padding, queue_receiver));
|
tokio::spawn(queue_task(requires_padding, block_size, queue_receiver));
|
||||||
|
|
||||||
Self { queue_sender }
|
Self { queue_sender }
|
||||||
}
|
}
|
||||||
|
@ -81,8 +81,12 @@ impl Queue {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Background task responsible of the queue state
|
// Background task responsible of the queue state
|
||||||
async fn queue_task(requires_padding: bool, receiver: flume::Receiver<QueueCommand>) {
|
async fn queue_task(
|
||||||
let mut state = State::new(requires_padding);
|
requires_padding: bool,
|
||||||
|
block_size: u32,
|
||||||
|
receiver: flume::Receiver<QueueCommand>,
|
||||||
|
) {
|
||||||
|
let mut state = State::new(requires_padding, block_size);
|
||||||
|
|
||||||
while let Ok(cmd) = receiver.recv_async().await {
|
while let Ok(cmd) = receiver.recv_async().await {
|
||||||
match cmd {
|
match cmd {
|
||||||
|
@ -119,15 +123,19 @@ struct State {
|
||||||
|
|
||||||
/// Whether the model is using padding
|
/// Whether the model is using padding
|
||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
|
|
||||||
|
/// Paged Attention block size
|
||||||
|
block_size: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl State {
|
impl State {
|
||||||
fn new(requires_padding: bool) -> Self {
|
fn new(requires_padding: bool, block_size: u32) -> Self {
|
||||||
Self {
|
Self {
|
||||||
entries: VecDeque::with_capacity(128),
|
entries: VecDeque::with_capacity(128),
|
||||||
next_id: 0,
|
next_id: 0,
|
||||||
next_batch_id: 0,
|
next_batch_id: 0,
|
||||||
requires_padding,
|
requires_padding,
|
||||||
|
block_size,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -187,10 +195,21 @@ impl State {
|
||||||
max_input_length = max_input_length.max(entry.request.input_length);
|
max_input_length = max_input_length.max(entry.request.input_length);
|
||||||
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length
|
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length
|
||||||
} else {
|
} else {
|
||||||
prefill_tokens += entry.request.input_length;
|
// pad to block size
|
||||||
|
prefill_tokens += ((entry.request.input_length + self.block_size - 1)
|
||||||
|
/ self.block_size)
|
||||||
|
* self.block_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
if self.requires_padding {
|
||||||
|
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
||||||
|
} else {
|
||||||
|
// pad to block size
|
||||||
|
decode_tokens +=
|
||||||
|
((entry.request.stopping_parameters.max_new_tokens + self.block_size - 1)
|
||||||
|
/ self.block_size)
|
||||||
|
* self.block_size;
|
||||||
|
}
|
||||||
|
|
||||||
if prefill_tokens > prefill_token_budget
|
if prefill_tokens > prefill_token_budget
|
||||||
|| (prefill_tokens + decode_tokens) > token_budget
|
|| (prefill_tokens + decode_tokens) > token_budget
|
||||||
|
@ -321,7 +340,7 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_append() {
|
fn test_append() {
|
||||||
let mut state = State::new(false);
|
let mut state = State::new(false, 1);
|
||||||
let (entry, _guard) = default_entry();
|
let (entry, _guard) = default_entry();
|
||||||
|
|
||||||
assert_eq!(state.next_id, 0);
|
assert_eq!(state.next_id, 0);
|
||||||
|
@ -337,7 +356,7 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_next_batch_empty() {
|
fn test_next_batch_empty() {
|
||||||
let mut state = State::new(false);
|
let mut state = State::new(false, 1);
|
||||||
|
|
||||||
assert!(state.next_batch(None, 1, 1).is_none());
|
assert!(state.next_batch(None, 1, 1).is_none());
|
||||||
assert!(state.next_batch(Some(1), 1, 1).is_none());
|
assert!(state.next_batch(Some(1), 1, 1).is_none());
|
||||||
|
@ -345,7 +364,7 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_next_batch_min_size() {
|
fn test_next_batch_min_size() {
|
||||||
let mut state = State::new(false);
|
let mut state = State::new(false, 1);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
|
@ -377,7 +396,7 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_next_batch_token_budget() {
|
fn test_next_batch_token_budget() {
|
||||||
let mut state = State::new(false);
|
let mut state = State::new(false, 1);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
|
@ -410,14 +429,14 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_append() {
|
async fn test_queue_append() {
|
||||||
let queue = Queue::new(false);
|
let queue = Queue::new(false, 1);
|
||||||
let (entry, _guard) = default_entry();
|
let (entry, _guard) = default_entry();
|
||||||
queue.append(entry);
|
queue.append(entry);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_empty() {
|
async fn test_queue_next_batch_empty() {
|
||||||
let queue = Queue::new(false);
|
let queue = Queue::new(false, 1);
|
||||||
|
|
||||||
assert!(queue.next_batch(None, 1, 1).await.is_none());
|
assert!(queue.next_batch(None, 1, 1).await.is_none());
|
||||||
assert!(queue.next_batch(Some(1), 1, 1).await.is_none());
|
assert!(queue.next_batch(Some(1), 1, 1).await.is_none());
|
||||||
|
@ -425,7 +444,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_min_size() {
|
async fn test_queue_next_batch_min_size() {
|
||||||
let queue = Queue::new(false);
|
let queue = Queue::new(false, 1);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
|
@ -458,7 +477,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_token_budget() {
|
async fn test_queue_next_batch_token_budget() {
|
||||||
let queue = Queue::new(false);
|
let queue = Queue::new(false, 1);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
|
@ -483,7 +502,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_dropped_receiver() {
|
async fn test_queue_next_batch_dropped_receiver() {
|
||||||
let queue = Queue::new(false);
|
let queue = Queue::new(false, 1);
|
||||||
let (entry, _) = default_entry();
|
let (entry, _) = default_entry();
|
||||||
queue.append(entry);
|
queue.append(entry);
|
||||||
|
|
||||||
|
|
|
@ -710,14 +710,14 @@ class FlashCausalLM(Model):
|
||||||
def batch_type(self) -> Type[FlashCausalLMBatch]:
|
def batch_type(self) -> Type[FlashCausalLMBatch]:
|
||||||
return FlashCausalLMBatch
|
return FlashCausalLMBatch
|
||||||
|
|
||||||
def warmup(self, batch: FlashCausalLMBatch, max_total_tokens: int):
|
def warmup(self, batch: FlashCausalLMBatch):
|
||||||
global CACHE_MANAGER
|
global CACHE_MANAGER
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.reset_peak_memory_stats(self.device)
|
||||||
try:
|
try:
|
||||||
CACHE_MANAGER = CacheManager(
|
CACHE_MANAGER = CacheManager(
|
||||||
# Adds some wiggle room
|
batch.blocks,
|
||||||
math.ceil(max_total_tokens / BLOCK_SIZE) + 10,
|
|
||||||
self.num_layers,
|
self.num_layers,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.head_size,
|
self.head_size,
|
||||||
|
@ -727,11 +727,43 @@ class FlashCausalLM(Model):
|
||||||
_, batch = self.generate_token(batch)
|
_, batch = self.generate_token(batch)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} "
|
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
|
||||||
f"prefill tokens. "
|
f"You need to decrease `--max-batch-prefill-tokens`"
|
||||||
f"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`"
|
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
||||||
|
# Calculate the number of blocks that can be allocated with the
|
||||||
|
# profiled peak memory.
|
||||||
|
torch.cuda.synchronize(self.device)
|
||||||
|
peak_memory = torch.cuda.max_memory_reserved(self.device)
|
||||||
|
|
||||||
|
dtype_size = torch.tensor([], dtype=self.dtype).element_size()
|
||||||
|
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
||||||
|
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
||||||
|
|
||||||
|
total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory
|
||||||
|
|
||||||
|
# 0.98 to add some wiggle room
|
||||||
|
num_blocks = (
|
||||||
|
int((total_gpu_memory * 0.98 - peak_memory) // total_cache_size)
|
||||||
|
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
|
||||||
|
+ batch.blocks
|
||||||
|
)
|
||||||
|
|
||||||
|
del CACHE_MANAGER
|
||||||
del batch
|
del batch
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
CACHE_MANAGER = CacheManager(
|
||||||
|
num_blocks,
|
||||||
|
self.num_layers,
|
||||||
|
self.num_kv_heads,
|
||||||
|
self.head_size,
|
||||||
|
self.dtype,
|
||||||
|
self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
return int(num_blocks * BLOCK_SIZE)
|
||||||
|
|
||||||
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
|
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
|
||||||
return self.tokenizer.decode(
|
return self.tokenizer.decode(
|
||||||
|
@ -991,7 +1023,6 @@ class FlashCausalLM(Model):
|
||||||
|
|
||||||
if stopped:
|
if stopped:
|
||||||
del batch
|
del batch
|
||||||
torch.cuda.empty_cache()
|
|
||||||
# No need to return a batch if we know that all requests stopped
|
# No need to return a batch if we know that all requests stopped
|
||||||
return generations, None
|
return generations, None
|
||||||
|
|
||||||
|
|
|
@ -58,8 +58,9 @@ class Model(ABC):
|
||||||
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
|
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def warmup(self, batch: B, max_total_tokens: int):
|
def warmup(self, batch: B) -> Optional[int]:
|
||||||
self.generate_token(batch)
|
self.generate_token(batch)
|
||||||
|
return None
|
||||||
|
|
||||||
def decode_token(
|
def decode_token(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -51,21 +51,17 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
filtered_batch = batch.filter(request.request_ids)
|
filtered_batch = batch.filter(request.request_ids)
|
||||||
self.cache.set(filtered_batch)
|
self.cache.set(filtered_batch)
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
||||||
|
|
||||||
async def Warmup(self, request, context):
|
async def Warmup(self, request, context):
|
||||||
batch = self.model.batch_type.from_pb(
|
batch = self.model.batch_type.from_pb(
|
||||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||||
)
|
)
|
||||||
self.model.warmup(batch, request.max_total_tokens)
|
max_supported_total_tokens = self.model.warmup(batch)
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
return generate_pb2.WarmupResponse(
|
||||||
torch.cuda.empty_cache()
|
max_supported_total_tokens=max_supported_total_tokens
|
||||||
|
)
|
||||||
return generate_pb2.WarmupResponse()
|
|
||||||
|
|
||||||
async def Prefill(self, request, context):
|
async def Prefill(self, request, context):
|
||||||
batch = self.model.batch_type.from_pb(
|
batch = self.model.batch_type.from_pb(
|
||||||
|
@ -96,8 +92,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
|
|
||||||
if len(batches) > 1:
|
if len(batches) > 1:
|
||||||
batch = self.model.batch_type.concatenate(batches)
|
batch = self.model.batch_type.concatenate(batches)
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
else:
|
else:
|
||||||
batch = batches[0]
|
batch = batches[0]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue