import pickle import time from typing import Tuple from uuid import uuid4 import ujson as json from redis import Redis from llm_server.cluster.cluster_config import cluster_config from llm_server.config.global_config import GlobalConfig from llm_server.custom_redis import RedisCustom, redis from llm_server.database.database import get_token_ratelimit from llm_server.logging import create_logger def increment_ip_count(client_ip: str, redis_key): redis.hincrby(redis_key, client_ip, 1) def decrement_ip_count(client_ip: str, redis_key): new_count = redis.hincrby(redis_key, client_ip, -1) if new_count <= 0: redis.hdel(redis_key, client_ip) class RedisPriorityQueue: """ A queue for a specific backend. """ def __init__(self, name, db: int = 12): self.name = name self.redis = RedisCustom(name, db=db) self._logger = create_logger('RedisPriorityQueue') def put(self, item, priority: int, selected_model: str, do_stream: bool = False): # TODO: remove this when we're sure nothing strange is happening assert item is not None assert priority is not None assert selected_model is not None # Check if the IP is already in the dictionary and if it has reached the limit ip_count = self.get_ip_request_count(item[1]) _, simultaneous_ip = get_token_ratelimit(item[2]) if ip_count and int(ip_count) >= simultaneous_ip and priority != 0: self._logger.debug(f'Rejecting request from {item[1]} - {ip_count} request queued.') return None # reject the request timestamp = time.time() event = DataEvent() self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model, timestamp, do_stream)): -priority}) return event def get(self): while True: data = self.redis.zpopmin('queue') if data: item = json.loads(data[0][0]) return item time.sleep(0.1) # wait for something to be added to the queue def __len__(self): return self.redis.zcard('queue') def get_ip_request_count(self, client_ip: str): """ Get the number of requests in the queue from a specific IP. This is a bit inefficient since we iterate over the entire queue, but keeps the queue as a single point of truth instead of tracking a separate hashed set which can get confusing. If we run into slowdowns in the future, we should go back to the hashed set approach. :param client_ip: :return: """ start_time = time.time() items = self.redis.zrange('queue', 0, -1) count = 0 for item in items: item_data = json.loads(item) if item_data[0][1] == client_ip: count += 1 elapsed_time = time.time() - start_time if elapsed_time > 0.5: raise Exception(f"!!! get_ip_request_count took {elapsed_time} seconds to execute !!!") return count def flush(self): self.redis.flush() def items(self): return self.redis.zrange('queue', 0, -1) def cleanup(self): now = time.time() for item in self.items(): item_data = json.loads(item) timestamp = item_data[-2] if now - timestamp > GlobalConfig.get().backend_generate_request_timeout: self.redis.zrem('queue', 0, item) event_id = item_data[1] event = DataEvent(event_id) event.set((False, None, 'closed')) self._logger.debug('Removed timed-out item from queue: {event_id}') class DataEvent: """ Class to simplify pub/sub communication between consumers and producers (MASTERS and SLAVES lololololol). """ def __init__(self, event_id: str = None): self.event_id = event_id if event_id else str(uuid4()) self.redis = Redis(host='localhost', port=6379, db=14) self.pubsub = self.redis.pubsub() self.pubsub.subscribe(self.event_id) def set(self, data): self.redis.publish(self.event_id, pickle.dumps(data)) def wait(self): for item in self.pubsub.listen(): if item['type'] == 'message': return pickle.loads(item['data']) def update_active_workers(key: str, operation: str): if operation == 'incr': redis.incr(f'active_gen_workers:{key}') elif operation == 'decr': redis.decr(f'active_gen_workers:{key}') if redis.get(f'active_gen_workers:{key}', default=0, dtype=int) < 0: redis.set(f'active_gen_workers:{key}', 0) def incr_active_workers(selected_model: str, backend_url: str): update_active_workers(selected_model, 'incr') update_active_workers(backend_url, 'incr') def decr_active_workers(selected_model: str, backend_url: str): update_active_workers(selected_model, 'decr') update_active_workers(backend_url, 'decr') class PriorityQueue: """ Helper class to wrangler all the different queues. """ def __init__(self, backends: set = None): """ Only have to load the backends once. :param backends: """ self.redis = Redis(host='localhost', port=6379, db=9) if backends: for item in backends: self.redis.sadd('backends', item) def get_backends(self): return {x.decode('utf-8') for x in self.redis.smembers('backends')} def get_queued_ip_count(self, client_ip: str): count = 0 for backend_url in self.get_backends(): queue = RedisPriorityQueue(backend_url) count += queue.get_ip_request_count(client_ip) return count def put(self, backend_url, item: Tuple[dict, str, str, dict], priority: int, selected_model: str, do_stream: bool = False): queue = RedisPriorityQueue(backend_url) return queue.put(item, priority, selected_model, do_stream) def activity(self): lines = [] status_redis = RedisCustom('worker_status') for worker in status_redis.keys(): lines.append((worker, status_redis.getp(worker))) return sorted(lines) def len(self, model_name): count = 0 backends_with_models = set() for k in self.get_backends(): info = cluster_config.get_backend(k) if info.get('model') == model_name: backends_with_models.add(k) for backend_url in backends_with_models: count += len(RedisPriorityQueue(backend_url)) return count def __len__(self): count = 0 p = set() for backend_url in self.get_backends(): queue = RedisPriorityQueue(backend_url) p.add((backend_url, len(queue))) count += len(queue) return count def flush(self): for k in self.redis.keys(): q = json.loads(self.redis.get(k)) q.flush() self.redis.set(k, json.dumps(q)) def flush_db(self): self.redis.flushdb() priority_queue = PriorityQueue()