add model selection to openai endpoint
This commit is contained in:
parent
5f7bf4faca
commit
18e37a72ae
|
@ -7,7 +7,7 @@ from llm_server.custom_redis import redis
|
|||
from llm_server.llm.generator import generator
|
||||
from llm_server.llm.info import get_info
|
||||
from llm_server.routes.queue import priority_queue
|
||||
from llm_server.routes.stats import get_active_gen_workers_model, calculate_wait_time
|
||||
from llm_server.routes.stats import calculate_wait_time, get_active_gen_workers_model
|
||||
|
||||
|
||||
def get_backends_from_model(model_name: str):
|
||||
|
@ -51,8 +51,6 @@ def test_backend(backend_url: str, test_prompt: bool = False):
|
|||
return True, i
|
||||
|
||||
|
||||
|
||||
|
||||
def get_model_choices(regen: bool = False):
|
||||
if not regen:
|
||||
c = redis.getp('model_choices')
|
||||
|
@ -89,7 +87,7 @@ def get_model_choices(regen: bool = False):
|
|||
'model': model,
|
||||
'client_api': f'https://{base_client_api}/{model}',
|
||||
'ws_client_api': f'wss://{base_client_api}/{model}/v1/stream' if opts.enable_streaming else None,
|
||||
'openai_client_api': f'https://{base_client_api}/openai/{model}' if opts.enable_openi_compatible_backend else 'disabled',
|
||||
'openai_client_api': f'https://{base_client_api}/openai/{model}/v1' if opts.enable_openi_compatible_backend else 'disabled',
|
||||
'backend_count': len(b),
|
||||
'estimated_wait': estimated_wait_sec,
|
||||
'queued': proompters_in_queue,
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
# TODO: give this a better name!
|
|
@ -5,9 +5,11 @@ from ..server_error import handle_server_error
|
|||
from ... import opts
|
||||
|
||||
openai_bp = Blueprint('openai/v1/', __name__)
|
||||
openai_model_bp = Blueprint('openai/', __name__)
|
||||
|
||||
|
||||
@openai_bp.before_request
|
||||
@openai_model_bp.before_request
|
||||
def before_oai_request():
|
||||
if not opts.enable_openi_compatible_backend:
|
||||
return 'The OpenAI-compatible backend is disabled.', 401
|
||||
|
@ -15,6 +17,7 @@ def before_oai_request():
|
|||
|
||||
|
||||
@openai_bp.errorhandler(500)
|
||||
@openai_model_bp.errorhandler(500)
|
||||
def handle_error(e):
|
||||
return handle_server_error(e)
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import traceback
|
|||
from flask import Response, jsonify, request
|
||||
|
||||
from llm_server.custom_redis import redis
|
||||
from . import openai_bp
|
||||
from . import openai_bp, openai_model_bp
|
||||
from ..helpers.http import validate_json
|
||||
from ..openai_request_handler import OpenAIRequestHandler
|
||||
from ..queue import priority_queue
|
||||
|
@ -20,12 +20,13 @@ from ...llm.openai.transform import generate_oai_string, transform_messages_to_p
|
|||
|
||||
|
||||
@openai_bp.route('/chat/completions', methods=['POST'])
|
||||
def openai_chat_completions():
|
||||
@openai_model_bp.route('/<model_name>/v1/chat/completions', methods=['POST'])
|
||||
def openai_chat_completions(model_name=None):
|
||||
request_valid_json, request_json_body = validate_json(request)
|
||||
if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'):
|
||||
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
|
||||
else:
|
||||
handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body)
|
||||
handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name)
|
||||
if not request_json_body.get('stream'):
|
||||
try:
|
||||
invalid_oai_err_msg = validate_oai(request_json_body)
|
||||
|
|
|
@ -5,7 +5,7 @@ import simplejson as json
|
|||
from flask import Response, jsonify, request
|
||||
|
||||
from llm_server.custom_redis import redis
|
||||
from . import openai_bp
|
||||
from . import openai_bp, openai_model_bp
|
||||
from ..helpers.http import validate_json
|
||||
from ..ooba_request_handler import OobaRequestHandler
|
||||
from ..queue import priority_queue
|
||||
|
@ -20,12 +20,13 @@ from ...llm.openai.transform import generate_oai_string, trim_string_to_fit
|
|||
# TODO: add rate-limit headers?
|
||||
|
||||
@openai_bp.route('/completions', methods=['POST'])
|
||||
def openai_completions():
|
||||
@openai_model_bp.route('/<model_name>/v1/completions', methods=['POST'])
|
||||
def openai_completions(model_name=None):
|
||||
request_valid_json, request_json_body = validate_json(request)
|
||||
if not request_valid_json or not request_json_body.get('prompt'):
|
||||
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
||||
else:
|
||||
handler = OobaRequestHandler(incoming_request=request)
|
||||
handler = OobaRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name)
|
||||
|
||||
if handler.cluster_backend_info['mode'] != 'vllm':
|
||||
# TODO: implement other backends
|
||||
|
|
|
@ -20,7 +20,7 @@ from llm_server.routes.queue import priority_queue
|
|||
class RequestHandler:
|
||||
def __init__(self, incoming_request: flask.Request, selected_model: str = None, incoming_json: Union[dict, str] = None):
|
||||
self.request = incoming_request
|
||||
self.enable_backend_blind_rrd = request.headers.get('LLM-Blind-RRD', False) == 'true'
|
||||
# self.enable_backend_blind_rrd = request.headers.get('LLM-Blind-RRD', False) == 'true'
|
||||
|
||||
# Routes need to validate it, here we just load it
|
||||
if incoming_json:
|
||||
|
|
|
@ -14,7 +14,7 @@ def generate(model_name=None):
|
|||
if not request_valid_json or not request_json_body.get('prompt'):
|
||||
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
||||
else:
|
||||
handler = OobaRequestHandler(request, model_name)
|
||||
handler = OobaRequestHandler(request, selected_model=model_name)
|
||||
try:
|
||||
return handler.handle_request()
|
||||
except Exception:
|
||||
|
|
|
@ -24,7 +24,7 @@ from llm_server.database.create import create_db
|
|||
from llm_server.helpers import auto_set_base_client_api
|
||||
from llm_server.llm.vllm.info import vllm_info
|
||||
from llm_server.pre_fork import server_startup
|
||||
from llm_server.routes.openai import openai_bp
|
||||
from llm_server.routes.openai import openai_bp, openai_model_bp
|
||||
from llm_server.routes.server_error import handle_server_error
|
||||
from llm_server.routes.v1 import bp
|
||||
from llm_server.routes.v1.generate_stats import generate_stats
|
||||
|
@ -70,6 +70,7 @@ except ModuleNotFoundError as e:
|
|||
app = Flask(__name__)
|
||||
app.register_blueprint(bp, url_prefix='/api/')
|
||||
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
|
||||
app.register_blueprint(openai_model_bp, url_prefix='/api/openai/')
|
||||
init_socketio(app)
|
||||
flask_cache.init_app(app)
|
||||
flask_cache.clear()
|
||||
|
|
Reference in New Issue