add model selection to openai endpoint

This commit is contained in:
Cyberes 2023-10-09 23:51:26 -06:00
parent 5f7bf4faca
commit 18e37a72ae
8 changed files with 17 additions and 14 deletions

View File

@ -7,7 +7,7 @@ from llm_server.custom_redis import redis
from llm_server.llm.generator import generator from llm_server.llm.generator import generator
from llm_server.llm.info import get_info from llm_server.llm.info import get_info
from llm_server.routes.queue import priority_queue from llm_server.routes.queue import priority_queue
from llm_server.routes.stats import get_active_gen_workers_model, calculate_wait_time from llm_server.routes.stats import calculate_wait_time, get_active_gen_workers_model
def get_backends_from_model(model_name: str): def get_backends_from_model(model_name: str):
@ -51,8 +51,6 @@ def test_backend(backend_url: str, test_prompt: bool = False):
return True, i return True, i
def get_model_choices(regen: bool = False): def get_model_choices(regen: bool = False):
if not regen: if not regen:
c = redis.getp('model_choices') c = redis.getp('model_choices')
@ -89,7 +87,7 @@ def get_model_choices(regen: bool = False):
'model': model, 'model': model,
'client_api': f'https://{base_client_api}/{model}', 'client_api': f'https://{base_client_api}/{model}',
'ws_client_api': f'wss://{base_client_api}/{model}/v1/stream' if opts.enable_streaming else None, 'ws_client_api': f'wss://{base_client_api}/{model}/v1/stream' if opts.enable_streaming else None,
'openai_client_api': f'https://{base_client_api}/openai/{model}' if opts.enable_openi_compatible_backend else 'disabled', 'openai_client_api': f'https://{base_client_api}/openai/{model}/v1' if opts.enable_openi_compatible_backend else 'disabled',
'backend_count': len(b), 'backend_count': len(b),
'estimated_wait': estimated_wait_sec, 'estimated_wait': estimated_wait_sec,
'queued': proompters_in_queue, 'queued': proompters_in_queue,

View File

@ -1 +0,0 @@
# TODO: give this a better name!

View File

@ -5,9 +5,11 @@ from ..server_error import handle_server_error
from ... import opts from ... import opts
openai_bp = Blueprint('openai/v1/', __name__) openai_bp = Blueprint('openai/v1/', __name__)
openai_model_bp = Blueprint('openai/', __name__)
@openai_bp.before_request @openai_bp.before_request
@openai_model_bp.before_request
def before_oai_request(): def before_oai_request():
if not opts.enable_openi_compatible_backend: if not opts.enable_openi_compatible_backend:
return 'The OpenAI-compatible backend is disabled.', 401 return 'The OpenAI-compatible backend is disabled.', 401
@ -15,6 +17,7 @@ def before_oai_request():
@openai_bp.errorhandler(500) @openai_bp.errorhandler(500)
@openai_model_bp.errorhandler(500)
def handle_error(e): def handle_error(e):
return handle_server_error(e) return handle_server_error(e)

View File

@ -5,7 +5,7 @@ import traceback
from flask import Response, jsonify, request from flask import Response, jsonify, request
from llm_server.custom_redis import redis from llm_server.custom_redis import redis
from . import openai_bp from . import openai_bp, openai_model_bp
from ..helpers.http import validate_json from ..helpers.http import validate_json
from ..openai_request_handler import OpenAIRequestHandler from ..openai_request_handler import OpenAIRequestHandler
from ..queue import priority_queue from ..queue import priority_queue
@ -20,12 +20,13 @@ from ...llm.openai.transform import generate_oai_string, transform_messages_to_p
@openai_bp.route('/chat/completions', methods=['POST']) @openai_bp.route('/chat/completions', methods=['POST'])
def openai_chat_completions(): @openai_model_bp.route('/<model_name>/v1/chat/completions', methods=['POST'])
def openai_chat_completions(model_name=None):
request_valid_json, request_json_body = validate_json(request) request_valid_json, request_json_body = validate_json(request)
if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'): if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'):
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400 return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
else: else:
handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body) handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name)
if not request_json_body.get('stream'): if not request_json_body.get('stream'):
try: try:
invalid_oai_err_msg = validate_oai(request_json_body) invalid_oai_err_msg = validate_oai(request_json_body)

View File

@ -5,7 +5,7 @@ import simplejson as json
from flask import Response, jsonify, request from flask import Response, jsonify, request
from llm_server.custom_redis import redis from llm_server.custom_redis import redis
from . import openai_bp from . import openai_bp, openai_model_bp
from ..helpers.http import validate_json from ..helpers.http import validate_json
from ..ooba_request_handler import OobaRequestHandler from ..ooba_request_handler import OobaRequestHandler
from ..queue import priority_queue from ..queue import priority_queue
@ -20,12 +20,13 @@ from ...llm.openai.transform import generate_oai_string, trim_string_to_fit
# TODO: add rate-limit headers? # TODO: add rate-limit headers?
@openai_bp.route('/completions', methods=['POST']) @openai_bp.route('/completions', methods=['POST'])
def openai_completions(): @openai_model_bp.route('/<model_name>/v1/completions', methods=['POST'])
def openai_completions(model_name=None):
request_valid_json, request_json_body = validate_json(request) request_valid_json, request_json_body = validate_json(request)
if not request_valid_json or not request_json_body.get('prompt'): if not request_valid_json or not request_json_body.get('prompt'):
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
else: else:
handler = OobaRequestHandler(incoming_request=request) handler = OobaRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name)
if handler.cluster_backend_info['mode'] != 'vllm': if handler.cluster_backend_info['mode'] != 'vllm':
# TODO: implement other backends # TODO: implement other backends

View File

@ -20,7 +20,7 @@ from llm_server.routes.queue import priority_queue
class RequestHandler: class RequestHandler:
def __init__(self, incoming_request: flask.Request, selected_model: str = None, incoming_json: Union[dict, str] = None): def __init__(self, incoming_request: flask.Request, selected_model: str = None, incoming_json: Union[dict, str] = None):
self.request = incoming_request self.request = incoming_request
self.enable_backend_blind_rrd = request.headers.get('LLM-Blind-RRD', False) == 'true' # self.enable_backend_blind_rrd = request.headers.get('LLM-Blind-RRD', False) == 'true'
# Routes need to validate it, here we just load it # Routes need to validate it, here we just load it
if incoming_json: if incoming_json:

View File

@ -14,7 +14,7 @@ def generate(model_name=None):
if not request_valid_json or not request_json_body.get('prompt'): if not request_valid_json or not request_json_body.get('prompt'):
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
else: else:
handler = OobaRequestHandler(request, model_name) handler = OobaRequestHandler(request, selected_model=model_name)
try: try:
return handler.handle_request() return handler.handle_request()
except Exception: except Exception:

View File

@ -24,7 +24,7 @@ from llm_server.database.create import create_db
from llm_server.helpers import auto_set_base_client_api from llm_server.helpers import auto_set_base_client_api
from llm_server.llm.vllm.info import vllm_info from llm_server.llm.vllm.info import vllm_info
from llm_server.pre_fork import server_startup from llm_server.pre_fork import server_startup
from llm_server.routes.openai import openai_bp from llm_server.routes.openai import openai_bp, openai_model_bp
from llm_server.routes.server_error import handle_server_error from llm_server.routes.server_error import handle_server_error
from llm_server.routes.v1 import bp from llm_server.routes.v1 import bp
from llm_server.routes.v1.generate_stats import generate_stats from llm_server.routes.v1.generate_stats import generate_stats
@ -70,6 +70,7 @@ except ModuleNotFoundError as e:
app = Flask(__name__) app = Flask(__name__)
app.register_blueprint(bp, url_prefix='/api/') app.register_blueprint(bp, url_prefix='/api/')
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/') app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
app.register_blueprint(openai_model_bp, url_prefix='/api/openai/')
init_socketio(app) init_socketio(app)
flask_cache.init_app(app) flask_cache.init_app(app)
flask_cache.clear() flask_cache.clear()