fix(server): decrease memory fragmentation (#557)
This commit is contained in:
parent
6f42942772
commit
c4bb5264ac
|
@ -1,3 +1,5 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
from typing import Dict, Optional, TypeVar
|
from typing import Dict, Optional, TypeVar
|
||||||
|
|
||||||
from text_generation_server.models.types import Batch
|
from text_generation_server.models.types import Batch
|
||||||
|
@ -20,6 +22,8 @@ class Cache:
|
||||||
batch = self.pop(batch_id)
|
batch = self.pop(batch_id)
|
||||||
if batch is not None:
|
if batch is not None:
|
||||||
del batch
|
del batch
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
keys = list(self.cache.keys())
|
keys = list(self.cache.keys())
|
||||||
|
|
|
@ -638,6 +638,8 @@ class FlashCausalLMBatch(Batch):
|
||||||
# Needed to avoid dropping blocks when the batches will go out of scope
|
# Needed to avoid dropping blocks when the batches will go out of scope
|
||||||
for b in batches:
|
for b in batches:
|
||||||
b.block_tables = None
|
b.block_tables = None
|
||||||
|
del b
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
return FlashCausalLMBatch(
|
return FlashCausalLMBatch(
|
||||||
batch_id=batches[0].batch_id,
|
batch_id=batches[0].batch_id,
|
||||||
|
@ -732,6 +734,7 @@ class FlashCausalLM(Model):
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
del batch
|
del batch
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
|
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
|
||||||
return self.tokenizer.decode(
|
return self.tokenizer.decode(
|
||||||
|
@ -775,16 +778,21 @@ class FlashCausalLM(Model):
|
||||||
# Allocate blocks to this batch
|
# Allocate blocks to this batch
|
||||||
CACHE_MANAGER.allocate(batch)
|
CACHE_MANAGER.allocate(batch)
|
||||||
|
|
||||||
out = self.forward(
|
try:
|
||||||
batch.input_ids,
|
out = self.forward(
|
||||||
batch.position_ids,
|
batch.input_ids,
|
||||||
batch.cu_seqlen_prefill,
|
batch.position_ids,
|
||||||
batch.block_tables_tensor,
|
batch.cu_seqlen_prefill,
|
||||||
batch.slots[batch.slot_indices],
|
batch.block_tables_tensor,
|
||||||
batch.input_lengths_tensor,
|
batch.slots[batch.slot_indices],
|
||||||
batch.max_seqlen,
|
batch.input_lengths_tensor,
|
||||||
batch.prefill_head_indices,
|
batch.max_seqlen,
|
||||||
)
|
batch.prefill_head_indices,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
del batch
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
raise e
|
||||||
|
|
||||||
if prefill:
|
if prefill:
|
||||||
next_token_logits = (
|
next_token_logits = (
|
||||||
|
|
Loading…
Reference in New Issue