import threading import time from llm_server.cluster.cluster_config import cluster_config, get_a_cluster_backend from llm_server.custom_redis import redis from llm_server.llm.generator import generator from llm_server.routes.queue import DataEvent, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count, priority_queue def worker(): while True: (request_json_body, client_ip, token, parameters, backend_url), event_id, selected_model = priority_queue.get() if not backend_url: backend_url = get_a_cluster_backend(selected_model) else: backend_url = cluster_config.validate_backend(backend_url) backend_info = cluster_config.get_backend(backend_url) if not selected_model: selected_model = backend_info['model'] # This wait time will be "invisible", meaning the worker may as # well be still waiting to get an item from the queue. need_to_wait(backend_url) increment_ip_count(client_ip, 'processing_ips') incr_active_workers(selected_model, backend_url) print('<---', event_id) if not request_json_body: # This was a dummy request from the websocket handlers. # We're going to let the websocket handler decrement # processing_ips and active_gen_workers. event = DataEvent(event_id) event.set((True, None, None)) continue try: success, response, error_msg = generator(request_json_body, backend_url) event = DataEvent(event_id) event.set((success, response, error_msg)) finally: decrement_ip_count(client_ip, 'processing_ips') decr_active_workers(selected_model, backend_url) def start_workers(num_workers: int): i = 0 for _ in range(num_workers): t = threading.Thread(target=worker) t.daemon = True t.start() i += 1 print(f'Started {i} inference workers.') def need_to_wait(backend_url: str): # We need to check the number of active workers since the streaming endpoint may be doing something. active_workers = redis.get(f'active_gen_workers:{backend_url}', 0, dtype=int) concurrent_gens = cluster_config.get_backend(backend_url).get('concurrent_gens', 1) s = time.time() print(active_workers) while active_workers >= concurrent_gens: time.sleep(0.01) e = time.time() if e - s > 0.1: print(f'Worker was delayed {e - s} seconds.')