import json import os import sys from pathlib import Path from threading import Thread from flask import Flask, jsonify, render_template, request import config from llm_server import opts from llm_server.config import ConfigLoader, config_default_vars, config_required_vars, mode_ui_names from llm_server.database import get_number_of_rows, init_db from llm_server.helpers import resolve_path from llm_server.llm.hf_textgen.info import hf_textget_info from llm_server.routes.cache import cache, redis from llm_server.routes.queue import start_workers from llm_server.routes.stats import SemaphoreCheckerThread, process_avg_gen_time from llm_server.routes.v1 import bp from llm_server.routes.v1.generate_stats import generate_stats from llm_server.stream import init_socketio from llm_server.threads import MainBackgroundThread script_path = os.path.dirname(os.path.realpath(__file__)) config_path_environ = os.getenv("CONFIG_PATH") if config_path_environ: config_path = config_path_environ else: config_path = Path(script_path, 'config', 'config.yml') config_loader = ConfigLoader(config_path, config_default_vars, config_required_vars) success, config, msg = config_loader.load_config() if not success: print('Failed to load config:', msg) sys.exit(1) # Resolve relative directory to the directory of the script if config['database_path'].startswith('./'): config['database_path'] = resolve_path(script_path, config['database_path'].strip('./')) opts.database_path = resolve_path(config['database_path']) init_db() if config['mode'] not in ['oobabooga', 'hf-textgen', 'vllm']: print('Unknown mode:', config['mode']) sys.exit(1) opts.mode = config['mode'] opts.auth_required = config['auth_required'] opts.log_prompts = config['log_prompts'] opts.concurrent_gens = config['concurrent_gens'] opts.frontend_api_client = config['frontend_api_client'] opts.context_size = config['token_limit'] opts.show_num_prompts = config['show_num_prompts'] opts.show_uptime = config['show_uptime'] opts.backend_url = config['backend_url'].strip('/') opts.show_total_output_tokens = config['show_total_output_tokens'] opts.netdata_root = config['netdata_root'] opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip'] opts.show_backend_info = config['show_backend_info'] opts.max_new_tokens = config['max_new_tokens'] opts.verify_ssl = config['verify_ssl'] if not opts.verify_ssl: import urllib3 urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) flushed_keys = redis.flush() print('Flushed', len(flushed_keys), 'keys from Redis.') if config['load_num_prompts']: redis.set('proompts', get_number_of_rows('prompts')) if config['average_generation_time_mode'] not in ['database', 'minute']: print('Invalid value for config item "average_generation_time_mode":', config['average_generation_time_mode']) sys.exit(1) opts.average_generation_time_mode = config['average_generation_time_mode'] start_workers(opts.concurrent_gens) # cleanup_thread = Thread(target=elapsed_times_cleanup) # cleanup_thread.daemon = True # cleanup_thread.start() # Start the background thread process_avg_gen_time_background_thread = Thread(target=process_avg_gen_time) process_avg_gen_time_background_thread.daemon = True process_avg_gen_time_background_thread.start() MainBackgroundThread().start() SemaphoreCheckerThread().start() app = Flask(__name__) cache.init_app(app) cache.clear() # clear redis cache init_socketio(app) # with app.app_context(): # current_app.tokenizer = tiktoken.get_encoding("cl100k_base") app.register_blueprint(bp, url_prefix='/api/v1/') # print(app.url_map) @app.route('/') @app.route('/api') @cache.cached(timeout=10, query_string=True) def home(): if not opts.base_client_api: opts.base_client_api = f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}' stats = generate_stats() if not bool(redis.get('backend_online')) or not stats['online']: running_model = estimated_wait_sec = 'offline' else: running_model = opts.running_model if stats['queue']['queued'] == 0 and stats['queue']['processing'] > 0: # There will be a wait if the queue is empty but prompts are processing, but we don't # know how long. estimated_wait_sec = f"less than {stats['stats']['average_generation_elapsed_sec']} seconds" else: estimated_wait_sec = f"{stats['queue']['estimated_wait_sec']} seconds" if len(config['analytics_tracking_code']): analytics_tracking_code = f"" else: analytics_tracking_code = '' if config['info_html']: info_html = '
\n' + config['info_html'] else: info_html = '' return render_template('home.html', llm_middleware_name=config['llm_middleware_name'], analytics_tracking_code=analytics_tracking_code, info_html=info_html, current_model=running_model, client_api=stats['endpoints']['blocking'], ws_client_api=stats['endpoints']['streaming'], estimated_wait=estimated_wait_sec, mode_name=mode_ui_names[opts.mode][0], api_input_textbox=mode_ui_names[opts.mode][1], streaming_input_textbox=mode_ui_names[opts.mode][2], context_size=opts.context_size, stats_json=json.dumps(stats, indent=4, ensure_ascii=False), extra_info=hf_textget_info if opts.mode == 'hf-textgen' else '', ) @app.route('/') @app.route('//') def fallback(first=None, rest=None): return jsonify({ 'code': 404, 'msg': 'not found' }), 404 if __name__ == "__main__": app.run(host='0.0.0.0')