49 lines
1.2 KiB
Python
49 lines
1.2 KiB
Python
|
import torch
|
||
|
|
||
|
from dataclasses import dataclass
|
||
|
from typing import Dict, Optional, List
|
||
|
|
||
|
from bloom_inference.pb import generate_pb2
|
||
|
from bloom_inference.utils import NextTokenChooser, StoppingCriteria
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class CacheEntry:
|
||
|
batch_id: int
|
||
|
request_ids: List[int]
|
||
|
input_ids: Dict[str, torch.Tensor]
|
||
|
all_input_ids: List[torch.Tensor]
|
||
|
next_token_choosers: List[NextTokenChooser]
|
||
|
stopping_criterias: List[StoppingCriteria]
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.request_ids)
|
||
|
|
||
|
def to_pb(self):
|
||
|
return generate_pb2.CacheEntry(
|
||
|
id=self.batch_id,
|
||
|
request_ids=self.request_ids,
|
||
|
sequence_length=max(len(entry) for entry in self.all_input_ids),
|
||
|
)
|
||
|
|
||
|
|
||
|
class Cache:
|
||
|
def __init__(self):
|
||
|
self.cache: Dict[str, CacheEntry] = {}
|
||
|
|
||
|
def pop(self, batch_id: str) -> Optional[CacheEntry]:
|
||
|
return self.cache.pop(batch_id, None)
|
||
|
|
||
|
def set(self, entry: CacheEntry):
|
||
|
if entry is not None:
|
||
|
self.cache[entry.batch_id] = entry
|
||
|
|
||
|
def delete(self, batch_id: str):
|
||
|
del self.cache[batch_id]
|
||
|
|
||
|
def clear(self):
|
||
|
self.cache.clear()
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.cache.keys())
|