try: import gevent.monkey gevent.monkey.patch_all() except ImportError: pass import os import sys from pathlib import Path import simplejson as json from flask import Flask, jsonify, render_template, request from llm_server.cluster.cluster_config import cluster_config from llm_server.cluster.model_choices import get_model_choices from llm_server.config.config import mode_ui_names from llm_server.config.load import load_config from llm_server.database.conn import database from llm_server.database.create import create_db from llm_server.pre_fork import server_startup from llm_server.routes.openai import openai_bp from llm_server.routes.server_error import handle_server_error from llm_server.routes.v1 import old_v1_bp from llm_server.routes.v2 import bp from llm_server.sock import init_socketio # TODO: per-backend workers # TODO: allow setting concurrent gens per-backend # TODO: set the max tokens to that of the lowest backend # TODO: implement RRD backend loadbalancer option # TODO: have VLLM reject a request if it already has n == concurrent_gens running # TODO: add a way to cancel VLLM gens. Maybe use websockets? # TODO: use coloredlogs # TODO: need to update opts. for workers # TODO: add a healthcheck to VLLM # TODO: allow choosing the model by the URL path # TODO: have VLLM report context size, uptime # Lower priority # TODO: set VLLM to stream ALL data using socket.io. If the socket disconnects, cancel generation. # TODO: estiamted wait time needs to account for full concurrent_gens but the queue is less than concurrent_gens # TODO: the estiamted wait time lags behind the stats # TODO: simulate OpenAI error messages regardless of endpoint # TODO: send extra headers when ratelimited? # TODO: make sure log_prompt() is used everywhere, including errors and invalid requests # TODO: unify logging thread in a function and use async/await instead # TODO: move the netdata stats to a seperate part of the stats and have it set to the currently selected backend # TODO: have VLLM reply with stats (TPS, generated token count, processing time) # TODO: add config reloading via stored redis variables # Done, but need to verify # TODO: add more excluding to SYSTEM__ tokens # TODO: return 200 when returning formatted sillytavern error try: import vllm except ModuleNotFoundError as e: print('Could not import vllm-gptq:', e) print('Please see README.md for install instructions.') sys.exit(1) import config from llm_server import opts from llm_server.helpers import auto_set_base_client_api from llm_server.llm.vllm.info import vllm_info from llm_server.custom_redis import flask_cache from llm_server.llm import redis from llm_server.routes.v2.generate_stats import generate_stats app = Flask(__name__) init_socketio(app) app.register_blueprint(bp, url_prefix='/api/v2/') app.register_blueprint(old_v1_bp, url_prefix='/api/v1/') app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/') flask_cache.init_app(app) flask_cache.clear() script_path = os.path.dirname(os.path.realpath(__file__)) config_path_environ = os.getenv("CONFIG_PATH") if config_path_environ: config_path = config_path_environ else: config_path = Path(script_path, 'config', 'config.yml') success, config, msg = load_config(config_path) if not success: print('Failed to load config:', msg) sys.exit(1) database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database']) create_db() @app.route('/') @app.route('/api') @app.route('/api/openai') @flask_cache.cached(timeout=10) def home(): base_client_api = redis.get('base_client_api', dtype=str) stats = generate_stats() model_choices, default_backend_info = get_model_choices() if default_backend_info['queued'] == 0 and default_backend_info['queued'] >= opts.concurrent_gens: # There will be a wait if the queue is empty but prompts are processing, but we don't # know how long. default_estimated_wait_sec = f"less than {default_backend_info['estimated_wait']} seconds" else: default_estimated_wait_sec = f"{default_backend_info['estimated_wait']} seconds" if len(config['analytics_tracking_code']): analytics_tracking_code = f"" else: analytics_tracking_code = '' if config['info_html']: info_html = config['info_html'] else: info_html = '' mode_info = '' for k, v in cluster_config.all().items(): if v['mode'] == 'vllm': mode_info = vllm_info break return render_template('home.html', llm_middleware_name=opts.llm_middleware_name, analytics_tracking_code=analytics_tracking_code, info_html=info_html, default_model=default_backend_info['model'], default_active_gen_workers=default_backend_info['processing'], default_proompters_in_queue=default_backend_info['queued'], current_model=opts.manual_model_name if opts.manual_model_name else None, # else running_model, client_api=f'https://{base_client_api}/v2', ws_client_api=f'wss://{base_client_api}/v2/stream' if opts.enable_streaming else 'disabled', default_estimated_wait=default_estimated_wait_sec, mode_name=mode_ui_names[opts.mode][0], api_input_textbox=mode_ui_names[opts.mode][1], streaming_input_textbox=mode_ui_names[opts.mode][2], default_context_size=default_backend_info['context_size'], stats_json=json.dumps(stats, indent=4, ensure_ascii=False), extra_info=mode_info, openai_client_api=f'https://{base_client_api}/openai/v1' if opts.enable_openi_compatible_backend else 'disabled', expose_openai_system_prompt=opts.expose_openai_system_prompt, enable_streaming=opts.enable_streaming, model_choices=model_choices, proompters_5_min=stats['stats']['proompters']['5_min'], proompters_24_hrs=stats['stats']['proompters']['24_hrs'], ) @app.route('/') @app.route('//') def fallback(first=None, rest=None): return jsonify({ 'code': 404, 'msg': 'not found' }), 404 @app.errorhandler(500) def server_error(e): return handle_server_error(e) @app.before_request def before_app_request(): auto_set_base_client_api(request) if __name__ == "__main__": server_startup(None) print('FLASK MODE - Startup complete!') app.run(host='0.0.0.0', threaded=False, processes=15)