import hashlib import pickle import traceback from llm_server.cluster.redis_cycle import add_backend_cycler, redis_cycle from llm_server.cluster.stores import redis_running_models from llm_server.config.global_config import GlobalConfig from llm_server.custom_redis import RedisCustom from llm_server.logging import create_logger from llm_server.routes.helpers.model import estimate_model_size # Don't try to reorganize this file or else you'll run into circular imports. _logger = create_logger('redis') class RedisClusterStore: """ A class used to store the cluster state in Redis. """ def __init__(self, name: str, **kwargs): self.name = name self.config_redis = RedisCustom(name, **kwargs) def clear(self): self.config_redis.flush() def load(self): stuff = {} for item in GlobalConfig.get().cluster: backend_url = item.backend_url.strip('/') item.backend_url = backend_url stuff[backend_url] = item for k, v in stuff.items(): self.add_backend(k, v.dict()) def add_backend(self, name: str, values: dict): self.config_redis.hset(name, mapping={k: pickle.dumps(v) for k, v in values.items()}) self.set_backend_value(name, 'online', False) h = hashlib.sha256(name.encode('utf-8')).hexdigest() self.set_backend_value(name, 'hash', f'{h[:8]}-{h[-8:]}') def set_backend_value(self, backend: str, key: str, value): # By storing the value as a pickle we don't have to cast anything when getting the value from Redis. self.config_redis.hset(backend, key, pickle.dumps(value)) def get_backend(self, name: str): r = self.config_redis.hgetall(name) output = {} for k, v in r.items(): output[k.decode('utf8')] = pickle.loads(v) return output def all(self): keys = self.config_redis.keys('*') if keys: result = {} for key in keys: if key != f'{self.name}:____': v = self.get_backend(key) result[key] = v return result else: return {} def validate_backend(self, backend_url: str): """ Returns the backend URL that was given. If that backend is offline, it will select a new one. This fallback behavior does NOT take the selected model into account. :param backend_url: :return: """ backend_info = self.get_backend(backend_url) if not backend_info['online']: old = backend_url backend_url = get_a_cluster_backend() _logger.debug(f'Backend {old} offline. Request was redirected to {backend_url}') return backend_url cluster_config = RedisClusterStore('cluster_config') def get_backends(): """ Get all the backends in the cluster, sorted by priority. The first tuple is the online ones, second is the ones that are offline. :return: """ backends = cluster_config.all() result = {} for k, v in backends.items(): b = cluster_config.get_backend(k) status = b.get('online', False) priority = b['priority'] result[k] = {'status': status, 'priority': priority} try: if not GlobalConfig.get().prioritize_by_size: online_backends = sorted( ((url, info) for url, info in backends.items() if info['online']), key=lambda kv: -kv[1]['priority'], reverse=True ) else: online_backends = sorted( ((url, info) for url, info in backends.items() if info['online']), key=lambda kv: estimate_model_size(kv[1]['model_config']), reverse=True ) offline_backends = sorted( ((url, info) for url, info in backends.items() if not info['online']), key=lambda kv: -kv[1]['priority'], reverse=True ) return [url for url, info in online_backends], [url for url, info in offline_backends] except KeyError: _logger.err(f'Failed to get a backend from the cluster config: {traceback.format_exc()}\nCurrent backends: {backends}') def get_a_cluster_backend(model=None): """ Get a backend from Redis. If there are no online backends, return None. If `model` is not supplied, we will pick one ourselves. """ if model: # First, determine if there are multiple backends hosting the same model. backends_hosting_model = [i.decode('utf-8') for i in redis_running_models.smembers(model)] # If so, create a Redis "cycle" iterator for those backends. if len(backends_hosting_model): add_backend_cycler(model, backends_hosting_model) cycled = redis_cycle(model) if len(cycled): return cycled[0] else: # No backend hosting that model return None else: online, _ = get_backends() if len(online): return online[0]