diff --git a/server/text_generation_server/interceptor.py b/server/text_generation_server/interceptor.py index a3247d1..725105f 100644 --- a/server/text_generation_server/interceptor.py +++ b/server/text_generation_server/interceptor.py @@ -1,3 +1,4 @@ +import torch import grpc from google.rpc import status_pb2, code_pb2 @@ -22,6 +23,9 @@ class ExceptionInterceptor(AsyncServerInterceptor): method_name = method_name.split("/")[-1] logger.exception(f"Method {method_name} encountered an error.") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + await context.abort_with_status( rpc_status.to_status( status_pb2.Status(code=code_pb2.INTERNAL, message=str(err)) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 5420556..4e5804f 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -639,7 +639,6 @@ class FlashCausalLMBatch(Batch): for b in batches: b.block_tables = None del b - torch.cuda.empty_cache() return FlashCausalLMBatch( batch_id=batches[0].batch_id, @@ -733,7 +732,6 @@ class FlashCausalLM(Model): f"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`" ) from e del batch - torch.cuda.empty_cache() def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str: return self.tokenizer.decode( @@ -790,7 +788,6 @@ class FlashCausalLM(Model): ) except Exception as e: del batch - torch.cuda.empty_cache() raise e if prefill: diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index c375330..7bc62ce 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -51,6 +51,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): filtered_batch = batch.filter(request.request_ids) self.cache.set(filtered_batch) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) async def Warmup(self, request, context): @@ -58,6 +61,10 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): request.batch, self.model.tokenizer, self.model.dtype, self.model.device ) self.model.warmup(batch, request.max_total_tokens) + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + return generate_pb2.WarmupResponse() async def Prefill(self, request, context): @@ -89,6 +96,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): if len(batches) > 1: batch = self.model.batch_type.concatenate(batches) + if torch.cuda.is_available(): + torch.cuda.empty_cache() else: batch = batches[0]