diff --git a/llm_server/routes/openai/chat_completions.py b/llm_server/routes/openai/chat_completions.py index 4018fee..426fd98 100644 --- a/llm_server/routes/openai/chat_completions.py +++ b/llm_server/routes/openai/chat_completions.py @@ -102,8 +102,12 @@ def openai_chat_completions(model_name=None): pubsub = redis.pubsub() pubsub.subscribe(event_id) for item in pubsub.listen(): - if item['type'] == 'message' and item['data'].decode('utf-8') == 'begin': - break + if item['type'] == 'message': + msg = item['data'].decode('utf-8') + if msg == 'begin': + break + elif msg == 'offline': + return return_invalid_model_err(handler.request_json_body['model']) time.sleep(0.1) # Double check the model is still online diff --git a/llm_server/routes/openai/completions.py b/llm_server/routes/openai/completions.py index 4dda2f2..2cd8578 100644 --- a/llm_server/routes/openai/completions.py +++ b/llm_server/routes/openai/completions.py @@ -127,8 +127,12 @@ def openai_completions(model_name=None): pubsub = redis.pubsub() pubsub.subscribe(event_id) for item in pubsub.listen(): - if item['type'] == 'message' and item['data'].decode('utf-8') == 'begin': - break + if item['type'] == 'message': + msg = item['data'].decode('utf-8') + if msg == 'begin': + break + elif msg == 'offline': + return return_invalid_model_err(handler.request_json_body['model']) time.sleep(0.1) # Double check the model is still online diff --git a/llm_server/routes/v1/generate_stream.py b/llm_server/routes/v1/generate_stream.py index b918106..79be511 100644 --- a/llm_server/routes/v1/generate_stream.py +++ b/llm_server/routes/v1/generate_stream.py @@ -143,8 +143,12 @@ def do_stream(ws, model_name): pubsub = redis.pubsub() pubsub.subscribe(event_id) for item in pubsub.listen(): - if item['type'] == 'message' and item['data'].decode('utf-8') == 'begin': - break + if item['type'] == 'message': + msg = item['data'].decode('utf-8') + if msg == 'begin': + break + elif msg == 'offline': + return messages.BACKEND_OFFLINE, 404 # TODO: format this error time.sleep(0.1) # Double check the model is still online diff --git a/llm_server/workers/inferencer.py b/llm_server/workers/inferencer.py index 0a9d871..d65d125 100644 --- a/llm_server/workers/inferencer.py +++ b/llm_server/workers/inferencer.py @@ -2,7 +2,6 @@ import threading import time import traceback -from llm_server import messages from llm_server.cluster.cluster_config import cluster_config from llm_server.custom_redis import redis from llm_server.llm.generator import generator @@ -14,10 +13,11 @@ def worker(backend_url): while True: (request_json_body, client_ip, token, parameters), event_id, selected_model = queue.get() backend_info = cluster_config.get_backend(backend_url) + pubsub = redis.pubsub() + pubsub.subscribe(event_id) if not backend_info['online']: - event = DataEvent(event_id) - event.set((False, None, messages.BACKEND_OFFLINE)) + redis.publish(event_id, 'offline') return if not selected_model: @@ -34,8 +34,6 @@ def worker(backend_url): # is finished. Since a lot of ratelimiting and stats are # based off the number of active workers, we must keep # the generation based off the workers. - pubsub = redis.pubsub() - pubsub.subscribe(event_id) redis.publish(event_id, 'begin') for item in pubsub.listen(): if item['type'] == 'message' and item['data'].decode('utf-8') == 'finished':