unify error message handling
This commit is contained in:
parent
957a6cd092
commit
105b66d5e2
|
@ -1,7 +1,7 @@
|
|||
from typing import Tuple
|
||||
|
||||
import flask
|
||||
from flask import jsonify
|
||||
from flask import jsonify, request
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database.database import log_prompt
|
||||
|
@ -29,17 +29,18 @@ class OobaRequestHandler(RequestHandler):
|
|||
|
||||
def handle_ratelimited(self):
|
||||
msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.'
|
||||
disable_st_error_formatting = self.request.headers.get('LLM-ST-Errors', False) == 'true'
|
||||
if disable_st_error_formatting:
|
||||
return msg, 429
|
||||
else:
|
||||
backend_response = format_sillytavern_err(msg, 'error')
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, None, self.parameters, dict(self.request.headers), 429, self.request.url, is_error=True)
|
||||
return jsonify({
|
||||
'results': [{'text': backend_response}]
|
||||
}), 429
|
||||
backend_response = self.handle_error(msg)
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, is_error=True)
|
||||
return backend_response[0], 429 # We only return the response from handle_error(), not the error code
|
||||
|
||||
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
|
||||
disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true'
|
||||
if disable_st_error_formatting:
|
||||
# TODO: how to format this
|
||||
response_msg = error_msg
|
||||
else:
|
||||
response_msg = format_sillytavern_err(error_msg, error_type)
|
||||
|
||||
def handle_error(self, msg: str) -> Tuple[flask.Response, int]:
|
||||
return jsonify({
|
||||
'results': [{'text': msg}]
|
||||
'results': [{'text': response_msg}]
|
||||
}), 200 # return 200 so we don't trigger an error message in the client's ST
|
||||
|
|
|
@ -7,13 +7,12 @@ from flask import Response, jsonify, request
|
|||
|
||||
from . import openai_bp
|
||||
from ..cache import redis
|
||||
from ..helpers.client import format_sillytavern_err
|
||||
from ..helpers.http import validate_json
|
||||
from ..openai_request_handler import OpenAIRequestHandler
|
||||
from ...llm.openai.transform import build_openai_response, generate_oai_string, transform_messages_to_prompt
|
||||
from ... import opts
|
||||
from ...database.database import log_prompt
|
||||
from ...llm.generator import generator
|
||||
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt
|
||||
from ...llm.vllm import tokenize
|
||||
|
||||
|
||||
|
@ -21,7 +20,6 @@ from ...llm.vllm import tokenize
|
|||
|
||||
@openai_bp.route('/chat/completions', methods=['POST'])
|
||||
def openai_chat_completions():
|
||||
disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true'
|
||||
request_valid_json, request_json_body = validate_json(request)
|
||||
if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'):
|
||||
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
|
||||
|
@ -110,10 +108,6 @@ def openai_chat_completions():
|
|||
else:
|
||||
try:
|
||||
return handler.handle_request()
|
||||
except Exception as e:
|
||||
print(f'EXCEPTION on {request.url}!!!', f'{e.__class__.__name__}: {e}')
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
if disable_st_error_formatting:
|
||||
return '500', 500
|
||||
else:
|
||||
return build_openai_response('', format_sillytavern_err(f'Server encountered exception.', 'error')), 500
|
||||
return 'Internal server error', 500
|
||||
|
|
|
@ -17,7 +17,6 @@ from ...llm.openai.transform import build_openai_response, generate_oai_string
|
|||
|
||||
@openai_bp.route('/completions', methods=['POST'])
|
||||
def openai_completions():
|
||||
disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true'
|
||||
request_valid_json, request_json_body = validate_json(request)
|
||||
if not request_valid_json or not request_json_body.get('prompt'):
|
||||
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
||||
|
@ -57,10 +56,6 @@ def openai_completions():
|
|||
if stats:
|
||||
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
|
||||
return response
|
||||
except Exception as e:
|
||||
print(f'EXCEPTION on {request.url}!!!')
|
||||
print(traceback.format_exc())
|
||||
if disable_st_error_formatting:
|
||||
return '500', 500
|
||||
else:
|
||||
return build_openai_response('', format_sillytavern_err(f'Server encountered exception.', 'error')), 500
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
return 'Internal Server Error', 500
|
||||
|
|
|
@ -7,9 +7,8 @@ import flask
|
|||
from flask import jsonify
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database.database import is_api_key_moderated, log_prompt
|
||||
from llm_server.database.database import is_api_key_moderated
|
||||
from llm_server.llm.openai.transform import build_openai_response, transform_messages_to_prompt, trim_prompt_to_fit
|
||||
from llm_server.routes.helpers.client import format_sillytavern_err
|
||||
from llm_server.routes.request_handler import RequestHandler
|
||||
from llm_server.threads import add_moderation_task, get_results
|
||||
|
||||
|
@ -71,17 +70,12 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
return backend_response, backend_response_status_code
|
||||
|
||||
def handle_ratelimited(self):
|
||||
disable_st_error_formatting = self.request.headers.get('LLM-ST-Errors', False) == 'true'
|
||||
if disable_st_error_formatting:
|
||||
# TODO: format this like OpenAI does
|
||||
return '429', 429
|
||||
else:
|
||||
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error')
|
||||
log_prompt(ip=self.client_ip, token=self.token, prompt=self.request_json_body.get('prompt', ''), response=backend_response, gen_time=None, parameters=self.parameters, headers=dict(self.request.headers), backend_response_code=429, request_url=self.request.url, is_error=True)
|
||||
return build_openai_response(self.prompt, backend_response), 429
|
||||
# TODO: return a simulated OpenAI error message
|
||||
# Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.
|
||||
return 'Ratelimited', 429
|
||||
|
||||
def handle_error(self, msg: str) -> Tuple[flask.Response, int]:
|
||||
# return build_openai_response('', msg), 400
|
||||
def handle_error(self, error_msg: str) -> Tuple[flask.Response, int]:
|
||||
# TODO: return a simulated OpenAI error message
|
||||
return jsonify({
|
||||
"error": {
|
||||
"message": "Invalid request, check your parameters and try again.",
|
||||
|
|
|
@ -119,12 +119,7 @@ class RequestHandler:
|
|||
else:
|
||||
# Otherwise, just grab the first and only one.
|
||||
combined_error_message = invalid_request_err_msgs[0] + '.'
|
||||
msg = f'Validation Error: {combined_error_message}'
|
||||
disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true'
|
||||
if disable_st_error_formatting:
|
||||
backend_response = (Response(msg, 400), 400)
|
||||
else:
|
||||
backend_response = self.handle_error(format_sillytavern_err(msg, 'error'))
|
||||
backend_response = self.handle_error(combined_error_message, 'Validation Error')
|
||||
|
||||
if do_log:
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True)
|
||||
|
@ -168,12 +163,7 @@ class RequestHandler:
|
|||
error_msg = 'Unknown error.'
|
||||
else:
|
||||
error_msg = error_msg.strip('.') + '.'
|
||||
|
||||
disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true'
|
||||
if disable_st_error_formatting:
|
||||
backend_response = (Response(error_msg, 400), 400)
|
||||
else:
|
||||
backend_response = self.handle_error(format_sillytavern_err(error_msg, 'error'))
|
||||
backend_response = self.handle_error(error_msg)
|
||||
log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True)
|
||||
return (False, None, None, 0), backend_response
|
||||
|
||||
|
@ -193,13 +183,8 @@ class RequestHandler:
|
|||
|
||||
if return_json_err:
|
||||
error_msg = 'The backend did not return valid JSON.'
|
||||
disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true'
|
||||
if disable_st_error_formatting:
|
||||
# TODO: how to format this
|
||||
backend_response = (Response(error_msg, 400), 400)
|
||||
else:
|
||||
backend_response = self.handle_error(format_sillytavern_err(error_msg, 'error'))
|
||||
log_prompt(self.client_ip, self.token, prompt, backend_response, elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True)
|
||||
backend_response = self.handle_error(error_msg)
|
||||
log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True)
|
||||
return (False, None, None, 0), backend_response
|
||||
|
||||
# ===============================================
|
||||
|
@ -223,7 +208,7 @@ class RequestHandler:
|
|||
def handle_ratelimited(self) -> Tuple[flask.Response, int]:
|
||||
raise NotImplementedError
|
||||
|
||||
def handle_error(self, msg: str) -> Tuple[flask.Response, int]:
|
||||
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
|
|
@ -10,17 +10,13 @@ from ..ooba_request_handler import OobaRequestHandler
|
|||
|
||||
@bp.route('/generate', methods=['POST'])
|
||||
def generate():
|
||||
disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true'
|
||||
request_valid_json, request_json_body = validate_json(request)
|
||||
if not request_valid_json or not request_json_body.get('prompt'):
|
||||
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
||||
else:
|
||||
handler = OobaRequestHandler(request)
|
||||
try:
|
||||
return OobaRequestHandler(request).handle_request()
|
||||
except Exception as e:
|
||||
print(f'EXCEPTION on {request.url}!!!')
|
||||
print(traceback.format_exc())
|
||||
if disable_st_error_formatting:
|
||||
return '500', 500
|
||||
else:
|
||||
return format_sillytavern_err(f'Server encountered exception.', 'error'), 500
|
||||
return handler.handle_request()
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
return handler.handle_error('Server encountered exception.', 'exception')[0], 500
|
||||
|
|
|
@ -25,8 +25,6 @@ def stream(ws):
|
|||
|
||||
r_headers = dict(request.headers)
|
||||
r_url = request.url
|
||||
disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true'
|
||||
|
||||
message_num = 0
|
||||
while ws.connected:
|
||||
message = ws.receive()
|
||||
|
@ -135,23 +133,22 @@ def stream(ws):
|
|||
thread.start()
|
||||
thread.join()
|
||||
except:
|
||||
if not disable_st_error_formatting:
|
||||
generated_text = generated_text + '\n\n' + format_sillytavern_err('Encountered error while streaming.', 'error')
|
||||
traceback.print_exc()
|
||||
ws.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': message_num,
|
||||
'text': generated_text
|
||||
}))
|
||||
generated_text = generated_text + '\n\n' + handler.handle_error('Encountered error while streaming.', 'exception')[0].data.decode('utf-8')
|
||||
traceback.print_exc()
|
||||
ws.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': message_num,
|
||||
'text': generated_text
|
||||
}))
|
||||
|
||||
def background_task_exception():
|
||||
generated_tokens = tokenize(generated_text)
|
||||
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, None, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens)
|
||||
def background_task_exception():
|
||||
generated_tokens = tokenize(generated_text)
|
||||
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, None, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens)
|
||||
|
||||
# TODO: use async/await instead of threads
|
||||
thread = threading.Thread(target=background_task_exception)
|
||||
thread.start()
|
||||
thread.join()
|
||||
# TODO: use async/await instead of threads
|
||||
thread = threading.Thread(target=background_task_exception)
|
||||
thread.start()
|
||||
thread.join()
|
||||
try:
|
||||
ws.send(json.dumps({
|
||||
'event': 'stream_end',
|
||||
|
|
|
@ -26,6 +26,7 @@ from llm_server.routes.v1 import bp
|
|||
from llm_server.stream import init_socketio
|
||||
|
||||
# TODO: have the workers handle streaming too
|
||||
# TODO: send extra headers when ratelimited?
|
||||
# TODO: return 200 when returning formatted sillytavern error
|
||||
# TODO: add some sort of loadbalancer to send requests to a group of backends
|
||||
# TODO: allow setting concurrent gens per-backend
|
||||
|
|
Reference in New Issue