2022-11-04 11:03:04 -06:00
|
|
|
from typing import Dict, Optional, TypeVar
|
2022-10-08 04:30:12 -06:00
|
|
|
|
2022-10-28 11:24:00 -06:00
|
|
|
from text_generation.models.types import Batch
|
|
|
|
|
2022-11-04 11:03:04 -06:00
|
|
|
B = TypeVar("B", bound=Batch)
|
|
|
|
|
2022-10-08 04:30:12 -06:00
|
|
|
|
|
|
|
class Cache:
|
|
|
|
def __init__(self):
|
2022-11-04 11:03:04 -06:00
|
|
|
self.cache: Dict[int, B] = {}
|
2022-10-08 04:30:12 -06:00
|
|
|
|
2022-11-04 11:03:04 -06:00
|
|
|
def pop(self, batch_id: int) -> Optional[B]:
|
2022-10-08 04:30:12 -06:00
|
|
|
return self.cache.pop(batch_id, None)
|
|
|
|
|
2022-11-04 11:03:04 -06:00
|
|
|
def set(self, entry: B):
|
2022-10-08 04:30:12 -06:00
|
|
|
if entry is not None:
|
|
|
|
self.cache[entry.batch_id] = entry
|
|
|
|
|
2022-10-11 08:50:54 -06:00
|
|
|
def delete(self, batch_id: int):
|
2022-10-08 04:30:12 -06:00
|
|
|
del self.cache[batch_id]
|
|
|
|
|
|
|
|
def clear(self):
|
|
|
|
self.cache.clear()
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.cache.keys())
|