import sys import traceback from typing import Union import redis as redis_pkg import simplejson as json from flask_caching import Cache from redis import Redis from redis.typing import ExpiryT, FieldT flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'}) ONE_MONTH_SECONDS = 2678000 class RedisWrapper: """ A wrapper class to set prefixes to keys. """ def __init__(self, prefix, **kwargs): self.redis = Redis(**kwargs) self.prefix = prefix try: self.set('____', 1) except redis_pkg.exceptions.ConnectionError as e: print('Failed to connect to the Redis server:', e) print('Did you install and start the Redis server?') sys.exit(1) def _key(self, key): return f"{self.prefix}:{key}" def set(self, key, value, ex: Union[ExpiryT, None] = None): return self.redis.set(self._key(key), value, ex=ex) def get(self, key, dtype=None, default=None): """ :param key: :param dtype: convert to this type :return: """ d = self.redis.get(self._key(key)) if dtype and d: try: if dtype == str: return d.decode('utf-8') if dtype in [dict, list]: return json.loads(d.decode("utf-8")) else: return dtype(d) except: traceback.print_exc() if not d: return default else: return d def incr(self, key, amount=1): return self.redis.incr(self._key(key), amount) def decr(self, key, amount=1): return self.redis.decr(self._key(key), amount) def sadd(self, key: str, *values: FieldT): return self.redis.sadd(self._key(key), *values) def srem(self, key: str, *values: FieldT): return self.redis.srem(self._key(key), *values) def sismember(self, key: str, value: str): return self.redis.sismember(self._key(key), value) def lindex( self, name: str, index: int ): return self.redis.lindex(self._key(name), index) def lrem(self, name: str, count: int, value: str): return self.redis.lrem(self._key(name), count, value) def rpush(self, name: str, *values: FieldT): return self.redis.rpush(self._key(name), *values) def llen(self, name: str): return self.redis.llen(self._key(name)) def set_dict(self, key: Union[list, dict], dict_value, ex: Union[ExpiryT, None] = None): return self.set(key, json.dumps(dict_value), ex=ex) def get_dict(self, key): r = self.get(key) if not r: return dict() else: return json.loads(r.decode("utf-8")) def flush(self): flushed = [] for key in self.redis.scan_iter(f'{self.prefix}:*'): flushed.append(key) self.redis.delete(key) return flushed redis = RedisWrapper('local_llm')