import json import pickle import time from uuid import uuid4 from redis import Redis from llm_server import opts from llm_server.routes.cache import redis def increment_ip_count(client_ip: int, redis_key): ip_count = redis.get_dict(redis_key) ip_count[client_ip] = ip_count.get(client_ip, 0) + 1 redis.set_dict(redis_key, ip_count) return ip_count def decrement_ip_count(client_ip: int, redis_key): ip_count = redis.get_dict(redis_key) if client_ip in ip_count.keys(): ip_count[client_ip] -= 1 if ip_count[client_ip] == 0: del ip_count[client_ip] # Remove the IP from the dictionary if count is 0 redis.set_dict(redis_key, ip_count) return ip_count class RedisPriorityQueue: def __init__(self): self.redis = Redis(host='localhost', port=6379, db=15) self.pubsub = self.redis.pubsub() self.pubsub.subscribe('events') def put(self, item, priority): event = DataEvent() # Check if the IP is already in the dictionary and if it has reached the limit ip_count = self.redis.hget('queued_ip_count', item[1]) if ip_count: ip_count = int(ip_count) if ip_count and int(ip_count) >= opts.simultaneous_requests_per_ip and priority != 0: print(f'Rejecting request from {item[1]} - {ip_count} requests in progress.') return None # reject the request self.redis.zadd('queue', {json.dumps((item, event.event_id)): -priority}) self.increment_ip_count(item[1], 'queued_ip_count') return event def get(self): while True: data = self.redis.zpopmin('queue') if data: item = json.loads(data[0][0]) client_ip = item[0][1] self.decrement_ip_count(client_ip, 'queued_ip_count') return item time.sleep(0.5) # wait for something to be added to the queue def increment_ip_count(self, ip, key): self.redis.hincrby(key, ip, 1) def decrement_ip_count(self, ip, key): self.redis.hincrby(key, ip, -1) def __len__(self): return self.redis.zcard('queue') def get_ip_count(self, client_ip: str): x = self.redis.hget('queued_ip_count', client_ip) if x: return x.decode('utf-8') else: return x class DataEvent: def __init__(self, event_id=None): self.event_id = event_id if event_id else str(uuid4()) self.redis = Redis(host='localhost', port=6379, db=14) self.pubsub = self.redis.pubsub() self.pubsub.subscribe(self.event_id) def set(self, data): self.redis.publish(self.event_id, pickle.dumps(data)) def wait(self): for item in self.pubsub.listen(): if item['type'] == 'message': return pickle.loads(item['data']) priority_queue = RedisPriorityQueue()