import simplejson as json from flask import Flask, jsonify, render_template, request, Response from llm_server.cluster.backend import get_model_choices from llm_server.cluster.cluster_config import cluster_config from llm_server.config.config import MODE_UI_NAMES from llm_server.config.global_config import GlobalConfig from llm_server.custom_redis import flask_cache, redis from llm_server.helpers import auto_set_base_client_api from llm_server.llm.vllm.info import vllm_info from llm_server.routes.openai import openai_bp, openai_model_bp from llm_server.routes.server_error import handle_server_error from llm_server.routes.v1 import bp from llm_server.routes.v1.generate_stats import generate_stats from llm_server.sock import init_wssocket # TODO: detect blocking disconnect # TODO: return an `error: True`, error code, and error message rather than just a formatted message # TODO: what happens when all backends are offline? What about the "online" key in the stats page? # TODO: redis SCAN vs KEYS?? # TODO: is frequency penalty the same as ooba repetition penalty??? # TODO: make sure openai_moderation_enabled works on websockets, completions, and chat completions # TODO: insert pydantic object into database # TODO: figure out blocking API disconnect https://news.ycombinator.com/item?id=41168033 # Lower priority # TODO: if a backend is at its limit of concurrent requests, choose a different one # TODO: make error messages consitient # TODO: support logit_bias on OpenAI and Ooba endpoints. # TODO: add a way to cancel VLLM gens. Maybe use websockets? # TODO: validate openai_silent_trim works as expected and only when enabled # TODO: rewrite config storage. Store in redis so we can reload it. # TODO: set VLLM to stream ALL data using socket.io. If the socket disconnects, cancel generation. # TODO: estiamted wait time needs to account for full concurrent_gens but the queue is less than concurrent_gens # TODO: the estiamted wait time lags behind the stats # TODO: simulate OpenAI error messages regardless of endpoint # TODO: send extra headers when ratelimited? # TODO: make sure log_prompt() is used everywhere, including errors and invalid requests # TODO: unify logging thread in a function and use async/await instead # TODO: move the netdata stats to a seperate part of the stats and have it set to the currently selected backend # TODO: have VLLM reply with stats (TPS, generated token count, processing time) # TODO: add config reloading via stored redis variables # Done, but need to verify # TODO: add more excluding to SYSTEM__ tokens # TODO: return 200 when returning formatted sillytavern error app = Flask(__name__) # Fixes ConcurrentObjectUseError # https://github.com/miguelgrinberg/simple-websocket/issues/24 app.config['SOCK_SERVER_OPTIONS'] = {'ping_interval': 25} app.register_blueprint(bp, url_prefix='/api/') app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/') app.register_blueprint(openai_model_bp, url_prefix='/api/openai/') init_wssocket(app) flask_cache.init_app(app) flask_cache.clear() @app.route('/') @app.route('/api') @app.route('/api/openai') @flask_cache.cached(timeout=10) def home(): base_client_api = redis.get('base_client_api', dtype=str) stats = generate_stats() model_choices, default_model = get_model_choices() if default_model: if not model_choices.get(default_model): return 'The server is still starting up. Please wait...' default_model_info = model_choices[default_model] if default_model_info['queued'] == 0 and default_model_info['queued'] >= default_model_info['concurrent_gens']: # There will be a wait if the queue is empty but prompts are processing, but we don't # know how long. default_estimated_wait_sec = f"less than {int(default_model_info['estimated_wait'])} seconds" else: default_estimated_wait_sec = f"{int(default_model_info['estimated_wait'])} seconds" else: default_model_info = { 'model': 'OFFLINE', 'processing': '-', 'queued': '-', 'context_size': '-', } default_estimated_wait_sec = 'OFFLINE' if default_model_info['context_size'] is None: # Sometimes a model doesn't provide the correct config, so the context size is set # to None by the daemon. default_model_info['context_size'] = '-' if len(GlobalConfig.get().analytics_tracking_code): analytics_tracking_code = f"" else: analytics_tracking_code = '' if GlobalConfig.get().info_html: info_html = GlobalConfig.get().info_html else: info_html = '' mode_info = '' for k, v in cluster_config.all().items(): if v['mode'] == 'vllm': mode_info = vllm_info break return render_template('home.html', llm_middleware_name=GlobalConfig.get().llm_middleware_name, analytics_tracking_code=analytics_tracking_code, info_html=info_html, default_model=default_model_info['model'], default_active_gen_workers=default_model_info['processing'], default_proompters_in_queue=default_model_info['queued'], current_model=GlobalConfig.get().manual_model_name if GlobalConfig.get().manual_model_name else None, # else running_model, client_api=f'https://{base_client_api}', ws_client_api=f'wss://{base_client_api}/v1/stream' if GlobalConfig.get().enable_streaming else 'disabled', default_estimated_wait=default_estimated_wait_sec, mode_name=MODE_UI_NAMES[GlobalConfig.get().frontend_api_mode].name, api_input_textbox=MODE_UI_NAMES[GlobalConfig.get().frontend_api_mode].api_name, streaming_input_textbox=MODE_UI_NAMES[GlobalConfig.get().frontend_api_mode].streaming_name, default_context_size=default_model_info['context_size'], stats_json=json.dumps(stats, indent=4, ensure_ascii=False), extra_info=mode_info, openai_client_api=f'https://{base_client_api}/openai/v1' if GlobalConfig.get().enable_openi_compatible_backend else 'disabled', expose_openai_system_prompt=GlobalConfig.get().expose_openai_system_prompt, enable_streaming=GlobalConfig.get().enable_streaming, model_choices=model_choices, proompters_5_min=stats['stats']['proompters']['5_min'], proompters_24_hrs=stats['stats']['proompters']['24_hrs'], ) @app.route('/robots.txt') def robots(): # TODO: have config value to deny all # TODO: https://developers.google.com/search/docs/crawling-indexing/robots/create-robots-txt t = """User-agent: * Allow: /""" r = Response(t) r.headers['Content-Type'] = 'text/plain' return r @app.route('/') @app.route('//') def fallback(first=None, rest=None): return jsonify({ 'code': 404, 'msg': 'not found' }), 404 @app.errorhandler(500) def server_error(e): return handle_server_error(e) @app.before_request def before_app_request(): auto_set_base_client_api(request) if __name__ == "__main__": print('Do not run this file directly. Instead, use gunicorn:') print("gunicorn -c other/gunicorn_conf.py server:app -b 0.0.0.0:5000 --worker-class gevent --workers 3 --access-logfile '-' --error-logfile '-'") quit(1)