From c4bb5264acacf55f8f1e693ccd29f7a13d3bcba0 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Thu, 6 Jul 2023 14:28:33 +0200 Subject: [PATCH] fix(server): decrease memory fragmentation (#557) --- server/text_generation_server/cache.py | 4 +++ .../models/flash_causal_lm.py | 28 ++++++++++++------- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/server/text_generation_server/cache.py b/server/text_generation_server/cache.py index 79fcd3aa..4504733e 100644 --- a/server/text_generation_server/cache.py +++ b/server/text_generation_server/cache.py @@ -1,3 +1,5 @@ +import torch + from typing import Dict, Optional, TypeVar from text_generation_server.models.types import Batch @@ -20,6 +22,8 @@ class Cache: batch = self.pop(batch_id) if batch is not None: del batch + if torch.cuda.is_available(): + torch.cuda.empty_cache() def clear(self): keys = list(self.cache.keys()) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index bf5f5bbe..bebd3df5 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -638,6 +638,8 @@ class FlashCausalLMBatch(Batch): # Needed to avoid dropping blocks when the batches will go out of scope for b in batches: b.block_tables = None + del b + torch.cuda.empty_cache() return FlashCausalLMBatch( batch_id=batches[0].batch_id, @@ -732,6 +734,7 @@ class FlashCausalLM(Model): ) raise e del batch + torch.cuda.empty_cache() def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str: return self.tokenizer.decode( @@ -775,16 +778,21 @@ class FlashCausalLM(Model): # Allocate blocks to this batch CACHE_MANAGER.allocate(batch) - out = self.forward( - batch.input_ids, - batch.position_ids, - batch.cu_seqlen_prefill, - batch.block_tables_tensor, - batch.slots[batch.slot_indices], - batch.input_lengths_tensor, - batch.max_seqlen, - batch.prefill_head_indices, - ) + try: + out = self.forward( + batch.input_ids, + batch.position_ids, + batch.cu_seqlen_prefill, + batch.block_tables_tensor, + batch.slots[batch.slot_indices], + batch.input_lengths_tensor, + batch.max_seqlen, + batch.prefill_head_indices, + ) + except Exception as e: + del batch + torch.cuda.empty_cache() + raise e if prefill: next_token_logits = (