hf_text-generation-inference/server/text_generation_server/cache.py

35 lines
796 B
Python
Raw Normal View History

import torch
from typing import Dict, Optional, TypeVar
2022-10-08 04:30:12 -06:00
2023-03-07 10:52:22 -07:00
from text_generation_server.models.types import Batch
B = TypeVar("B", bound=Batch)
2022-10-08 04:30:12 -06:00
class Cache:
def __init__(self):
self.cache: Dict[int, B] = {}
2022-10-08 04:30:12 -06:00
def pop(self, batch_id: int) -> Optional[B]:
2022-10-08 04:30:12 -06:00
return self.cache.pop(batch_id, None)
def set(self, entry: B):
2022-10-08 04:30:12 -06:00
if entry is not None:
self.cache[entry.batch_id] = entry
def delete(self, batch_id: int):
batch = self.pop(batch_id)
if batch is not None:
del batch
if torch.cuda.is_available():
torch.cuda.empty_cache()
2022-10-08 04:30:12 -06:00
def clear(self):
keys = list(self.cache.keys())
for k in keys:
self.delete(k)
2022-10-08 04:30:12 -06:00
def __len__(self):
return len(self.cache.keys())