import json import pickle import threading import time from uuid import uuid4 from redis import Redis from llm_server import opts from llm_server.llm.generator import generator from llm_server.routes.cache import redis from llm_server.routes.stats import generation_elapsed, generation_elapsed_lock redis.set_dict('processing_ips', {}) def increment_ip_count(client_ip: int, redis_key): ip_count = redis.get_dict(redis_key) ip_count[client_ip] = ip_count.get(client_ip, 0) + 1 redis.set_dict(redis_key, ip_count) return ip_count def decrement_ip_count(client_ip: int, redis_key): ip_count = redis.get_dict(redis_key) if client_ip in ip_count.keys(): ip_count[client_ip] -= 1 if ip_count[client_ip] == 0: del ip_count[client_ip] # Remove the IP from the dictionary if count is 0 redis.set_dict(redis_key, ip_count) return ip_count class RedisPriorityQueue: def __init__(self): self._index = 0 self._lock = threading.Lock() self.redis = Redis(host='localhost', port=6379, db=15) # Clear the DB for key in self.redis.scan_iter('*'): self.redis.delete(key) self.pubsub = self.redis.pubsub() self.pubsub.subscribe('events') def put(self, item, priority): event = DataEvent() # Check if the IP is already in the dictionary and if it has reached the limit ip_count = self.redis.hget('queued_ip_count', item[1]) if ip_count and int(ip_count) >= opts.simultaneous_requests_per_ip and priority != 0: return None # reject the request self.redis.zadd('queue', {json.dumps((self._index, item, event.event_id)): -priority}) self._index += 1 # Increment the count for this IP with self._lock: self.increment_ip_count(item[1], 'queued_ip_count') return event def get(self): while True: data = self.redis.zpopmin('queue') if data: item = json.loads(data[0][0]) client_ip = item[1][1] # Decrement the count for this IP with self._lock: self.decrement_ip_count(client_ip, 'queued_ip_count') return item time.sleep(1) # wait for an item to be added to the queue def increment_ip_count(self, ip, key): self.redis.hincrby(key, ip, 1) def decrement_ip_count(self, ip, key): self.redis.hincrby(key, ip, -1) def __len__(self): return self.redis.zcard('queue') class DataEvent: def __init__(self, event_id=None): self.event_id = event_id if event_id else str(uuid4()) self.redis = Redis(host='localhost', port=6379, db=14) self.pubsub = self.redis.pubsub() self.pubsub.subscribe(self.event_id) def set(self, data): self.redis.publish(self.event_id, pickle.dumps(data)) def wait(self): for item in self.pubsub.listen(): if item['type'] == 'message': return pickle.loads(item['data']) priority_queue = RedisPriorityQueue() def worker(): while True: index, (request_json_body, client_ip, token, parameters), event_id = priority_queue.get() increment_ip_count(client_ip, 'processing_ips') redis.incr('active_gen_workers') start_time = time.time() success, response, error_msg = generator(request_json_body) end_time = time.time() elapsed_time = end_time - start_time with generation_elapsed_lock: generation_elapsed.append((end_time, elapsed_time)) event = DataEvent(event_id) event.set((success, response, error_msg)) decrement_ip_count(client_ip, 'processing_ips') redis.decr('active_gen_workers') def start_workers(num_workers: int): for _ in range(num_workers): threading.Thread(target=worker).start()