local-llm-server/llm_server/routes/cache.py

108 lines
3.0 KiB
Python
Raw Normal View History

import sys
2023-09-23 23:14:22 -06:00
import traceback
from typing import Union
import redis as redis_pkg
import simplejson as json
2023-08-21 21:28:52 -06:00
from flask_caching import Cache
2023-08-23 01:14:19 -06:00
from redis import Redis
2023-09-27 19:39:04 -06:00
from redis.typing import ExpiryT, FieldT
2023-08-21 21:28:52 -06:00
2023-09-26 22:09:11 -06:00
flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
ONE_MONTH_SECONDS = 2678000
class RedisWrapper:
"""
A wrapper class to set prefixes to keys.
"""
def __init__(self, prefix, **kwargs):
self.redis = Redis(**kwargs)
self.prefix = prefix
try:
self.set('____', 1)
except redis_pkg.exceptions.ConnectionError as e:
print('Failed to connect to the Redis server:', e)
print('Did you install and start the Redis server?')
sys.exit(1)
def _key(self, key):
return f"{self.prefix}:{key}"
def set(self, key, value, ex: Union[ExpiryT, None] = None):
return self.redis.set(self._key(key), value, ex=ex)
def get(self, key, dtype=None, default=None):
2023-09-23 23:14:22 -06:00
"""
:param key:
:param dtype: convert to this type
:return:
"""
2023-09-23 23:14:22 -06:00
d = self.redis.get(self._key(key))
if dtype and d:
try:
if dtype == str:
return d.decode('utf-8')
if dtype in [dict, list]:
return json.loads(d.decode("utf-8"))
2023-09-23 23:14:22 -06:00
else:
return dtype(d)
except:
traceback.print_exc()
if not d:
return default
else:
return d
def incr(self, key, amount=1):
return self.redis.incr(self._key(key), amount)
def decr(self, key, amount=1):
return self.redis.decr(self._key(key), amount)
def sadd(self, key: str, *values: FieldT):
return self.redis.sadd(self._key(key), *values)
def srem(self, key: str, *values: FieldT):
return self.redis.srem(self._key(key), *values)
def sismember(self, key: str, value: str):
return self.redis.sismember(self._key(key), value)
2023-09-27 19:39:04 -06:00
def lindex(
self, name: str, index: int
):
return self.redis.lindex(self._key(name), index)
def lrem(self, name: str, count: int, value: str):
return self.redis.lrem(self._key(name), count, value)
def rpush(self, name: str, *values: FieldT):
return self.redis.rpush(self._key(name), *values)
def llen(self, name: str):
return self.redis.llen(self._key(name))
def set_dict(self, key: Union[list, dict], dict_value, ex: Union[ExpiryT, None] = None):
return self.set(key, json.dumps(dict_value), ex=ex)
def get_dict(self, key):
r = self.get(key)
if not r:
return dict()
else:
return json.loads(r.decode("utf-8"))
def flush(self):
flushed = []
for key in self.redis.scan_iter(f'{self.prefix}:*'):
flushed.append(key)
self.redis.delete(key)
return flushed
redis = RedisWrapper('local_llm')