try: import gevent.monkey gevent.monkey.patch_all() except ImportError: pass import os import re import sys from pathlib import Path from threading import Thread import openai import simplejson as json from flask import Flask, jsonify, render_template, request import llm_server from llm_server.database.conn import database from llm_server.database.create import create_db from llm_server.database.database import get_number_of_rows from llm_server.llm import get_token_count from llm_server.routes.openai import openai_bp from llm_server.routes.server_error import handle_server_error from llm_server.routes.v1 import bp from llm_server.stream import init_socketio # TODO: have the workers handle streaming too # TODO: add backend fallbacks. Backends at the bottom of the list are higher priority and are fallbacks if the upper ones fail # TODO: implement background thread to test backends via sending test prompts # TODO: if backend fails request, mark it as down # TODO: allow setting concurrent gens per-backend # TODO: set the max tokens to that of the lowest backend # TODO: implement RRD backend loadbalancer option # TODO: simulate OpenAI error messages regardless of endpoint # TODO: send extra headers when ratelimited? # TODO: make sure log_prompt() is used everywhere, including errors and invalid requests # TODO: unify logging thread in a function and use async/await instead # Done, but need to verify # TODO: add more excluding to SYSTEM__ tokens # TODO: return 200 when returning formatted sillytavern error try: import vllm except ModuleNotFoundError as e: print('Could not import vllm-gptq:', e) print('Please see README.md for install instructions.') sys.exit(1) import config from llm_server import opts from llm_server.config import ConfigLoader, config_default_vars, config_required_vars, mode_ui_names from llm_server.helpers import resolve_path, auto_set_base_client_api from llm_server.llm.vllm.info import vllm_info from llm_server.routes.cache import RedisWrapper, flask_cache from llm_server.llm import redis from llm_server.routes.queue import start_workers from llm_server.routes.stats import SemaphoreCheckerThread, get_active_gen_workers, process_avg_gen_time from llm_server.routes.v1.generate_stats import generate_stats from llm_server.threads import MainBackgroundThread, cache_stats, start_moderation_workers script_path = os.path.dirname(os.path.realpath(__file__)) app = Flask(__name__) init_socketio(app) app.register_blueprint(bp, url_prefix='/api/v1/') app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/') flask_cache.init_app(app) flask_cache.clear() config_path_environ = os.getenv("CONFIG_PATH") if config_path_environ: config_path = config_path_environ else: config_path = Path(script_path, 'config', 'config.yml') config_loader = ConfigLoader(config_path, config_default_vars, config_required_vars) success, config, msg = config_loader.load_config() if not success: print('Failed to load config:', msg) sys.exit(1) # Resolve relative directory to the directory of the script if config['database_path'].startswith('./'): config['database_path'] = resolve_path(script_path, config['database_path'].strip('./')) database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database']) create_db() if config['mode'] not in ['oobabooga', 'vllm']: print('Unknown mode:', config['mode']) sys.exit(1) # TODO: this is atrocious opts.mode = config['mode'] opts.auth_required = config['auth_required'] opts.log_prompts = config['log_prompts'] opts.concurrent_gens = config['concurrent_gens'] opts.frontend_api_client = config['frontend_api_client'] opts.context_size = config['token_limit'] opts.show_num_prompts = config['show_num_prompts'] opts.show_uptime = config['show_uptime'] opts.backend_url = config['backend_url'].strip('/') opts.show_total_output_tokens = config['show_total_output_tokens'] opts.netdata_root = config['netdata_root'] opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip'] opts.show_backend_info = config['show_backend_info'] opts.max_new_tokens = config['max_new_tokens'] opts.manual_model_name = config['manual_model_name'] opts.llm_middleware_name = config['llm_middleware_name'] opts.enable_openi_compatible_backend = config['enable_openi_compatible_backend'] opts.openai_system_prompt = config['openai_system_prompt'] opts.expose_openai_system_prompt = config['expose_openai_system_prompt'] opts.enable_streaming = config['enable_streaming'] opts.openai_api_key = config['openai_api_key'] openai.api_key = opts.openai_api_key opts.admin_token = config['admin_token'] opts.openai_expose_our_model = config['openai_epose_our_model'] opts.openai_force_no_hashes = config['openai_force_no_hashes'] opts.include_system_tokens_in_stats = config['include_system_tokens_in_stats'] opts.openai_moderation_scan_last_n = config['openai_moderation_scan_last_n'] opts.openai_moderation_workers = config['openai_moderation_workers'] opts.openai_org_name = config['openai_org_name'] opts.openai_silent_trim = config['openai_silent_trim'] opts.openai_moderation_enabled = config['openai_moderation_enabled'] if opts.openai_expose_our_model and not opts.openai_api_key: print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.') sys.exit(1) opts.verify_ssl = config['verify_ssl'] if not opts.verify_ssl: import urllib3 urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) if config['average_generation_time_mode'] not in ['database', 'minute']: print('Invalid value for config item "average_generation_time_mode":', config['average_generation_time_mode']) sys.exit(1) opts.average_generation_time_mode = config['average_generation_time_mode'] if opts.mode == 'oobabooga': raise NotImplementedError # llm_server.llm.tokenizer = OobaboogaBackend() elif opts.mode == 'vllm': llm_server.llm.get_token_count = llm_server.llm.vllm.tokenize else: raise Exception def pre_fork(server): llm_server.llm.redis = RedisWrapper('local_llm') flushed_keys = redis.flush() print('Flushed', len(flushed_keys), 'keys from Redis.') redis.set('backend_mode', opts.mode) if config['http_host']: http_host = re.sub(r'http(?:s)?://', '', config["http_host"]) redis.set('http_host', http_host) redis.set('base_client_api', f'{http_host}/{opts.frontend_api_client.strip("/")}') if config['load_num_prompts']: redis.set('proompts', get_number_of_rows('prompts')) # Start background processes start_workers(opts.concurrent_gens) start_moderation_workers(opts.openai_moderation_workers) process_avg_gen_time_background_thread = Thread(target=process_avg_gen_time) process_avg_gen_time_background_thread.daemon = True process_avg_gen_time_background_thread.start() MainBackgroundThread().start() SemaphoreCheckerThread().start() # This needs to be started after Flask is initalized stats_updater_thread = Thread(target=cache_stats) stats_updater_thread.daemon = True stats_updater_thread.start() # Cache the initial stats print('Loading backend stats...') generate_stats() # print(app.url_map) @app.route('/') @app.route('/api') @app.route('/api/openai') @flask_cache.cached(timeout=10) def home(): stats = generate_stats() if not stats['online']: running_model = estimated_wait_sec = 'offline' else: running_model = redis.get('running_model', str, 'ERROR') active_gen_workers = get_active_gen_workers() if stats['queue']['queued'] == 0 and active_gen_workers >= opts.concurrent_gens: # There will be a wait if the queue is empty but prompts are processing, but we don't # know how long. estimated_wait_sec = f"less than {stats['stats']['average_generation_elapsed_sec']} seconds" else: estimated_wait_sec = f"{stats['queue']['estimated_wait_sec']} seconds" if len(config['analytics_tracking_code']): analytics_tracking_code = f"" else: analytics_tracking_code = '' if config['info_html']: info_html = config['info_html'] else: info_html = '' mode_info = '' if opts.mode == 'vllm': mode_info = vllm_info base_client_api = redis.get('base_client_api', str) return render_template('home.html', llm_middleware_name=opts.llm_middleware_name, analytics_tracking_code=analytics_tracking_code, info_html=info_html, current_model=opts.manual_model_name if opts.manual_model_name else running_model, client_api=f'https://{base_client_api}', ws_client_api=f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else None, estimated_wait=estimated_wait_sec, mode_name=mode_ui_names[opts.mode][0], api_input_textbox=mode_ui_names[opts.mode][1], streaming_input_textbox=mode_ui_names[opts.mode][2], context_size=opts.context_size, stats_json=json.dumps(stats, indent=4, ensure_ascii=False), extra_info=mode_info, openai_client_api=f'https://{base_client_api}/openai/v1' if opts.enable_openi_compatible_backend else 'disabled', expose_openai_system_prompt=opts.expose_openai_system_prompt, enable_streaming=opts.enable_streaming, ) # TODO: add authenticated route to get the current backend URL. Add it to /v1/backend @app.route('/') @app.route('//') def fallback(first=None, rest=None): return jsonify({ 'code': 404, 'msg': 'not found' }), 404 @app.errorhandler(500) def server_error(e): return handle_server_error(e) @app.before_request def before_app_request(): auto_set_base_client_api(request) if __name__ == "__main__": pre_fork(None) print('FLASK MODE - Startup complete!') app.run(host='0.0.0.0', threaded=False, processes=15)