feat(server): auto max_batch_total_tokens for flash att models (#630)

This commit is contained in:
OlivierDehaene 2023-07-19 09:31:25 +02:00 committed by GitHub
parent 5e6ddfd6a4
commit fe80f5360c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 159 additions and 94 deletions

View File

@ -184,8 +184,8 @@ struct Args {
/// depends on other parameters like if you're using quantization, flash attention /// depends on other parameters like if you're using quantization, flash attention
/// or the model implementation, text-generation-inference cannot infer this number /// or the model implementation, text-generation-inference cannot infer this number
/// automatically. /// automatically.
#[clap(default_value = "16000", long, env)] #[clap(long, env)]
max_batch_total_tokens: u32, max_batch_total_tokens: Option<u32>,
/// This setting defines how many tokens can be passed before forcing the waiting /// This setting defines how many tokens can be passed before forcing the waiting
/// queries to be put on the batch (if the size of the batch allows for it). /// queries to be put on the batch (if the size of the batch allows for it).
@ -369,12 +369,6 @@ fn shard_manager(
// Copy current process env // Copy current process env
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
// Use cuda allocator. It leads to less memory fragmentation
envs.push((
"PYTORCH_CUDA_ALLOC_CONF".into(),
"backend:cudaMallocAsync".into(),
));
// Torch Distributed Env vars // Torch Distributed Env vars
envs.push(("RANK".into(), rank.to_string().into())); envs.push(("RANK".into(), rank.to_string().into()));
envs.push(("WORLD_SIZE".into(), world_size.to_string().into())); envs.push(("WORLD_SIZE".into(), world_size.to_string().into()));
@ -428,7 +422,7 @@ fn shard_manager(
} }
// Start process // Start process
tracing::info!("Starting shard {rank}"); tracing::info!("Starting shard");
let mut p = match Command::new("text-generation-server") let mut p = match Command::new("text-generation-server")
.args(shard_args) .args(shard_args)
.envs(envs) .envs(envs)
@ -493,17 +487,17 @@ fn shard_manager(
if shutdown.load(Ordering::SeqCst) { if shutdown.load(Ordering::SeqCst) {
p.kill().unwrap(); p.kill().unwrap();
let _ = p.wait(); let _ = p.wait();
tracing::info!("Shard {rank} terminated"); tracing::info!("Shard terminated");
return; return;
} }
// Shard is ready // Shard is ready
if uds.exists() && !ready { if uds.exists() && !ready {
tracing::info!("Shard {rank} ready in {:?}", start_time.elapsed()); tracing::info!("Shard ready in {:?}", start_time.elapsed());
status_sender.send(ShardStatus::Ready).unwrap(); status_sender.send(ShardStatus::Ready).unwrap();
ready = true; ready = true;
} else if !ready && wait_time.elapsed() > Duration::from_secs(10) { } else if !ready && wait_time.elapsed() > Duration::from_secs(10) {
tracing::info!("Waiting for shard {rank} to be ready..."); tracing::info!("Waiting for shard to be ready...");
wait_time = Instant::now(); wait_time = Instant::now();
} }
sleep(Duration::from_millis(100)); sleep(Duration::from_millis(100));
@ -860,8 +854,6 @@ fn spawn_webserver(
args.max_total_tokens.to_string(), args.max_total_tokens.to_string(),
"--max-batch-prefill-tokens".to_string(), "--max-batch-prefill-tokens".to_string(),
args.max_batch_prefill_tokens.to_string(), args.max_batch_prefill_tokens.to_string(),
"--max-batch-total-tokens".to_string(),
args.max_batch_total_tokens.to_string(),
"--waiting-served-ratio".to_string(), "--waiting-served-ratio".to_string(),
args.waiting_served_ratio.to_string(), args.waiting_served_ratio.to_string(),
"--max-waiting-tokens".to_string(), "--max-waiting-tokens".to_string(),
@ -878,6 +870,12 @@ fn spawn_webserver(
args.model_id, args.model_id,
]; ];
// Model optional max batch total tokens
if let Some(max_batch_total_tokens) = args.max_batch_total_tokens {
router_args.push("--max-batch-total-tokens".to_string());
router_args.push(max_batch_total_tokens.to_string());
}
// Model optional revision // Model optional revision
if let Some(ref revision) = args.revision { if let Some(ref revision) = args.revision {
router_args.push("--revision".to_string()); router_args.push("--revision".to_string());
@ -1036,18 +1034,7 @@ fn main() -> Result<(), LauncherError> {
args.max_batch_prefill_tokens, args.max_input_length args.max_batch_prefill_tokens, args.max_input_length
))); )));
} }
if args.max_batch_prefill_tokens > args.max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
args.max_batch_prefill_tokens, args.max_batch_total_tokens
)));
}
if args.max_total_tokens as u32 > args.max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
args.max_total_tokens, args.max_batch_total_tokens
)));
}
if args.validation_workers == 0 { if args.validation_workers == 0 {
return Err(LauncherError::ArgumentValidation( return Err(LauncherError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(), "`validation_workers` must be > 0".to_string(),
@ -1065,6 +1052,21 @@ fn main() -> Result<(), LauncherError> {
tracing::info!("Sharding model on {num_shard} processes"); tracing::info!("Sharding model on {num_shard} processes");
} }
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
if args.max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
args.max_batch_prefill_tokens, max_batch_total_tokens
)));
}
if args.max_total_tokens as u32 > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
args.max_total_tokens, max_batch_total_tokens
)));
}
}
// Signal handler // Signal handler
let running = Arc::new(AtomicBool::new(true)); let running = Arc::new(AtomicBool::new(true));
let r = running.clone(); let r = running.clone();

View File

@ -198,9 +198,10 @@ message DecodeResponse {
message WarmupRequest { message WarmupRequest {
/// Batch to warmup on /// Batch to warmup on
Batch batch = 1; Batch batch = 1;
/// Maximum number of tokens that the client will send
uint32 max_total_tokens = 2;
} }
/// Empty response /// Empty response
message WarmupResponse {} message WarmupResponse {
/// Maximum number of tokens supported by the model
optional uint32 max_supported_total_tokens = 1;
}

View File

@ -103,8 +103,7 @@ impl Client {
&mut self, &mut self,
max_input_length: u32, max_input_length: u32,
max_prefill_tokens: u32, max_prefill_tokens: u32,
max_total_tokens: u32, ) -> Result<Option<u32>> {
) -> Result<()> {
let mut n_tokens = 0; let mut n_tokens = 0;
let mut requests = Vec::new(); let mut requests = Vec::new();
@ -143,13 +142,9 @@ impl Client {
max_tokens: 0, max_tokens: 0,
}; };
let request = tonic::Request::new(WarmupRequest { let request = tonic::Request::new(WarmupRequest { batch: Some(batch) }).inject_context();
batch: Some(batch), let response = self.stub.warmup(request).await?.into_inner();
max_total_tokens, Ok(response.max_supported_total_tokens)
})
.inject_context();
self.stub.warmup(request).await?.into_inner();
Ok(())
} }
/// Generate one token for each request in the given batch /// Generate one token for each request in the given batch

View File

@ -95,14 +95,11 @@ impl ShardedClient {
&mut self, &mut self,
max_input_length: u32, max_input_length: u32,
max_prefill_tokens: u32, max_prefill_tokens: u32,
max_total_tokens: u32, ) -> Result<Option<u32>> {
) -> Result<()> {
let futures: Vec<_> = self let futures: Vec<_> = self
.clients .clients
.iter_mut() .iter_mut()
.map(|client| { .map(|client| Box::pin(client.warmup(max_input_length, max_prefill_tokens)))
Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens))
})
.collect(); .collect();
// all shards return the same message // all shards return the same message
join_all(futures).await.pop().unwrap() join_all(futures).await.pop().unwrap()

View File

@ -53,7 +53,7 @@ impl Infer {
generation_health: Arc<AtomicBool>, generation_health: Arc<AtomicBool>,
) -> Self { ) -> Self {
// Infer shared state // Infer shared state
let queue = Queue::new(requires_padding); let queue = Queue::new(requires_padding, 16);
let shared = Arc::new(Shared { let shared = Arc::new(Shared {
batching_task: Notify::new(), batching_task: Notify::new(),
}); });

View File

@ -37,8 +37,8 @@ struct Args {
waiting_served_ratio: f32, waiting_served_ratio: f32,
#[clap(default_value = "4096", long, env)] #[clap(default_value = "4096", long, env)]
max_batch_prefill_tokens: u32, max_batch_prefill_tokens: u32,
#[clap(default_value = "16000", long, env)] #[clap(long, env)]
max_batch_total_tokens: u32, max_batch_total_tokens: Option<u32>,
#[clap(default_value = "20", long, env)] #[clap(default_value = "20", long, env)]
max_waiting_tokens: usize, max_waiting_tokens: usize,
#[clap(default_value = "0.0.0.0", long, env)] #[clap(default_value = "0.0.0.0", long, env)]
@ -110,18 +110,22 @@ fn main() -> Result<(), RouterError> {
if max_input_length as u32 > max_batch_prefill_tokens { if max_input_length as u32 > max_batch_prefill_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}"))); return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}")));
} }
if max_batch_prefill_tokens > max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
}
if max_total_tokens as u32 > max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
}
if validation_workers == 0 { if validation_workers == 0 {
return Err(RouterError::ArgumentValidation( return Err(RouterError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(), "`validation_workers` must be > 0".to_string(),
)); ));
} }
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
if max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
}
if max_total_tokens as u32 > *max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
}
}
// CORS allowed origins // CORS allowed origins
// map to go inside the option and then map to parse from String to HeaderValue // map to go inside the option and then map to parse from String to HeaderValue
// Finally, convert to AllowOrigin // Finally, convert to AllowOrigin
@ -210,14 +214,35 @@ fn main() -> Result<(), RouterError> {
// Warmup model // Warmup model
tracing::info!("Warming up model"); tracing::info!("Warming up model");
sharded_client let max_supported_batch_total_tokens = match sharded_client
.warmup( .warmup(max_input_length as u32, max_batch_prefill_tokens)
max_input_length as u32,
max_batch_prefill_tokens,
max_batch_total_tokens,
)
.await .await
.map_err(RouterError::Warmup)?; .map_err(RouterError::Warmup)?
{
// Older models do not support automatic max-batch-total-tokens
None => {
let max_batch_total_tokens = max_batch_total_tokens.unwrap_or(
16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)),
);
tracing::warn!("Model does not support automatic max batch total tokens");
max_batch_total_tokens
}
// Flash attention models return their max supported total tokens
Some(max_supported_batch_total_tokens) => {
// Warn if user added his own max-batch-total-tokens as we will ignore it
if max_batch_total_tokens.is_some() {
tracing::warn!(
"`--max-batch-total-tokens` is deprecated for Flash \
Attention models."
);
tracing::warn!(
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
);
}
max_supported_batch_total_tokens
}
};
tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}");
tracing::info!("Connected"); tracing::info!("Connected");
let addr = match hostname.parse() { let addr = match hostname.parse() {
@ -240,7 +265,7 @@ fn main() -> Result<(), RouterError> {
max_total_tokens, max_total_tokens,
waiting_served_ratio, waiting_served_ratio,
max_batch_prefill_tokens, max_batch_prefill_tokens,
max_batch_total_tokens, max_supported_batch_total_tokens,
max_waiting_tokens, max_waiting_tokens,
sharded_client, sharded_client,
tokenizer, tokenizer,

View File

@ -33,12 +33,12 @@ pub(crate) struct Queue {
} }
impl Queue { impl Queue {
pub(crate) fn new(requires_padding: bool) -> Self { pub(crate) fn new(requires_padding: bool, block_size: u32) -> Self {
// Create channel // Create channel
let (queue_sender, queue_receiver) = flume::unbounded(); let (queue_sender, queue_receiver) = flume::unbounded();
// Launch background queue task // Launch background queue task
tokio::spawn(queue_task(requires_padding, queue_receiver)); tokio::spawn(queue_task(requires_padding, block_size, queue_receiver));
Self { queue_sender } Self { queue_sender }
} }
@ -81,8 +81,12 @@ impl Queue {
} }
// Background task responsible of the queue state // Background task responsible of the queue state
async fn queue_task(requires_padding: bool, receiver: flume::Receiver<QueueCommand>) { async fn queue_task(
let mut state = State::new(requires_padding); requires_padding: bool,
block_size: u32,
receiver: flume::Receiver<QueueCommand>,
) {
let mut state = State::new(requires_padding, block_size);
while let Ok(cmd) = receiver.recv_async().await { while let Ok(cmd) = receiver.recv_async().await {
match cmd { match cmd {
@ -119,15 +123,19 @@ struct State {
/// Whether the model is using padding /// Whether the model is using padding
requires_padding: bool, requires_padding: bool,
/// Paged Attention block size
block_size: u32,
} }
impl State { impl State {
fn new(requires_padding: bool) -> Self { fn new(requires_padding: bool, block_size: u32) -> Self {
Self { Self {
entries: VecDeque::with_capacity(128), entries: VecDeque::with_capacity(128),
next_id: 0, next_id: 0,
next_batch_id: 0, next_batch_id: 0,
requires_padding, requires_padding,
block_size,
} }
} }
@ -187,10 +195,21 @@ impl State {
max_input_length = max_input_length.max(entry.request.input_length); max_input_length = max_input_length.max(entry.request.input_length);
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length
} else { } else {
prefill_tokens += entry.request.input_length; // pad to block size
prefill_tokens += ((entry.request.input_length + self.block_size - 1)
/ self.block_size)
* self.block_size;
} }
decode_tokens += entry.request.stopping_parameters.max_new_tokens; if self.requires_padding {
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
} else {
// pad to block size
decode_tokens +=
((entry.request.stopping_parameters.max_new_tokens + self.block_size - 1)
/ self.block_size)
* self.block_size;
}
if prefill_tokens > prefill_token_budget if prefill_tokens > prefill_token_budget
|| (prefill_tokens + decode_tokens) > token_budget || (prefill_tokens + decode_tokens) > token_budget
@ -321,7 +340,7 @@ mod tests {
#[test] #[test]
fn test_append() { fn test_append() {
let mut state = State::new(false); let mut state = State::new(false, 1);
let (entry, _guard) = default_entry(); let (entry, _guard) = default_entry();
assert_eq!(state.next_id, 0); assert_eq!(state.next_id, 0);
@ -337,7 +356,7 @@ mod tests {
#[test] #[test]
fn test_next_batch_empty() { fn test_next_batch_empty() {
let mut state = State::new(false); let mut state = State::new(false, 1);
assert!(state.next_batch(None, 1, 1).is_none()); assert!(state.next_batch(None, 1, 1).is_none());
assert!(state.next_batch(Some(1), 1, 1).is_none()); assert!(state.next_batch(Some(1), 1, 1).is_none());
@ -345,7 +364,7 @@ mod tests {
#[test] #[test]
fn test_next_batch_min_size() { fn test_next_batch_min_size() {
let mut state = State::new(false); let mut state = State::new(false, 1);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
@ -377,7 +396,7 @@ mod tests {
#[test] #[test]
fn test_next_batch_token_budget() { fn test_next_batch_token_budget() {
let mut state = State::new(false); let mut state = State::new(false, 1);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
@ -410,14 +429,14 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_append() { async fn test_queue_append() {
let queue = Queue::new(false); let queue = Queue::new(false, 1);
let (entry, _guard) = default_entry(); let (entry, _guard) = default_entry();
queue.append(entry); queue.append(entry);
} }
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_empty() { async fn test_queue_next_batch_empty() {
let queue = Queue::new(false); let queue = Queue::new(false, 1);
assert!(queue.next_batch(None, 1, 1).await.is_none()); assert!(queue.next_batch(None, 1, 1).await.is_none());
assert!(queue.next_batch(Some(1), 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), 1, 1).await.is_none());
@ -425,7 +444,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_min_size() { async fn test_queue_next_batch_min_size() {
let queue = Queue::new(false); let queue = Queue::new(false, 1);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
@ -458,7 +477,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_token_budget() { async fn test_queue_next_batch_token_budget() {
let queue = Queue::new(false); let queue = Queue::new(false, 1);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
@ -483,7 +502,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_dropped_receiver() { async fn test_queue_next_batch_dropped_receiver() {
let queue = Queue::new(false); let queue = Queue::new(false, 1);
let (entry, _) = default_entry(); let (entry, _) = default_entry();
queue.append(entry); queue.append(entry);

View File

@ -710,14 +710,14 @@ class FlashCausalLM(Model):
def batch_type(self) -> Type[FlashCausalLMBatch]: def batch_type(self) -> Type[FlashCausalLMBatch]:
return FlashCausalLMBatch return FlashCausalLMBatch
def warmup(self, batch: FlashCausalLMBatch, max_total_tokens: int): def warmup(self, batch: FlashCausalLMBatch):
global CACHE_MANAGER global CACHE_MANAGER
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(self.device)
try: try:
CACHE_MANAGER = CacheManager( CACHE_MANAGER = CacheManager(
# Adds some wiggle room batch.blocks,
math.ceil(max_total_tokens / BLOCK_SIZE) + 10,
self.num_layers, self.num_layers,
self.num_kv_heads, self.num_kv_heads,
self.head_size, self.head_size,
@ -727,11 +727,43 @@ class FlashCausalLM(Model):
_, batch = self.generate_token(batch) _, batch = self.generate_token(batch)
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(
f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} " f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
f"prefill tokens. " f"You need to decrease `--max-batch-prefill-tokens`"
f"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`"
) from e ) from e
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
torch.cuda.synchronize(self.device)
peak_memory = torch.cuda.max_memory_reserved(self.device)
dtype_size = torch.tensor([], dtype=self.dtype).element_size()
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory
# 0.98 to add some wiggle room
num_blocks = (
int((total_gpu_memory * 0.98 - peak_memory) // total_cache_size)
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
+ batch.blocks
)
del CACHE_MANAGER
del batch del batch
torch.cuda.empty_cache()
CACHE_MANAGER = CacheManager(
num_blocks,
self.num_layers,
self.num_kv_heads,
self.head_size,
self.dtype,
self.device,
)
return int(num_blocks * BLOCK_SIZE)
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str: def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
return self.tokenizer.decode( return self.tokenizer.decode(
@ -991,7 +1023,6 @@ class FlashCausalLM(Model):
if stopped: if stopped:
del batch del batch
torch.cuda.empty_cache()
# No need to return a batch if we know that all requests stopped # No need to return a batch if we know that all requests stopped
return generations, None return generations, None

View File

@ -58,8 +58,9 @@ class Model(ABC):
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]: def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
raise NotImplementedError raise NotImplementedError
def warmup(self, batch: B, max_total_tokens: int): def warmup(self, batch: B) -> Optional[int]:
self.generate_token(batch) self.generate_token(batch)
return None
def decode_token( def decode_token(
self, self,

View File

@ -51,21 +51,17 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
filtered_batch = batch.filter(request.request_ids) filtered_batch = batch.filter(request.request_ids)
self.cache.set(filtered_batch) self.cache.set(filtered_batch)
if torch.cuda.is_available():
torch.cuda.empty_cache()
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
async def Warmup(self, request, context): async def Warmup(self, request, context):
batch = self.model.batch_type.from_pb( batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.dtype, self.model.device request.batch, self.model.tokenizer, self.model.dtype, self.model.device
) )
self.model.warmup(batch, request.max_total_tokens) max_supported_total_tokens = self.model.warmup(batch)
if torch.cuda.is_available(): return generate_pb2.WarmupResponse(
torch.cuda.empty_cache() max_supported_total_tokens=max_supported_total_tokens
)
return generate_pb2.WarmupResponse()
async def Prefill(self, request, context): async def Prefill(self, request, context):
batch = self.model.batch_type.from_pb( batch = self.model.batch_type.from_pb(
@ -96,8 +92,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
if len(batches) > 1: if len(batches) > 1:
batch = self.model.batch_type.concatenate(batches) batch = self.model.batch_type.concatenate(batches)
if torch.cuda.is_available():
torch.cuda.empty_cache()
else: else:
batch = batches[0] batch = batches[0]