local-llm-server/llm_server/custom_redis.py

256 lines
7.7 KiB
Python

import pickle
import sys
import traceback
from typing import Callable, List, Mapping, Optional, Union
import redis as redis_pkg
import simplejson as json
from flask_caching import Cache
from redis import Redis
from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, PatternT, ZScoreBoundT
flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/15', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
ONE_MONTH_SECONDS = 2678000
class RedisCustom(Redis):
"""
A simple wrapper class for Redis to create a "namespace" within a DB,
which simplyifies key management.
"""
def __init__(self, prefix, **kwargs):
super().__init__()
self.redis = Redis(**kwargs)
self.prefix = prefix
try:
self.set('____', 1)
except redis_pkg.exceptions.ConnectionError as e:
print('Failed to connect to the Redis server:', e)
print('Did you install and start the Redis server?')
sys.exit(1)
def _key(self, key):
return f"{self.prefix}:{key}"
def set(self, key, value, ex: Union[ExpiryT, None] = None):
return self.redis.set(self._key(key), value, ex=ex)
def get(self, key, default=None, dtype=None):
# TODO: use pickle
import inspect
if inspect.isclass(default):
raise Exception
d = self.redis.get(self._key(key))
if dtype and d:
try:
if dtype == str:
return d.decode('utf-8')
if dtype in [dict, list]:
return json.loads(d.decode("utf-8"))
else:
return dtype(d)
except:
traceback.print_exc()
if not d:
return default
else:
return d
def incr(self, key, amount=1):
return self.redis.incr(self._key(key), amount)
def decr(self, key, amount=1):
return self.redis.decr(self._key(key), amount)
def sadd(self, key: str, *values: FieldT):
return self.redis.sadd(self._key(key), *values)
def srem(self, key: str, *values: FieldT):
return self.redis.srem(self._key(key), *values)
def sismember(self, key: str, value: str):
return self.redis.sismember(self._key(key), value)
def lindex(
self, name: str, index: int
):
return self.redis.lindex(self._key(name), index)
def lrem(self, name: str, count: int, value: str):
return self.redis.lrem(self._key(name), count, value)
def rpush(self, name: str, *values: FieldT):
return self.redis.rpush(self._key(name), *values)
def llen(self, name: str):
return self.redis.llen(self._key(name))
def zrangebyscore(
self,
name: KeyT,
min: ZScoreBoundT,
max: ZScoreBoundT,
start: Union[int, None] = None,
num: Union[int, None] = None,
withscores: bool = False,
score_cast_func: Union[type, Callable] = float,
):
return self.redis.zrangebyscore(self._key(name), min, max, start, num, withscores, score_cast_func)
def zremrangebyscore(
self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT
):
return self.redis.zremrangebyscore(self._key(name), min, max)
def hincrby(
self, name: str, key: str, amount: int = 1
):
return self.redis.hincrby(self._key(name), key, amount)
def zcard(self, name: KeyT):
return self.redis.zcard(self._key(name))
def hdel(self, name: str, *keys: str):
return self.redis.hdel(self._key(name), *keys)
def hget(
self, name: str, key: str
):
return self.redis.hget(self._key(name), key)
def zadd(
self,
name: KeyT,
mapping: Mapping[AnyKeyT, EncodableT],
nx: bool = False,
xx: bool = False,
ch: bool = False,
incr: bool = False,
gt: bool = False,
lt: bool = False,
):
return self.redis.zadd(self._key(name), mapping, nx, xx, ch, incr, gt, lt)
def lpush(self, name: str, *values: FieldT):
return self.redis.lpush(self._key(name), *values)
def hset(
self,
name: str,
key: Optional = None,
value=None,
mapping: Optional[dict] = None,
items: Optional[list] = None,
):
return self.redis.hset(self._key(name), key, value, mapping, items)
def hkeys(self, name: str):
return self.redis.hkeys(self._key(name))
def hmget(self, name: str, keys: List, *args: List):
return self.redis.hmget(self._key(name), keys, *args)
def hgetall(self, name: str):
return self.redis.hgetall(self._key(name))
def keys(self, pattern: PatternT = "*", **kwargs):
raw_keys = self.redis.keys(self._key(pattern), **kwargs)
keys = []
for key in raw_keys:
p = key.decode('utf-8').split(':')
if len(p) >= 2:
# Delete prefix
del p[0]
k = ':'.join(p)
if k != '____':
keys.append(k)
return keys
def pipeline(self, transaction=True, shard_hint=None):
return self.redis.pipeline(transaction, shard_hint)
def smembers(self, name: str):
return self.redis.smembers(self._key(name))
def spop(self, name: str, count: Optional[int] = None):
return self.redis.spop(self._key(name), count)
def rpoplpush(self, src, dst):
return self.redis.rpoplpush(src, dst)
def zpopmin(self, name: KeyT, count: Union[int, None] = None):
return self.redis.zpopmin(self._key(name), count)
def exists(self, *names: KeyT):
n = []
for name in names:
n.append(self._key(name))
return self.redis.exists(*n)
def set_dict(self, key: Union[list, dict], dict_value, ex: Union[ExpiryT, None] = None):
return self.set(key, json.dumps(dict_value), ex=ex)
def get_dict(self, key):
r = self.get(key)
if not r:
return dict()
else:
return json.loads(r.decode("utf-8"))
def setp(self, name, value):
self.redis.set(self._key(name), pickle.dumps(value))
def getp(self, name: str):
r = self.redis.get(self._key(name))
if r:
return pickle.loads(r)
return r
def flush(self):
flushed = []
for key in self.redis.scan_iter(f'{self.prefix}:*'):
flushed.append(key)
self.redis.delete(key)
return flushed
def flushall(self, asynchronous: bool = ..., **kwargs) -> bool:
self.flush()
return True
def flushdb(self, asynchronous: bool = ..., **kwargs) -> bool:
self.flush()
return True
def lrange(self, name: str, start: int, end: int):
return self.redis.lrange(self._key(name), start, end)
def delete(self, *names: KeyT):
return self.redis.delete(*[self._key(i) for i in names])
def lpop(self, name: str, count: Optional[int] = None):
return self.redis.lpop(self._key(name), count)
def zrange(
self,
name: KeyT,
start: int,
end: int,
desc: bool = False,
withscores: bool = False,
score_cast_func: Union[type, Callable] = float,
byscore: bool = False,
bylex: bool = False,
offset: int = None,
num: int = None,
):
return self.redis.zrange(self._key(name), start, end, desc, withscores, score_cast_func, byscore, bylex, offset, num)
def zrem(self, name: KeyT, *values: FieldT):
return self.redis.zrem(self._key(name), *values)
redis = RedisCustom('local_llm')