fix oai exception

This commit is contained in:
Cyberes 2023-10-11 09:20:00 -06:00
parent 7286e38cb0
commit 78114771b0
5 changed files with 36 additions and 5 deletions

View File

@ -77,3 +77,18 @@ def validate_oai(parameters):
if parameters.get('max_tokens', 2) < 1: if parameters.get('max_tokens', 2) < 1:
return format_oai_err(f"{parameters['max_tokens']} is less than the minimum of 1 - 'max_tokens'") return format_oai_err(f"{parameters['max_tokens']} is less than the minimum of 1 - 'max_tokens'")
def return_invalid_model_err(requested_model: str):
if requested_model:
msg = f"The model `{requested_model}` does not exist"
else:
msg = "The requested model does not exist"
return jsonify({
"error": {
"message": msg,
"type": "invalid_request_error",
"param": None,
"code": "model_not_found"
}
}), 404

View File

@ -12,7 +12,7 @@ from ..queue import priority_queue
from ... import opts from ... import opts
from ...database.log_to_db import log_to_db from ...database.log_to_db import log_to_db
from ...llm.generator import generator from ...llm.generator import generator
from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_invalid_model_err
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
@ -27,6 +27,11 @@ def openai_chat_completions(model_name=None):
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400 return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
else: else:
handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name) handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name)
if handler.offline:
msg = return_invalid_model_err(model_name)
print(msg)
return handler.handle_error(msg)
if not request_json_body.get('stream'): if not request_json_body.get('stream'):
try: try:
invalid_oai_err_msg = validate_oai(request_json_body) invalid_oai_err_msg = validate_oai(request_json_body)

View File

@ -13,7 +13,7 @@ from ... import opts
from ...database.log_to_db import log_to_db from ...database.log_to_db import log_to_db
from ...llm import get_token_count from ...llm import get_token_count
from ...llm.generator import generator from ...llm.generator import generator
from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_invalid_model_err
from ...llm.openai.transform import generate_oai_string, trim_string_to_fit from ...llm.openai.transform import generate_oai_string, trim_string_to_fit
@ -27,6 +27,10 @@ def openai_completions(model_name=None):
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
else: else:
handler = OobaRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name) handler = OobaRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name)
if handler.offline:
msg = return_invalid_model_err(model_name)
print(msg)
return handler.handle_error(msg)
if handler.cluster_backend_info['mode'] != 'vllm': if handler.cluster_backend_info['mode'] != 'vllm':
# TODO: implement other backends # TODO: implement other backends

View File

@ -11,10 +11,10 @@ from flask import Response, jsonify, make_response
from llm_server import opts from llm_server import opts
from llm_server.cluster.backend import get_model_choices from llm_server.cluster.backend import get_model_choices
from llm_server.custom_redis import redis from llm_server.custom_redis import redis
from llm_server.database.database import is_api_key_moderated, do_db_log from llm_server.database.database import is_api_key_moderated
from llm_server.database.log_to_db import log_to_db from llm_server.database.log_to_db import log_to_db
from llm_server.llm import get_token_count from llm_server.llm import get_token_count
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_invalid_model_err
from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
from llm_server.routes.request_handler import RequestHandler from llm_server.routes.request_handler import RequestHandler
from llm_server.workers.moderator import add_moderation_task, get_results from llm_server.workers.moderator import add_moderation_task, get_results
@ -27,6 +27,10 @@ class OpenAIRequestHandler(RequestHandler):
def handle_request(self) -> Tuple[flask.Response, int]: def handle_request(self) -> Tuple[flask.Response, int]:
assert not self.used assert not self.used
if self.offline:
msg = return_invalid_model_err(self.selected_model)
print(msg)
return self.handle_error(msg)
if opts.openai_silent_trim: if opts.openai_silent_trim:
oai_messages = trim_messages_to_fit(self.request.json['messages'], self.cluster_backend_info['model_config']['max_position_embeddings'], self.backend_url) oai_messages = trim_messages_to_fit(self.request.json['messages'], self.cluster_backend_info['model_config']['max_position_embeddings'], self.backend_url)
@ -63,7 +67,7 @@ class OpenAIRequestHandler(RequestHandler):
if not self.prompt: if not self.prompt:
# TODO: format this as an openai error message # TODO: format this as an openai error message
return 'Invalid prompt', 400 return Response('Invalid prompt'), 400
llm_request = {**self.parameters, 'prompt': self.prompt} llm_request = {**self.parameters, 'prompt': self.prompt}
(success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request) (success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request)

View File

@ -1,5 +1,6 @@
import threading import threading
import time import time
import traceback
from uuid import uuid4 from uuid import uuid4
from llm_server.cluster.cluster_config import cluster_config from llm_server.cluster.cluster_config import cluster_config
@ -41,6 +42,8 @@ def worker(backend_url):
success, response, error_msg = generator(request_json_body, backend_url) success, response, error_msg = generator(request_json_body, backend_url)
event = DataEvent(event_id) event = DataEvent(event_id)
event.set((success, response, error_msg)) event.set((success, response, error_msg))
except:
traceback.print_exc()
finally: finally:
decrement_ip_count(client_ip, 'processing_ips') decrement_ip_count(client_ip, 'processing_ips')
decr_active_workers(selected_model, backend_url) decr_active_workers(selected_model, backend_url)