import pickle import sys import traceback from typing import Callable, List, Mapping, Optional, Union import redis as redis_pkg import simplejson as json from flask_caching import Cache from redis import Redis from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, PatternT, ZScoreBoundT flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/15', 'CACHE_KEY_PREFIX': 'local_llm_flask'}) ONE_MONTH_SECONDS = 2678000 class RedisCustom(Redis): """ A wrapper class to set prefixes to keys. """ def __init__(self, prefix, **kwargs): super().__init__() self.redis = Redis(**kwargs) self.prefix = prefix try: self.set('____', 1) except redis_pkg.exceptions.ConnectionError as e: print('Failed to connect to the Redis server:', e) print('Did you install and start the Redis server?') sys.exit(1) def _key(self, key): return f"{self.prefix}:{key}" def set(self, key, value, ex: Union[ExpiryT, None] = None): return self.redis.set(self._key(key), value, ex=ex) def get(self, key, default=None, dtype=None): # TODO: use pickle import inspect if inspect.isclass(default): raise Exception d = self.redis.get(self._key(key)) if dtype and d: try: if dtype == str: return d.decode('utf-8') if dtype in [dict, list]: return json.loads(d.decode("utf-8")) else: return dtype(d) except: traceback.print_exc() if not d: return default else: return d def incr(self, key, amount=1): return self.redis.incr(self._key(key), amount) def decr(self, key, amount=1): return self.redis.decr(self._key(key), amount) def sadd(self, key: str, *values: FieldT): return self.redis.sadd(self._key(key), *values) def srem(self, key: str, *values: FieldT): return self.redis.srem(self._key(key), *values) def sismember(self, key: str, value: str): return self.redis.sismember(self._key(key), value) def lindex( self, name: str, index: int ): return self.redis.lindex(self._key(name), index) def lrem(self, name: str, count: int, value: str): return self.redis.lrem(self._key(name), count, value) def rpush(self, name: str, *values: FieldT): return self.redis.rpush(self._key(name), *values) def llen(self, name: str): return self.redis.llen(self._key(name)) def zrangebyscore( self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT, start: Union[int, None] = None, num: Union[int, None] = None, withscores: bool = False, score_cast_func: Union[type, Callable] = float, ): return self.redis.zrangebyscore(self._key(name), min, max, start, num, withscores, score_cast_func) def zremrangebyscore( self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT ): return self.redis.zremrangebyscore(self._key(name), min, max) def hincrby( self, name: str, key: str, amount: int = 1 ): return self.redis.hincrby(self._key(name), key, amount) def zcard(self, name: KeyT): return self.redis.zcard(self._key(name)) def hdel(self, name: str, *keys: str): return self.redis.hdel(self._key(name), *keys) def hget( self, name: str, key: str ): return self.redis.hget(self._key(name), key) def zadd( self, name: KeyT, mapping: Mapping[AnyKeyT, EncodableT], nx: bool = False, xx: bool = False, ch: bool = False, incr: bool = False, gt: bool = False, lt: bool = False, ): return self.redis.zadd(self._key(name), mapping, nx, xx, ch, incr, gt, lt) def lpush(self, name: str, *values: FieldT): return self.redis.lpush(self._key(name), *values) def hset( self, name: str, key: Optional = None, value=None, mapping: Optional[dict] = None, items: Optional[list] = None, ): return self.redis.hset(self._key(name), key, value, mapping, items) def hkeys(self, name: str): return self.redis.hkeys(self._key(name)) def hmget(self, name: str, keys: List, *args: List): return self.redis.hmget(self._key(name), keys, *args) def hgetall(self, name: str): return self.redis.hgetall(self._key(name)) def keys(self, pattern: PatternT = "*", **kwargs): raw_keys = self.redis.keys(self._key(pattern), **kwargs) keys = [] for key in raw_keys: p = key.decode('utf-8').split(':') if len(p) >= 2: # Delete prefix del p[0] k = ':'.join(p) if k != '____': keys.append(k) return keys def pipeline(self, transaction=True, shard_hint=None): return self.redis.pipeline(transaction, shard_hint) def smembers(self, name: str): return self.redis.smembers(self._key(name)) def spop(self, name: str, count: Optional[int] = None): return self.redis.spop(self._key(name), count) def rpoplpush(self, src, dst): return self.redis.rpoplpush(src, dst) def zpopmin(self, name: KeyT, count: Union[int, None] = None): return self.redis.zpopmin(self._key(name), count) def exists(self, *names: KeyT): n = [] for name in names: n.append(self._key(name)) return self.redis.exists(*n) def set_dict(self, key: Union[list, dict], dict_value, ex: Union[ExpiryT, None] = None): return self.set(key, json.dumps(dict_value), ex=ex) def get_dict(self, key): r = self.get(key) if not r: return dict() else: return json.loads(r.decode("utf-8")) def setp(self, name, value): self.redis.set(self._key(name), pickle.dumps(value)) def getp(self, name: str): r = self.redis.get(self._key(name)) if r: return pickle.loads(r) return r def flush(self): flushed = [] for key in self.redis.scan_iter(f'{self.prefix}:*'): flushed.append(key) self.redis.delete(key) return flushed def flushall(self, asynchronous: bool = ..., **kwargs) -> bool: self.flush() return True def flushdb(self, asynchronous: bool = ..., **kwargs) -> bool: self.flush() return True def lrange(self, name: str, start: int, end: int): return self.redis.lrange(self._key(name), start, end) def delete(self, *names: KeyT): return self.redis.delete(*[self._key(i) for i in names]) def lpop(self, name: str, count: Optional[int] = None): return self.redis.lpop(self._key(name), count) def zrange( self, name: KeyT, start: int, end: int, desc: bool = False, withscores: bool = False, score_cast_func: Union[type, Callable] = float, byscore: bool = False, bylex: bool = False, offset: int = None, num: int = None, ): return self.redis.zrange(self._key(name), start, end, desc, withscores, score_cast_func, byscore, bylex, offset, num) def zrem(self, name: KeyT, *values: FieldT): return self.redis.zrem(self._key(name), *values) redis = RedisCustom('local_llm')