import json import pickle import time from uuid import uuid4 from redis import Redis from llm_server import opts from llm_server.routes.cache import redis def increment_ip_count(client_ip: str, redis_key): redis.hincrby(redis_key, client_ip, 1) def decrement_ip_count(client_ip: str, redis_key): new_count = redis.hincrby(redis_key, client_ip, -1) if new_count <= 0: redis.hdel(redis_key, client_ip) class RedisPriorityQueue: def __init__(self): self.redis = Redis(host='localhost', port=6379, db=15) self.pubsub = self.redis.pubsub() self.pubsub.subscribe('events') def put(self, item, priority): event = DataEvent() # Check if the IP is already in the dictionary and if it has reached the limit ip_count = self.redis.hget('queued_ip_count', item[1]) if ip_count: ip_count = int(ip_count) if ip_count and int(ip_count) >= opts.simultaneous_requests_per_ip and priority != 0: print(f'Rejecting request from {item[1]} - {ip_count} requests in progress.') return None # reject the request self.redis.zadd('queue', {json.dumps((item, event.event_id)): -priority}) self.increment_ip_count(item[1], 'queued_ip_count') return event def get(self): while True: data = self.redis.zpopmin('queue') if data: item = json.loads(data[0][0]) client_ip = item[0][1] self.decrement_ip_count(client_ip, 'queued_ip_count') return item time.sleep(0.1) # wait for something to be added to the queue def increment_ip_count(self, client_ip: str, redis_key): self.redis.hincrby(redis_key, client_ip, 1) def decrement_ip_count(self, client_ip: str, redis_key): new_count = self.redis.hincrby(redis_key, client_ip, -1) if new_count <= 0: self.redis.hdel(redis_key, client_ip) def __len__(self): return self.redis.zcard('queue') def get_queued_ip_count(self, client_ip: str): q = self.redis.hget('queued_ip_count', client_ip) if not q: return 0 return 0 class DataEvent: def __init__(self, event_id=None): self.event_id = event_id if event_id else str(uuid4()) self.redis = Redis(host='localhost', port=6379, db=14) self.pubsub = self.redis.pubsub() self.pubsub.subscribe(self.event_id) def set(self, data): self.redis.publish(self.event_id, pickle.dumps(data)) def wait(self): for item in self.pubsub.listen(): if item['type'] == 'message': return pickle.loads(item['data']) priority_queue = RedisPriorityQueue() def incr_active_workers(): redis.incr('active_gen_workers') def decr_active_workers(): redis.decr('active_gen_workers') new_count = redis.get('active_gen_workers', int, 0) if new_count < 0: redis.set('active_gen_workers', 0)