import heapq import threading import time from llm_server import opts from llm_server.llm.generator import generator from llm_server.routes.cache import redis from llm_server.routes.stats import generation_elapsed, generation_elapsed_lock redis.set_dict('processing_ips', {}) def increment_ip_count(client_ip: int, redis_key): ip_count = redis.get_dict(redis_key) ip_count[client_ip] = ip_count.get(client_ip, 0) + 1 redis.set_dict(redis_key, ip_count) return ip_count def decrement_ip_count(client_ip: int, redis_key): ip_count = redis.get_dict(redis_key) if client_ip in ip_count.keys(): ip_count[client_ip] -= 1 if ip_count[client_ip] == 0: del ip_count[client_ip] # Remove the IP from the dictionary if count is 0 redis.set_dict(redis_key, ip_count) return ip_count class PriorityQueue: def __init__(self): self._queue = [] self._index = 0 self._cv = threading.Condition() self._lock = threading.Lock() redis.set_dict('queued_ip_count', {}) def put(self, item, priority): event = DataEvent() with self._cv: # Check if the IP is already in the dictionary and if it has reached the limit ip_count = redis.get_dict('queued_ip_count') if item[1] in ip_count and ip_count[item[1]] >= opts.simultaneous_requests_per_ip and priority != 0: return None # reject the request heapq.heappush(self._queue, (-priority, self._index, item, event)) self._index += 1 # Increment the count for this IP with self._lock: increment_ip_count(item[1], 'queued_ip_count') self._cv.notify() return event def get(self): with self._cv: while len(self._queue) == 0: self._cv.wait() _, _, item, event = heapq.heappop(self._queue) # Decrement the count for this IP with self._lock: decrement_ip_count(item[1], 'queued_ip_count') return item, event def __len__(self): return len(self._queue) priority_queue = PriorityQueue() class DataEvent(threading.Event): def __init__(self): super().__init__() self.data = None def worker(): global processing_ips_lock while True: (request_json_body, client_ip, token, parameters), event = priority_queue.get() # redis.sadd('processing_ips', client_ip) increment_ip_count(client_ip, 'processing_ips') redis.incr('active_gen_workers') start_time = time.time() success, response, error_msg = generator(request_json_body) end_time = time.time() elapsed_time = end_time - start_time with generation_elapsed_lock: generation_elapsed.append((end_time, elapsed_time)) event.data = (success, response, error_msg) event.set() # redis.srem('processing_ips', client_ip) decrement_ip_count(client_ip, 'processing_ips') redis.decr('active_gen_workers') def start_workers(num_workers: int): for _ in range(num_workers): threading.Thread(target=worker).start()