diff --git a/llm_server/llm/openai/oai_to_vllm.py b/llm_server/llm/openai/oai_to_vllm.py index 1490933..cde5180 100644 --- a/llm_server/llm/openai/oai_to_vllm.py +++ b/llm_server/llm/openai/oai_to_vllm.py @@ -77,3 +77,18 @@ def validate_oai(parameters): if parameters.get('max_tokens', 2) < 1: return format_oai_err(f"{parameters['max_tokens']} is less than the minimum of 1 - 'max_tokens'") + + +def return_invalid_model_err(requested_model: str): + if requested_model: + msg = f"The model `{requested_model}` does not exist" + else: + msg = "The requested model does not exist" + return jsonify({ + "error": { + "message": msg, + "type": "invalid_request_error", + "param": None, + "code": "model_not_found" + } + }), 404 diff --git a/llm_server/routes/openai/chat_completions.py b/llm_server/routes/openai/chat_completions.py index 76f1b6c..d054703 100644 --- a/llm_server/routes/openai/chat_completions.py +++ b/llm_server/routes/openai/chat_completions.py @@ -12,7 +12,7 @@ from ..queue import priority_queue from ... import opts from ...database.log_to_db import log_to_db from ...llm.generator import generator -from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai +from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_invalid_model_err from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit @@ -27,6 +27,11 @@ def openai_chat_completions(model_name=None): return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400 else: handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name) + if handler.offline: + msg = return_invalid_model_err(model_name) + print(msg) + return handler.handle_error(msg) + if not request_json_body.get('stream'): try: invalid_oai_err_msg = validate_oai(request_json_body) diff --git a/llm_server/routes/openai/completions.py b/llm_server/routes/openai/completions.py index c8e7f19..3dcde2e 100644 --- a/llm_server/routes/openai/completions.py +++ b/llm_server/routes/openai/completions.py @@ -13,7 +13,7 @@ from ... import opts from ...database.log_to_db import log_to_db from ...llm import get_token_count from ...llm.generator import generator -from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai +from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_invalid_model_err from ...llm.openai.transform import generate_oai_string, trim_string_to_fit @@ -27,6 +27,10 @@ def openai_completions(model_name=None): return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 else: handler = OobaRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name) + if handler.offline: + msg = return_invalid_model_err(model_name) + print(msg) + return handler.handle_error(msg) if handler.cluster_backend_info['mode'] != 'vllm': # TODO: implement other backends diff --git a/llm_server/routes/openai_request_handler.py b/llm_server/routes/openai_request_handler.py index bc5c6f5..3c2a5b1 100644 --- a/llm_server/routes/openai_request_handler.py +++ b/llm_server/routes/openai_request_handler.py @@ -11,10 +11,10 @@ from flask import Response, jsonify, make_response from llm_server import opts from llm_server.cluster.backend import get_model_choices from llm_server.custom_redis import redis -from llm_server.database.database import is_api_key_moderated, do_db_log +from llm_server.database.database import is_api_key_moderated from llm_server.database.log_to_db import log_to_db from llm_server.llm import get_token_count -from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai +from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_invalid_model_err from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit from llm_server.routes.request_handler import RequestHandler from llm_server.workers.moderator import add_moderation_task, get_results @@ -27,6 +27,10 @@ class OpenAIRequestHandler(RequestHandler): def handle_request(self) -> Tuple[flask.Response, int]: assert not self.used + if self.offline: + msg = return_invalid_model_err(self.selected_model) + print(msg) + return self.handle_error(msg) if opts.openai_silent_trim: oai_messages = trim_messages_to_fit(self.request.json['messages'], self.cluster_backend_info['model_config']['max_position_embeddings'], self.backend_url) @@ -63,7 +67,7 @@ class OpenAIRequestHandler(RequestHandler): if not self.prompt: # TODO: format this as an openai error message - return 'Invalid prompt', 400 + return Response('Invalid prompt'), 400 llm_request = {**self.parameters, 'prompt': self.prompt} (success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request) diff --git a/llm_server/workers/inferencer.py b/llm_server/workers/inferencer.py index f06ff8c..79eed01 100644 --- a/llm_server/workers/inferencer.py +++ b/llm_server/workers/inferencer.py @@ -1,5 +1,6 @@ import threading import time +import traceback from uuid import uuid4 from llm_server.cluster.cluster_config import cluster_config @@ -41,6 +42,8 @@ def worker(backend_url): success, response, error_msg = generator(request_json_body, backend_url) event = DataEvent(event_id) event.set((success, response, error_msg)) + except: + traceback.print_exc() finally: decrement_ip_count(client_ip, 'processing_ips') decr_active_workers(selected_model, backend_url)