From 18e37a72ae9ee0251bf2fca61091c4b187930cd8 Mon Sep 17 00:00:00 2001 From: Cyberes Date: Mon, 9 Oct 2023 23:51:26 -0600 Subject: [PATCH] add model selection to openai endpoint --- llm_server/cluster/backend.py | 6 ++---- llm_server/cluster/model_choices.py | 1 - llm_server/routes/openai/__init__.py | 3 +++ llm_server/routes/openai/chat_completions.py | 7 ++++--- llm_server/routes/openai/completions.py | 7 ++++--- llm_server/routes/request_handler.py | 2 +- llm_server/routes/v1/generate.py | 2 +- server.py | 3 ++- 8 files changed, 17 insertions(+), 14 deletions(-) delete mode 100644 llm_server/cluster/model_choices.py diff --git a/llm_server/cluster/backend.py b/llm_server/cluster/backend.py index c301f93..2a7edc3 100644 --- a/llm_server/cluster/backend.py +++ b/llm_server/cluster/backend.py @@ -7,7 +7,7 @@ from llm_server.custom_redis import redis from llm_server.llm.generator import generator from llm_server.llm.info import get_info from llm_server.routes.queue import priority_queue -from llm_server.routes.stats import get_active_gen_workers_model, calculate_wait_time +from llm_server.routes.stats import calculate_wait_time, get_active_gen_workers_model def get_backends_from_model(model_name: str): @@ -51,8 +51,6 @@ def test_backend(backend_url: str, test_prompt: bool = False): return True, i - - def get_model_choices(regen: bool = False): if not regen: c = redis.getp('model_choices') @@ -89,7 +87,7 @@ def get_model_choices(regen: bool = False): 'model': model, 'client_api': f'https://{base_client_api}/{model}', 'ws_client_api': f'wss://{base_client_api}/{model}/v1/stream' if opts.enable_streaming else None, - 'openai_client_api': f'https://{base_client_api}/openai/{model}' if opts.enable_openi_compatible_backend else 'disabled', + 'openai_client_api': f'https://{base_client_api}/openai/{model}/v1' if opts.enable_openi_compatible_backend else 'disabled', 'backend_count': len(b), 'estimated_wait': estimated_wait_sec, 'queued': proompters_in_queue, diff --git a/llm_server/cluster/model_choices.py b/llm_server/cluster/model_choices.py deleted file mode 100644 index ef93bba..0000000 --- a/llm_server/cluster/model_choices.py +++ /dev/null @@ -1 +0,0 @@ -# TODO: give this a better name! diff --git a/llm_server/routes/openai/__init__.py b/llm_server/routes/openai/__init__.py index 67febc9..3a69aa7 100644 --- a/llm_server/routes/openai/__init__.py +++ b/llm_server/routes/openai/__init__.py @@ -5,9 +5,11 @@ from ..server_error import handle_server_error from ... import opts openai_bp = Blueprint('openai/v1/', __name__) +openai_model_bp = Blueprint('openai/', __name__) @openai_bp.before_request +@openai_model_bp.before_request def before_oai_request(): if not opts.enable_openi_compatible_backend: return 'The OpenAI-compatible backend is disabled.', 401 @@ -15,6 +17,7 @@ def before_oai_request(): @openai_bp.errorhandler(500) +@openai_model_bp.errorhandler(500) def handle_error(e): return handle_server_error(e) diff --git a/llm_server/routes/openai/chat_completions.py b/llm_server/routes/openai/chat_completions.py index b088a18..76f1b6c 100644 --- a/llm_server/routes/openai/chat_completions.py +++ b/llm_server/routes/openai/chat_completions.py @@ -5,7 +5,7 @@ import traceback from flask import Response, jsonify, request from llm_server.custom_redis import redis -from . import openai_bp +from . import openai_bp, openai_model_bp from ..helpers.http import validate_json from ..openai_request_handler import OpenAIRequestHandler from ..queue import priority_queue @@ -20,12 +20,13 @@ from ...llm.openai.transform import generate_oai_string, transform_messages_to_p @openai_bp.route('/chat/completions', methods=['POST']) -def openai_chat_completions(): +@openai_model_bp.route('//v1/chat/completions', methods=['POST']) +def openai_chat_completions(model_name=None): request_valid_json, request_json_body = validate_json(request) if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'): return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400 else: - handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body) + handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name) if not request_json_body.get('stream'): try: invalid_oai_err_msg = validate_oai(request_json_body) diff --git a/llm_server/routes/openai/completions.py b/llm_server/routes/openai/completions.py index 8b5d987..c8e7f19 100644 --- a/llm_server/routes/openai/completions.py +++ b/llm_server/routes/openai/completions.py @@ -5,7 +5,7 @@ import simplejson as json from flask import Response, jsonify, request from llm_server.custom_redis import redis -from . import openai_bp +from . import openai_bp, openai_model_bp from ..helpers.http import validate_json from ..ooba_request_handler import OobaRequestHandler from ..queue import priority_queue @@ -20,12 +20,13 @@ from ...llm.openai.transform import generate_oai_string, trim_string_to_fit # TODO: add rate-limit headers? @openai_bp.route('/completions', methods=['POST']) -def openai_completions(): +@openai_model_bp.route('//v1/completions', methods=['POST']) +def openai_completions(model_name=None): request_valid_json, request_json_body = validate_json(request) if not request_valid_json or not request_json_body.get('prompt'): return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 else: - handler = OobaRequestHandler(incoming_request=request) + handler = OobaRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name) if handler.cluster_backend_info['mode'] != 'vllm': # TODO: implement other backends diff --git a/llm_server/routes/request_handler.py b/llm_server/routes/request_handler.py index dd8326b..6a6ad4a 100644 --- a/llm_server/routes/request_handler.py +++ b/llm_server/routes/request_handler.py @@ -20,7 +20,7 @@ from llm_server.routes.queue import priority_queue class RequestHandler: def __init__(self, incoming_request: flask.Request, selected_model: str = None, incoming_json: Union[dict, str] = None): self.request = incoming_request - self.enable_backend_blind_rrd = request.headers.get('LLM-Blind-RRD', False) == 'true' + # self.enable_backend_blind_rrd = request.headers.get('LLM-Blind-RRD', False) == 'true' # Routes need to validate it, here we just load it if incoming_json: diff --git a/llm_server/routes/v1/generate.py b/llm_server/routes/v1/generate.py index 1a63db9..fcdc298 100644 --- a/llm_server/routes/v1/generate.py +++ b/llm_server/routes/v1/generate.py @@ -14,7 +14,7 @@ def generate(model_name=None): if not request_valid_json or not request_json_body.get('prompt'): return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 else: - handler = OobaRequestHandler(request, model_name) + handler = OobaRequestHandler(request, selected_model=model_name) try: return handler.handle_request() except Exception: diff --git a/server.py b/server.py index c3ed4a2..89c71aa 100644 --- a/server.py +++ b/server.py @@ -24,7 +24,7 @@ from llm_server.database.create import create_db from llm_server.helpers import auto_set_base_client_api from llm_server.llm.vllm.info import vllm_info from llm_server.pre_fork import server_startup -from llm_server.routes.openai import openai_bp +from llm_server.routes.openai import openai_bp, openai_model_bp from llm_server.routes.server_error import handle_server_error from llm_server.routes.v1 import bp from llm_server.routes.v1.generate_stats import generate_stats @@ -70,6 +70,7 @@ except ModuleNotFoundError as e: app = Flask(__name__) app.register_blueprint(bp, url_prefix='/api/') app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/') +app.register_blueprint(openai_model_bp, url_prefix='/api/openai/') init_socketio(app) flask_cache.init_app(app) flask_cache.clear()