add model selection to openai endpoint
This commit is contained in:
parent
5f7bf4faca
commit
18e37a72ae
|
@ -7,7 +7,7 @@ from llm_server.custom_redis import redis
|
||||||
from llm_server.llm.generator import generator
|
from llm_server.llm.generator import generator
|
||||||
from llm_server.llm.info import get_info
|
from llm_server.llm.info import get_info
|
||||||
from llm_server.routes.queue import priority_queue
|
from llm_server.routes.queue import priority_queue
|
||||||
from llm_server.routes.stats import get_active_gen_workers_model, calculate_wait_time
|
from llm_server.routes.stats import calculate_wait_time, get_active_gen_workers_model
|
||||||
|
|
||||||
|
|
||||||
def get_backends_from_model(model_name: str):
|
def get_backends_from_model(model_name: str):
|
||||||
|
@ -51,8 +51,6 @@ def test_backend(backend_url: str, test_prompt: bool = False):
|
||||||
return True, i
|
return True, i
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_choices(regen: bool = False):
|
def get_model_choices(regen: bool = False):
|
||||||
if not regen:
|
if not regen:
|
||||||
c = redis.getp('model_choices')
|
c = redis.getp('model_choices')
|
||||||
|
@ -89,7 +87,7 @@ def get_model_choices(regen: bool = False):
|
||||||
'model': model,
|
'model': model,
|
||||||
'client_api': f'https://{base_client_api}/{model}',
|
'client_api': f'https://{base_client_api}/{model}',
|
||||||
'ws_client_api': f'wss://{base_client_api}/{model}/v1/stream' if opts.enable_streaming else None,
|
'ws_client_api': f'wss://{base_client_api}/{model}/v1/stream' if opts.enable_streaming else None,
|
||||||
'openai_client_api': f'https://{base_client_api}/openai/{model}' if opts.enable_openi_compatible_backend else 'disabled',
|
'openai_client_api': f'https://{base_client_api}/openai/{model}/v1' if opts.enable_openi_compatible_backend else 'disabled',
|
||||||
'backend_count': len(b),
|
'backend_count': len(b),
|
||||||
'estimated_wait': estimated_wait_sec,
|
'estimated_wait': estimated_wait_sec,
|
||||||
'queued': proompters_in_queue,
|
'queued': proompters_in_queue,
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
# TODO: give this a better name!
|
|
|
@ -5,9 +5,11 @@ from ..server_error import handle_server_error
|
||||||
from ... import opts
|
from ... import opts
|
||||||
|
|
||||||
openai_bp = Blueprint('openai/v1/', __name__)
|
openai_bp = Blueprint('openai/v1/', __name__)
|
||||||
|
openai_model_bp = Blueprint('openai/', __name__)
|
||||||
|
|
||||||
|
|
||||||
@openai_bp.before_request
|
@openai_bp.before_request
|
||||||
|
@openai_model_bp.before_request
|
||||||
def before_oai_request():
|
def before_oai_request():
|
||||||
if not opts.enable_openi_compatible_backend:
|
if not opts.enable_openi_compatible_backend:
|
||||||
return 'The OpenAI-compatible backend is disabled.', 401
|
return 'The OpenAI-compatible backend is disabled.', 401
|
||||||
|
@ -15,6 +17,7 @@ def before_oai_request():
|
||||||
|
|
||||||
|
|
||||||
@openai_bp.errorhandler(500)
|
@openai_bp.errorhandler(500)
|
||||||
|
@openai_model_bp.errorhandler(500)
|
||||||
def handle_error(e):
|
def handle_error(e):
|
||||||
return handle_server_error(e)
|
return handle_server_error(e)
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ import traceback
|
||||||
from flask import Response, jsonify, request
|
from flask import Response, jsonify, request
|
||||||
|
|
||||||
from llm_server.custom_redis import redis
|
from llm_server.custom_redis import redis
|
||||||
from . import openai_bp
|
from . import openai_bp, openai_model_bp
|
||||||
from ..helpers.http import validate_json
|
from ..helpers.http import validate_json
|
||||||
from ..openai_request_handler import OpenAIRequestHandler
|
from ..openai_request_handler import OpenAIRequestHandler
|
||||||
from ..queue import priority_queue
|
from ..queue import priority_queue
|
||||||
|
@ -20,12 +20,13 @@ from ...llm.openai.transform import generate_oai_string, transform_messages_to_p
|
||||||
|
|
||||||
|
|
||||||
@openai_bp.route('/chat/completions', methods=['POST'])
|
@openai_bp.route('/chat/completions', methods=['POST'])
|
||||||
def openai_chat_completions():
|
@openai_model_bp.route('/<model_name>/v1/chat/completions', methods=['POST'])
|
||||||
|
def openai_chat_completions(model_name=None):
|
||||||
request_valid_json, request_json_body = validate_json(request)
|
request_valid_json, request_json_body = validate_json(request)
|
||||||
if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'):
|
if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'):
|
||||||
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
|
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
|
||||||
else:
|
else:
|
||||||
handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body)
|
handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name)
|
||||||
if not request_json_body.get('stream'):
|
if not request_json_body.get('stream'):
|
||||||
try:
|
try:
|
||||||
invalid_oai_err_msg = validate_oai(request_json_body)
|
invalid_oai_err_msg = validate_oai(request_json_body)
|
||||||
|
|
|
@ -5,7 +5,7 @@ import simplejson as json
|
||||||
from flask import Response, jsonify, request
|
from flask import Response, jsonify, request
|
||||||
|
|
||||||
from llm_server.custom_redis import redis
|
from llm_server.custom_redis import redis
|
||||||
from . import openai_bp
|
from . import openai_bp, openai_model_bp
|
||||||
from ..helpers.http import validate_json
|
from ..helpers.http import validate_json
|
||||||
from ..ooba_request_handler import OobaRequestHandler
|
from ..ooba_request_handler import OobaRequestHandler
|
||||||
from ..queue import priority_queue
|
from ..queue import priority_queue
|
||||||
|
@ -20,12 +20,13 @@ from ...llm.openai.transform import generate_oai_string, trim_string_to_fit
|
||||||
# TODO: add rate-limit headers?
|
# TODO: add rate-limit headers?
|
||||||
|
|
||||||
@openai_bp.route('/completions', methods=['POST'])
|
@openai_bp.route('/completions', methods=['POST'])
|
||||||
def openai_completions():
|
@openai_model_bp.route('/<model_name>/v1/completions', methods=['POST'])
|
||||||
|
def openai_completions(model_name=None):
|
||||||
request_valid_json, request_json_body = validate_json(request)
|
request_valid_json, request_json_body = validate_json(request)
|
||||||
if not request_valid_json or not request_json_body.get('prompt'):
|
if not request_valid_json or not request_json_body.get('prompt'):
|
||||||
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
||||||
else:
|
else:
|
||||||
handler = OobaRequestHandler(incoming_request=request)
|
handler = OobaRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name)
|
||||||
|
|
||||||
if handler.cluster_backend_info['mode'] != 'vllm':
|
if handler.cluster_backend_info['mode'] != 'vllm':
|
||||||
# TODO: implement other backends
|
# TODO: implement other backends
|
||||||
|
|
|
@ -20,7 +20,7 @@ from llm_server.routes.queue import priority_queue
|
||||||
class RequestHandler:
|
class RequestHandler:
|
||||||
def __init__(self, incoming_request: flask.Request, selected_model: str = None, incoming_json: Union[dict, str] = None):
|
def __init__(self, incoming_request: flask.Request, selected_model: str = None, incoming_json: Union[dict, str] = None):
|
||||||
self.request = incoming_request
|
self.request = incoming_request
|
||||||
self.enable_backend_blind_rrd = request.headers.get('LLM-Blind-RRD', False) == 'true'
|
# self.enable_backend_blind_rrd = request.headers.get('LLM-Blind-RRD', False) == 'true'
|
||||||
|
|
||||||
# Routes need to validate it, here we just load it
|
# Routes need to validate it, here we just load it
|
||||||
if incoming_json:
|
if incoming_json:
|
||||||
|
|
|
@ -14,7 +14,7 @@ def generate(model_name=None):
|
||||||
if not request_valid_json or not request_json_body.get('prompt'):
|
if not request_valid_json or not request_json_body.get('prompt'):
|
||||||
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
||||||
else:
|
else:
|
||||||
handler = OobaRequestHandler(request, model_name)
|
handler = OobaRequestHandler(request, selected_model=model_name)
|
||||||
try:
|
try:
|
||||||
return handler.handle_request()
|
return handler.handle_request()
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
|
@ -24,7 +24,7 @@ from llm_server.database.create import create_db
|
||||||
from llm_server.helpers import auto_set_base_client_api
|
from llm_server.helpers import auto_set_base_client_api
|
||||||
from llm_server.llm.vllm.info import vllm_info
|
from llm_server.llm.vllm.info import vllm_info
|
||||||
from llm_server.pre_fork import server_startup
|
from llm_server.pre_fork import server_startup
|
||||||
from llm_server.routes.openai import openai_bp
|
from llm_server.routes.openai import openai_bp, openai_model_bp
|
||||||
from llm_server.routes.server_error import handle_server_error
|
from llm_server.routes.server_error import handle_server_error
|
||||||
from llm_server.routes.v1 import bp
|
from llm_server.routes.v1 import bp
|
||||||
from llm_server.routes.v1.generate_stats import generate_stats
|
from llm_server.routes.v1.generate_stats import generate_stats
|
||||||
|
@ -70,6 +70,7 @@ except ModuleNotFoundError as e:
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
app.register_blueprint(bp, url_prefix='/api/')
|
app.register_blueprint(bp, url_prefix='/api/')
|
||||||
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
|
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
|
||||||
|
app.register_blueprint(openai_model_bp, url_prefix='/api/openai/')
|
||||||
init_socketio(app)
|
init_socketio(app)
|
||||||
flask_cache.init_app(app)
|
flask_cache.init_app(app)
|
||||||
flask_cache.clear()
|
flask_cache.clear()
|
||||||
|
|
Reference in New Issue