diff --git a/llm_server/llm/llm_backend.py b/llm_server/llm/llm_backend.py index 2ac2beb..f864b18 100644 --- a/llm_server/llm/llm_backend.py +++ b/llm_server/llm/llm_backend.py @@ -3,7 +3,6 @@ from typing import Tuple, Union import flask from llm_server.cluster.cluster_config import cluster_config -from llm_server.custom_redis import redis from llm_server.llm import get_token_count diff --git a/llm_server/messages.py b/llm_server/messages.py new file mode 100644 index 0000000..c7e3eb7 --- /dev/null +++ b/llm_server/messages.py @@ -0,0 +1 @@ +BACKEND_OFFLINE = 'The model you requested is not a valid choice. Please retry your query.' diff --git a/llm_server/routes/ooba_request_handler.py b/llm_server/routes/ooba_request_handler.py index ab20fbd..6966e32 100644 --- a/llm_server/routes/ooba_request_handler.py +++ b/llm_server/routes/ooba_request_handler.py @@ -3,7 +3,7 @@ from typing import Tuple import flask from flask import jsonify, request -from llm_server import opts +from llm_server import messages, opts from llm_server.database.log_to_db import log_to_db from llm_server.routes.helpers.client import format_sillytavern_err from llm_server.routes.request_handler import RequestHandler @@ -16,9 +16,8 @@ class OobaRequestHandler(RequestHandler): def handle_request(self, return_ok: bool = True): assert not self.used if self.offline: - msg = 'The model you requested is not a valid choice. Please retry your query.' - print(msg) - self.handle_error(msg) + print(messages.BACKEND_OFFLINE) + self.handle_error(messages.BACKEND_OFFLINE) request_valid, invalid_response = self.validate_request() if not request_valid: diff --git a/llm_server/routes/openai/chat_completions.py b/llm_server/routes/openai/chat_completions.py index 99b7488..4018fee 100644 --- a/llm_server/routes/openai/chat_completions.py +++ b/llm_server/routes/openai/chat_completions.py @@ -12,7 +12,7 @@ from ..queue import priority_queue from ... import opts from ...database.log_to_db import log_to_db from ...llm.generator import generator -from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_invalid_model_err +from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit @@ -106,6 +106,10 @@ def openai_chat_completions(model_name=None): break time.sleep(0.1) + # Double check the model is still online + if not handler.check_online(): + return return_invalid_model_err(handler.request_json_body['model']) + try: r_headers = dict(request.headers) r_url = request.url diff --git a/llm_server/routes/openai/completions.py b/llm_server/routes/openai/completions.py index 8851fcc..4dda2f2 100644 --- a/llm_server/routes/openai/completions.py +++ b/llm_server/routes/openai/completions.py @@ -13,7 +13,7 @@ from ... import opts from ...database.log_to_db import log_to_db from ...llm import get_token_count from ...llm.generator import generator -from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_invalid_model_err +from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai from ...llm.openai.transform import generate_oai_string, trim_string_to_fit @@ -131,6 +131,10 @@ def openai_completions(model_name=None): break time.sleep(0.1) + # Double check the model is still online + if not handler.check_online(): + return return_invalid_model_err(handler.request_json_body['model']) + try: response = generator(msg_to_backend, handler.backend_url) r_headers = dict(request.headers) diff --git a/llm_server/routes/request_handler.py b/llm_server/routes/request_handler.py index 6a6ad4a..4011030 100644 --- a/llm_server/routes/request_handler.py +++ b/llm_server/routes/request_handler.py @@ -58,6 +58,10 @@ class RequestHandler: # "recent_prompters" is only used for stats. redis.zadd('recent_prompters', {self.client_ip: time.time()}) + def check_online(self) -> bool: + self.cluster_backend_info = cluster_config.get_backend(self.backend_url) + return self.cluster_backend_info['online'] + def get_auth_token(self): if self.request_json_body.get('X-API-KEY'): return self.request_json_body['X-API-KEY'] diff --git a/llm_server/routes/v1/generate_stream.py b/llm_server/routes/v1/generate_stream.py index c55e36f..b918106 100644 --- a/llm_server/routes/v1/generate_stream.py +++ b/llm_server/routes/v1/generate_stream.py @@ -8,7 +8,7 @@ from . import bp from ..helpers.http import require_api_key, validate_json from ..ooba_request_handler import OobaRequestHandler from ..queue import priority_queue -from ... import opts +from ... import messages, opts from ...custom_redis import redis from ...database.log_to_db import log_to_db from ...llm.generator import generator @@ -147,9 +147,12 @@ def do_stream(ws, model_name): break time.sleep(0.1) + # Double check the model is still online + if not handler.check_online(): + return messages.BACKEND_OFFLINE, 404 # TODO: format this error + try: response = generator(llm_request, handler.backend_url) - if not response: error_msg = 'Failed to reach backend while streaming.' print('Streaming failed:', error_msg) diff --git a/llm_server/workers/inferencer.py b/llm_server/workers/inferencer.py index 0357eab..0a9d871 100644 --- a/llm_server/workers/inferencer.py +++ b/llm_server/workers/inferencer.py @@ -1,12 +1,12 @@ import threading import time import traceback -from uuid import uuid4 +from llm_server import messages from llm_server.cluster.cluster_config import cluster_config -from llm_server.custom_redis import redis, RedisCustom +from llm_server.custom_redis import redis from llm_server.llm.generator import generator -from llm_server.routes.queue import DataEvent, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count, RedisPriorityQueue, PriorityQueue, priority_queue +from llm_server.routes.queue import DataEvent, RedisPriorityQueue, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count def worker(backend_url): @@ -14,14 +14,18 @@ def worker(backend_url): while True: (request_json_body, client_ip, token, parameters), event_id, selected_model = queue.get() backend_info = cluster_config.get_backend(backend_url) + + if not backend_info['online']: + event = DataEvent(event_id) + event.set((False, None, messages.BACKEND_OFFLINE)) + return + if not selected_model: selected_model = backend_info['model'] increment_ip_count(client_ip, 'processing_ips') incr_active_workers(selected_model, backend_url) - print('Worker starting processing for', client_ip) - try: if not request_json_body: # This was a dummy request from the streaming handlers. @@ -48,7 +52,6 @@ def worker(backend_url): finally: decrement_ip_count(client_ip, 'processing_ips') decr_active_workers(selected_model, backend_url) - print('Worker finished processing for', client_ip) def start_workers(cluster: dict): diff --git a/server.py b/server.py index 89c71aa..490eebe 100644 --- a/server.py +++ b/server.py @@ -40,7 +40,7 @@ from llm_server.sock import init_socketio # TODO: if a backend is at its limit of concurrent requests, choose a different one # Lower priority -# TODO: fix moderation freezing after a while +# TODO: make error messages consitient # TODO: support logit_bias on OpenAI and Ooba endpoints. # TODO: add a way to cancel VLLM gens. Maybe use websockets? # TODO: validate openai_silent_trim works as expected and only when enabled