154 lines
4.5 KiB
Python
154 lines
4.5 KiB
Python
import sys
|
|
import traceback
|
|
from typing import Callable, List, Mapping, Union
|
|
|
|
import redis as redis_pkg
|
|
import simplejson as json
|
|
from flask_caching import Cache
|
|
from redis import Redis
|
|
from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, ZScoreBoundT
|
|
|
|
flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
|
|
|
|
ONE_MONTH_SECONDS = 2678000
|
|
|
|
|
|
class RedisWrapper:
|
|
"""
|
|
A wrapper class to set prefixes to keys.
|
|
"""
|
|
|
|
def __init__(self, prefix, **kwargs):
|
|
self.redis = Redis(**kwargs)
|
|
self.prefix = prefix
|
|
try:
|
|
self.set('____', 1)
|
|
except redis_pkg.exceptions.ConnectionError as e:
|
|
print('Failed to connect to the Redis server:', e)
|
|
print('Did you install and start the Redis server?')
|
|
sys.exit(1)
|
|
|
|
def _key(self, key):
|
|
return f"{self.prefix}:{key}"
|
|
|
|
def set(self, key, value, ex: Union[ExpiryT, None] = None):
|
|
return self.redis.set(self._key(key), value, ex=ex)
|
|
|
|
def get(self, key, dtype=None, default=None):
|
|
"""
|
|
:param key:
|
|
:param dtype: convert to this type
|
|
:return:
|
|
"""
|
|
|
|
d = self.redis.get(self._key(key))
|
|
if dtype and d:
|
|
try:
|
|
if dtype == str:
|
|
return d.decode('utf-8')
|
|
if dtype in [dict, list]:
|
|
return json.loads(d.decode("utf-8"))
|
|
else:
|
|
return dtype(d)
|
|
except:
|
|
traceback.print_exc()
|
|
if not d:
|
|
return default
|
|
else:
|
|
return d
|
|
|
|
def incr(self, key, amount=1):
|
|
return self.redis.incr(self._key(key), amount)
|
|
|
|
def decr(self, key, amount=1):
|
|
return self.redis.decr(self._key(key), amount)
|
|
|
|
def sadd(self, key: str, *values: FieldT):
|
|
return self.redis.sadd(self._key(key), *values)
|
|
|
|
def srem(self, key: str, *values: FieldT):
|
|
return self.redis.srem(self._key(key), *values)
|
|
|
|
def sismember(self, key: str, value: str):
|
|
return self.redis.sismember(self._key(key), value)
|
|
|
|
def lindex(
|
|
self, name: str, index: int
|
|
):
|
|
return self.redis.lindex(self._key(name), index)
|
|
|
|
def lrem(self, name: str, count: int, value: str):
|
|
return self.redis.lrem(self._key(name), count, value)
|
|
|
|
def rpush(self, name: str, *values: FieldT):
|
|
return self.redis.rpush(self._key(name), *values)
|
|
|
|
def llen(self, name: str):
|
|
return self.redis.llen(self._key(name))
|
|
|
|
def zrangebyscore(
|
|
self,
|
|
name: KeyT,
|
|
min: ZScoreBoundT,
|
|
max: ZScoreBoundT,
|
|
start: Union[int, None] = None,
|
|
num: Union[int, None] = None,
|
|
withscores: bool = False,
|
|
score_cast_func: Union[type, Callable] = float,
|
|
):
|
|
return self.redis.zrangebyscore(self._key(name), min, max, start, num, withscores, score_cast_func)
|
|
|
|
def zremrangebyscore(
|
|
self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT
|
|
):
|
|
return self.redis.zremrangebyscore(self._key(name), min, max)
|
|
|
|
def hincrby(
|
|
self, name: str, key: str, amount: int = 1
|
|
):
|
|
return self.redis.hincrby(self._key(name), key, amount)
|
|
|
|
def hdel(self, name: str, *keys: List):
|
|
return self.redis.hdel(self._key(name), *keys)
|
|
|
|
def hget(
|
|
self, name: str, key: str
|
|
):
|
|
return self.redis.hget(self._key(name), key)
|
|
|
|
def zadd(
|
|
self,
|
|
name: KeyT,
|
|
mapping: Mapping[AnyKeyT, EncodableT],
|
|
nx: bool = False,
|
|
xx: bool = False,
|
|
ch: bool = False,
|
|
incr: bool = False,
|
|
gt: bool = False,
|
|
lt: bool = False,
|
|
):
|
|
return self.redis.zadd(self._key(name), mapping, nx, xx, ch, incr, gt, lt)
|
|
|
|
def hkeys(self, name: str):
|
|
return self.redis.hkeys(self._key(name))
|
|
|
|
def set_dict(self, key: Union[list, dict], dict_value, ex: Union[ExpiryT, None] = None):
|
|
return self.set(key, json.dumps(dict_value), ex=ex)
|
|
|
|
def get_dict(self, key):
|
|
r = self.get(key)
|
|
if not r:
|
|
return dict()
|
|
else:
|
|
return json.loads(r.decode("utf-8"))
|
|
|
|
def flush(self):
|
|
flushed = []
|
|
for key in self.redis.scan_iter(f'{self.prefix}:*'):
|
|
flushed.append(key)
|
|
self.redis.delete(key)
|
|
return flushed
|
|
|
|
|
|
redis = RedisWrapper('local_llm')
|