import json import pickle import time from typing import Tuple from uuid import uuid4 from redis import Redis from llm_server.cluster.cluster_config import cluster_config from llm_server.custom_redis import RedisCustom, redis from llm_server.database.database import get_token_ratelimit def increment_ip_count(client_ip: str, redis_key): redis.hincrby(redis_key, client_ip, 1) def decrement_ip_count(client_ip: str, redis_key): new_count = redis.hincrby(redis_key, client_ip, -1) if new_count <= 0: redis.hdel(redis_key, client_ip) class RedisPriorityQueue: def __init__(self, name, db: int = 12): self.redis = RedisCustom(name, db=db) def put(self, item, priority, selected_model): event = DataEvent() # Check if the IP is already in the dictionary and if it has reached the limit ip_count = self.redis.hget('queued_ip_count', item[1]) if ip_count: ip_count = int(ip_count) _, simultaneous_ip = get_token_ratelimit(item[2]) if ip_count and int(ip_count) >= simultaneous_ip and priority != 0: print(f'Rejecting request from {item[1]} - {ip_count} requests in progress.') return None # reject the request self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model)): -priority}) self.increment_ip_count(item[1], 'queued_ip_count') return event def get(self): while True: data = self.redis.zpopmin('queue') if data: item = json.loads(data[0][0]) client_ip = item[0][1] self.decrement_ip_count(client_ip, 'queued_ip_count') return item time.sleep(0.1) # wait for something to be added to the queue def print_all_items(self): items = self.redis.zrange('queue', 0, -1) for item in items: print(item.decode('utf-8')) def increment_ip_count(self, client_ip: str, redis_key): new_count = self.redis.hincrby(redis_key, client_ip, 1) def decrement_ip_count(self, client_ip: str, redis_key): new_count = self.redis.hincrby(redis_key, client_ip, -1) if new_count <= 0: self.redis.hdel(redis_key, client_ip) def __len__(self): return self.redis.zcard('queue') def get_queued_ip_count(self, client_ip: str): q = self.redis.hget('queued_ip_count', client_ip) if not q: return 0 return 0 def flush(self): self.redis.flush() class DataEvent: def __init__(self, event_id=None): self.event_id = event_id if event_id else str(uuid4()) self.redis = Redis(host='localhost', port=6379, db=14) self.pubsub = self.redis.pubsub() self.pubsub.subscribe(self.event_id) def set(self, data): self.redis.publish(self.event_id, pickle.dumps(data)) def wait(self): for item in self.pubsub.listen(): if item['type'] == 'message': return pickle.loads(item['data']) def update_active_workers(key: str, operation: str): if operation == 'incr': redis.incr(f'active_gen_workers:{key}') elif operation == 'decr': redis.decr(f'active_gen_workers:{key}') if redis.get(f'active_gen_workers:{key}', default=0, dtype=int) < 0: redis.set(f'active_gen_workers:{key}', 0) def incr_active_workers(selected_model: str, backend_url: str): update_active_workers(selected_model, 'incr') update_active_workers(backend_url, 'incr') def decr_active_workers(selected_model: str, backend_url: str): update_active_workers(selected_model, 'decr') update_active_workers(backend_url, 'decr') class PriorityQueue: def __init__(self, backends: list = None): """ Only have to load the backends once. :param backends: """ self.redis = Redis(host='localhost', port=6379, db=9) if backends: for item in backends: self.redis.lpush('backends', item) def get_backends(self): return [x.decode('utf-8') for x in self.redis.lrange('backends', 0, -1)] def get_queued_ip_count(self, client_ip: str): count = 0 for backend_url in self.get_backends(): queue = RedisPriorityQueue(backend_url) count += queue.get_queued_ip_count(client_ip) return count def put(self, backend_url, item: Tuple[dict, str, str, dict], priority: int, selected_model: str): queue = RedisPriorityQueue(backend_url) return queue.put(item, priority, selected_model) def len(self, model_name): count = 0 backends_with_models = [] for k in self.get_backends(): info = cluster_config.get_backend(k) if info.get('model') == model_name: backends_with_models.append(k) for backend_url in backends_with_models: count += len(RedisPriorityQueue(backend_url)) return count def __len__(self): count = 0 for backend_url in self.get_backends(): queue = RedisPriorityQueue(backend_url) count += len(queue) return count def flush(self): for k in self.redis.keys(): q = json.loads(self.redis.get(k)) q.flush() self.redis.set(k, json.dumps(q)) def flush_db(self): self.redis.flushdb() priority_queue = PriorityQueue()