import json import pickle import time from uuid import uuid4 from redis import Redis from llm_server.custom_redis import RedisCustom, redis from llm_server.database.database import get_token_ratelimit def increment_ip_count(client_ip: str, redis_key): redis.hincrby(redis_key, client_ip, 1) def decrement_ip_count(client_ip: str, redis_key): new_count = redis.hincrby(redis_key, client_ip, -1) if new_count <= 0: redis.hdel(redis_key, client_ip) class RedisPriorityQueue: def __init__(self, name: str = 'priority_queue', db: int = 12): self.redis = RedisCustom(name, db=db) self.pubsub = self.redis.pubsub() self.pubsub.subscribe('events') def put(self, item, priority, selected_model): event = DataEvent() # Check if the IP is already in the dictionary and if it has reached the limit ip_count = self.redis.hget('queued_ip_count', item[1]) if ip_count: ip_count = int(ip_count) _, simultaneous_ip = get_token_ratelimit(item[2]) if ip_count and int(ip_count) >= simultaneous_ip and priority != 0: print(f'Rejecting request from {item[1]} - {ip_count} requests in progress.') return None # reject the request self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model)): -priority}) self.increment_ip_count(item[1], 'queued_ip_count') return event def get(self): while True: data = self.redis.zpopmin('queue') if data: item = json.loads(data[0][0]) client_ip = item[0][1] self.decrement_ip_count(client_ip, 'queued_ip_count') return item time.sleep(0.1) # wait for something to be added to the queue def increment_ip_count(self, client_ip: str, redis_key): self.redis.hincrby(redis_key, client_ip, 1) def decrement_ip_count(self, client_ip: str, redis_key): new_count = self.redis.hincrby(redis_key, client_ip, -1) if new_count <= 0: self.redis.hdel(redis_key, client_ip) def __len__(self): return self.redis.zcard('queue') def len(self, model_name): count = 0 for key in self.redis.zrange('queue', 0, -1): item = json.loads(key) if item[2] == model_name: count += 1 return count def get_queued_ip_count(self, client_ip: str): q = self.redis.hget('queued_ip_count', client_ip) if not q: return 0 return 0 def flush(self): self.redis.flush() class DataEvent: def __init__(self, event_id=None): self.event_id = event_id if event_id else str(uuid4()) self.redis = Redis(host='localhost', port=6379, db=14) self.pubsub = self.redis.pubsub() self.pubsub.subscribe(self.event_id) def set(self, data): self.redis.publish(self.event_id, pickle.dumps(data)) def wait(self): for item in self.pubsub.listen(): if item['type'] == 'message': return pickle.loads(item['data']) priority_queue = RedisPriorityQueue() def update_active_workers(key: str, operation: str): if operation == 'incr': redis.incr(f'active_gen_workers:{key}') elif operation == 'decr': redis.decr(f'active_gen_workers:{key}') if redis.get(f'active_gen_workers:{key}', default=0, dtype=int) < 0: redis.set(f'active_gen_workers:{key}', 0) def incr_active_workers(selected_model: str, backend_url: str): update_active_workers(selected_model, 'incr') update_active_workers(backend_url, 'incr') def decr_active_workers(selected_model: str, backend_url: str): update_active_workers(selected_model, 'decr') update_active_workers(backend_url, 'decr')