diff --git a/llm_server/routes/ooba_request_handler.py b/llm_server/routes/ooba_request_handler.py index 1f186f0..7dbbd27 100644 --- a/llm_server/routes/ooba_request_handler.py +++ b/llm_server/routes/ooba_request_handler.py @@ -1,7 +1,7 @@ from typing import Tuple import flask -from flask import jsonify +from flask import jsonify, request from llm_server import opts from llm_server.database.database import log_prompt @@ -29,17 +29,18 @@ class OobaRequestHandler(RequestHandler): def handle_ratelimited(self): msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.' - disable_st_error_formatting = self.request.headers.get('LLM-ST-Errors', False) == 'true' - if disable_st_error_formatting: - return msg, 429 - else: - backend_response = format_sillytavern_err(msg, 'error') - log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, None, self.parameters, dict(self.request.headers), 429, self.request.url, is_error=True) - return jsonify({ - 'results': [{'text': backend_response}] - }), 429 + backend_response = self.handle_error(msg) + log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, is_error=True) + return backend_response[0], 429 # We only return the response from handle_error(), not the error code + + def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]: + disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true' + if disable_st_error_formatting: + # TODO: how to format this + response_msg = error_msg + else: + response_msg = format_sillytavern_err(error_msg, error_type) - def handle_error(self, msg: str) -> Tuple[flask.Response, int]: return jsonify({ - 'results': [{'text': msg}] + 'results': [{'text': response_msg}] }), 200 # return 200 so we don't trigger an error message in the client's ST diff --git a/llm_server/routes/openai/chat_completions.py b/llm_server/routes/openai/chat_completions.py index 2f05b65..cc27dce 100644 --- a/llm_server/routes/openai/chat_completions.py +++ b/llm_server/routes/openai/chat_completions.py @@ -7,13 +7,12 @@ from flask import Response, jsonify, request from . import openai_bp from ..cache import redis -from ..helpers.client import format_sillytavern_err from ..helpers.http import validate_json from ..openai_request_handler import OpenAIRequestHandler -from ...llm.openai.transform import build_openai_response, generate_oai_string, transform_messages_to_prompt from ... import opts from ...database.database import log_prompt from ...llm.generator import generator +from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt from ...llm.vllm import tokenize @@ -21,7 +20,6 @@ from ...llm.vllm import tokenize @openai_bp.route('/chat/completions', methods=['POST']) def openai_chat_completions(): - disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true' request_valid_json, request_json_body = validate_json(request) if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'): return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400 @@ -110,10 +108,6 @@ def openai_chat_completions(): else: try: return handler.handle_request() - except Exception as e: - print(f'EXCEPTION on {request.url}!!!', f'{e.__class__.__name__}: {e}') + except Exception: traceback.print_exc() - if disable_st_error_formatting: - return '500', 500 - else: - return build_openai_response('', format_sillytavern_err(f'Server encountered exception.', 'error')), 500 + return 'Internal server error', 500 diff --git a/llm_server/routes/openai/completions.py b/llm_server/routes/openai/completions.py index 84e4542..503f628 100644 --- a/llm_server/routes/openai/completions.py +++ b/llm_server/routes/openai/completions.py @@ -17,7 +17,6 @@ from ...llm.openai.transform import build_openai_response, generate_oai_string @openai_bp.route('/completions', methods=['POST']) def openai_completions(): - disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true' request_valid_json, request_json_body = validate_json(request) if not request_valid_json or not request_json_body.get('prompt'): return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 @@ -57,10 +56,6 @@ def openai_completions(): if stats: response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec'] return response - except Exception as e: - print(f'EXCEPTION on {request.url}!!!') - print(traceback.format_exc()) - if disable_st_error_formatting: - return '500', 500 - else: - return build_openai_response('', format_sillytavern_err(f'Server encountered exception.', 'error')), 500 + except Exception: + traceback.print_exc() + return 'Internal Server Error', 500 diff --git a/llm_server/routes/openai_request_handler.py b/llm_server/routes/openai_request_handler.py index 1aedf26..561320c 100644 --- a/llm_server/routes/openai_request_handler.py +++ b/llm_server/routes/openai_request_handler.py @@ -7,9 +7,8 @@ import flask from flask import jsonify from llm_server import opts -from llm_server.database.database import is_api_key_moderated, log_prompt +from llm_server.database.database import is_api_key_moderated from llm_server.llm.openai.transform import build_openai_response, transform_messages_to_prompt, trim_prompt_to_fit -from llm_server.routes.helpers.client import format_sillytavern_err from llm_server.routes.request_handler import RequestHandler from llm_server.threads import add_moderation_task, get_results @@ -71,17 +70,12 @@ class OpenAIRequestHandler(RequestHandler): return backend_response, backend_response_status_code def handle_ratelimited(self): - disable_st_error_formatting = self.request.headers.get('LLM-ST-Errors', False) == 'true' - if disable_st_error_formatting: - # TODO: format this like OpenAI does - return '429', 429 - else: - backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error') - log_prompt(ip=self.client_ip, token=self.token, prompt=self.request_json_body.get('prompt', ''), response=backend_response, gen_time=None, parameters=self.parameters, headers=dict(self.request.headers), backend_response_code=429, request_url=self.request.url, is_error=True) - return build_openai_response(self.prompt, backend_response), 429 + # TODO: return a simulated OpenAI error message + # Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another. + return 'Ratelimited', 429 - def handle_error(self, msg: str) -> Tuple[flask.Response, int]: - # return build_openai_response('', msg), 400 + def handle_error(self, error_msg: str) -> Tuple[flask.Response, int]: + # TODO: return a simulated OpenAI error message return jsonify({ "error": { "message": "Invalid request, check your parameters and try again.", diff --git a/llm_server/routes/request_handler.py b/llm_server/routes/request_handler.py index 90ca620..45fa8f8 100644 --- a/llm_server/routes/request_handler.py +++ b/llm_server/routes/request_handler.py @@ -119,12 +119,7 @@ class RequestHandler: else: # Otherwise, just grab the first and only one. combined_error_message = invalid_request_err_msgs[0] + '.' - msg = f'Validation Error: {combined_error_message}' - disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true' - if disable_st_error_formatting: - backend_response = (Response(msg, 400), 400) - else: - backend_response = self.handle_error(format_sillytavern_err(msg, 'error')) + backend_response = self.handle_error(combined_error_message, 'Validation Error') if do_log: log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True) @@ -168,12 +163,7 @@ class RequestHandler: error_msg = 'Unknown error.' else: error_msg = error_msg.strip('.') + '.' - - disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true' - if disable_st_error_formatting: - backend_response = (Response(error_msg, 400), 400) - else: - backend_response = self.handle_error(format_sillytavern_err(error_msg, 'error')) + backend_response = self.handle_error(error_msg) log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True) return (False, None, None, 0), backend_response @@ -193,13 +183,8 @@ class RequestHandler: if return_json_err: error_msg = 'The backend did not return valid JSON.' - disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true' - if disable_st_error_formatting: - # TODO: how to format this - backend_response = (Response(error_msg, 400), 400) - else: - backend_response = self.handle_error(format_sillytavern_err(error_msg, 'error')) - log_prompt(self.client_ip, self.token, prompt, backend_response, elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True) + backend_response = self.handle_error(error_msg) + log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True) return (False, None, None, 0), backend_response # =============================================== @@ -223,7 +208,7 @@ class RequestHandler: def handle_ratelimited(self) -> Tuple[flask.Response, int]: raise NotImplementedError - def handle_error(self, msg: str) -> Tuple[flask.Response, int]: + def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]: raise NotImplementedError diff --git a/llm_server/routes/v1/generate.py b/llm_server/routes/v1/generate.py index 49fc43e..715288f 100644 --- a/llm_server/routes/v1/generate.py +++ b/llm_server/routes/v1/generate.py @@ -10,17 +10,13 @@ from ..ooba_request_handler import OobaRequestHandler @bp.route('/generate', methods=['POST']) def generate(): - disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true' request_valid_json, request_json_body = validate_json(request) if not request_valid_json or not request_json_body.get('prompt'): return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 else: + handler = OobaRequestHandler(request) try: - return OobaRequestHandler(request).handle_request() - except Exception as e: - print(f'EXCEPTION on {request.url}!!!') - print(traceback.format_exc()) - if disable_st_error_formatting: - return '500', 500 - else: - return format_sillytavern_err(f'Server encountered exception.', 'error'), 500 + return handler.handle_request() + except Exception: + traceback.print_exc() + return handler.handle_error('Server encountered exception.', 'exception')[0], 500 diff --git a/llm_server/routes/v1/generate_stream.py b/llm_server/routes/v1/generate_stream.py index 789fb4f..aa73120 100644 --- a/llm_server/routes/v1/generate_stream.py +++ b/llm_server/routes/v1/generate_stream.py @@ -25,8 +25,6 @@ def stream(ws): r_headers = dict(request.headers) r_url = request.url - disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true' - message_num = 0 while ws.connected: message = ws.receive() @@ -135,23 +133,22 @@ def stream(ws): thread.start() thread.join() except: - if not disable_st_error_formatting: - generated_text = generated_text + '\n\n' + format_sillytavern_err('Encountered error while streaming.', 'error') - traceback.print_exc() - ws.send(json.dumps({ - 'event': 'text_stream', - 'message_num': message_num, - 'text': generated_text - })) + generated_text = generated_text + '\n\n' + handler.handle_error('Encountered error while streaming.', 'exception')[0].data.decode('utf-8') + traceback.print_exc() + ws.send(json.dumps({ + 'event': 'text_stream', + 'message_num': message_num, + 'text': generated_text + })) - def background_task_exception(): - generated_tokens = tokenize(generated_text) - log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, None, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens) + def background_task_exception(): + generated_tokens = tokenize(generated_text) + log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, None, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens) - # TODO: use async/await instead of threads - thread = threading.Thread(target=background_task_exception) - thread.start() - thread.join() + # TODO: use async/await instead of threads + thread = threading.Thread(target=background_task_exception) + thread.start() + thread.join() try: ws.send(json.dumps({ 'event': 'stream_end', diff --git a/server.py b/server.py index 39a2eaa..6a9949f 100644 --- a/server.py +++ b/server.py @@ -26,6 +26,7 @@ from llm_server.routes.v1 import bp from llm_server.stream import init_socketio # TODO: have the workers handle streaming too +# TODO: send extra headers when ratelimited? # TODO: return 200 when returning formatted sillytavern error # TODO: add some sort of loadbalancer to send requests to a group of backends # TODO: allow setting concurrent gens per-backend