diff --git a/llm_server/routes/queue.py b/llm_server/routes/queue.py index 5d01e94..84cc614 100644 --- a/llm_server/routes/queue.py +++ b/llm_server/routes/queue.py @@ -85,3 +85,14 @@ class DataEvent: priority_queue = RedisPriorityQueue() + + +def incr_active_workers(): + redis.incr('active_gen_workers') + + +def decr_active_workers(): + redis.decr('active_gen_workers') + new_count = redis.get('active_gen_workers', int, 0) + if new_count < 0: + redis.set('active_gen_workers', 0) diff --git a/llm_server/routes/v1/generate_stream.py b/llm_server/routes/v1/generate_stream.py index 14af81b..f6e978b 100644 --- a/llm_server/routes/v1/generate_stream.py +++ b/llm_server/routes/v1/generate_stream.py @@ -9,7 +9,7 @@ from flask import request from ..cache import redis from ..helpers.http import require_api_key, validate_json from ..ooba_request_handler import OobaRequestHandler -from ..queue import decrement_ip_count, priority_queue +from ..queue import decr_active_workers, decrement_ip_count, priority_queue from ... import opts from ...database.database import log_prompt from ...llm.generator import generator @@ -167,7 +167,7 @@ def stream(ws): finally: # The worker incremented it, we'll decrement it. decrement_ip_count(handler.client_ip, 'processing_ips') - redis.decr('active_gen_workers') + decr_active_workers() try: ws.send(json.dumps({ 'event': 'stream_end', diff --git a/llm_server/workers/blocking.py b/llm_server/workers/blocking.py index 1104065..27b0815 100644 --- a/llm_server/workers/blocking.py +++ b/llm_server/workers/blocking.py @@ -4,7 +4,7 @@ import time from llm_server import opts from llm_server.llm.generator import generator from llm_server.routes.cache import redis -from llm_server.routes.queue import DataEvent, decrement_ip_count, increment_ip_count, priority_queue +from llm_server.routes.queue import DataEvent, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count, priority_queue def worker(): @@ -14,7 +14,7 @@ def worker(): need_to_wait() increment_ip_count(client_ip, 'processing_ips') - redis.incr('active_gen_workers') + incr_active_workers() if not request_json_body: # This was a dummy request from the websocket handler. @@ -27,7 +27,7 @@ def worker(): event.set((success, response, error_msg)) finally: decrement_ip_count(client_ip, 'processing_ips') - redis.decr('active_gen_workers') + decr_active_workers() def start_workers(num_workers: int):