fix(server): decrease memory fragmentation (#557)

This commit is contained in:
OlivierDehaene 2023-07-06 14:28:33 +02:00 committed by GitHub
parent 6f42942772
commit c4bb5264ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 10 deletions

View File

@ -1,3 +1,5 @@
import torch
from typing import Dict, Optional, TypeVar
from text_generation_server.models.types import Batch
@ -20,6 +22,8 @@ class Cache:
batch = self.pop(batch_id)
if batch is not None:
del batch
if torch.cuda.is_available():
torch.cuda.empty_cache()
def clear(self):
keys = list(self.cache.keys())

View File

@ -638,6 +638,8 @@ class FlashCausalLMBatch(Batch):
# Needed to avoid dropping blocks when the batches will go out of scope
for b in batches:
b.block_tables = None
del b
torch.cuda.empty_cache()
return FlashCausalLMBatch(
batch_id=batches[0].batch_id,
@ -732,6 +734,7 @@ class FlashCausalLM(Model):
)
raise e
del batch
torch.cuda.empty_cache()
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
return self.tokenizer.decode(
@ -775,6 +778,7 @@ class FlashCausalLM(Model):
# Allocate blocks to this batch
CACHE_MANAGER.allocate(batch)
try:
out = self.forward(
batch.input_ids,
batch.position_ids,
@ -785,6 +789,10 @@ class FlashCausalLM(Model):
batch.max_seqlen,
batch.prefill_head_indices,
)
except Exception as e:
del batch
torch.cuda.empty_cache()
raise e
if prefill:
next_token_logits = (