This repository has been archived on 2024-10-27. You can view files and clone it, but cannot push or open issues or pull requests.
local-llm-server/llm_server/custom_redis.py

123 lines
3.5 KiB
Python

import logging
import pickle
import sys
import traceback
from typing import Union
import redis as redis_pkg
import simplejson as json
from flask_caching import Cache
from redis import Redis
from redis.typing import ExpiryT, KeyT, PatternT
flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/15', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
ONE_MONTH_SECONDS = 2678000
class RedisCustom(Redis):
"""
A simple wrapper class for Redis to create a "namespace" within a DB,
which simplifies key management.
"""
# TODO: is there a better way to do this instead of overriding every single method?
def __init__(self, prefix, **kwargs):
super().__init__()
self.redis = Redis(**kwargs)
self.prefix = prefix
try:
self.set('____', 1)
except redis_pkg.exceptions.ConnectionError as e:
logger = logging.getLogger('redis')
logger.setLevel(logging.INFO)
logger.error(f'Failed to connect to the Redis server: {e}\nDid you install and start the Redis server?')
sys.exit(1)
def _key(self, key):
return f"{self.prefix}:{key}"
def execute_command(self, *args, **options):
if args[0] != 'GET':
args = list(args)
args[1] = self._key(args[1])
return super().execute_command(*args, **options)
def get(self, key, default=None, dtype=None):
# TODO: use pickle
import inspect
if inspect.isclass(default):
raise Exception
d = self.redis.get(self._key(key))
if dtype and d:
try:
if dtype == str:
return d.decode('utf-8')
if dtype in [dict, list]:
return json.loads(d.decode("utf-8"))
else:
return dtype(d)
except:
traceback.print_exc()
if not d:
return default
else:
return d
def keys(self, pattern: PatternT = "*", **kwargs):
raw_keys = self.redis.keys(self._key(pattern), **kwargs)
keys = []
for key in raw_keys:
p = key.decode('utf-8').split(':')
if len(p) >= 2:
# Delete prefix
del p[0]
k = ':'.join(p)
# keys.append(k)
return keys
def exists(self, *names: KeyT):
n = []
for name in names:
n.append(self._key(name))
return self.redis.exists(*n)
def set_dict(self, key: Union[list, dict], dict_value, ex: Union[ExpiryT, None] = None):
return self.set(key, json.dumps(dict_value), ex=ex)
def get_dict(self, key):
r = self.get(key)
if not r:
return dict()
else:
return json.loads(r.decode("utf-8"))
def setp(self, name, value):
self.redis.set(self._key(name), pickle.dumps(value))
def getp(self, name: str):
r = self.redis.get(self._key(name))
if r:
return pickle.loads(r)
return r
def flush(self):
flushed = []
for key in self.redis.scan_iter(f'{self.prefix}:*'):
flushed.append(key)
self.redis.delete(key)
return flushed
def flushall(self, asynchronous: bool = ..., **kwargs) -> bool:
self.flush()
return True
def flushdb(self, asynchronous: bool = ..., **kwargs) -> bool:
self.flush()
return True
redis = RedisCustom('local_llm')