From e86a5182ebb8e90f4b464ecb52732f0b459201dd Mon Sep 17 00:00:00 2001 From: Cyberes Date: Wed, 27 Sep 2023 23:36:44 -0600 Subject: [PATCH] redo background processes, reorganize server.py --- daemon.py | 44 +++++++ gunicorn.py | 4 +- llm_server/config/__init__.py | 0 llm_server/{ => config}/config.py | 0 llm_server/config/load.py | 86 +++++++++++++ llm_server/pre_fork.py | 21 ++++ llm_server/routes/helpers/http.py | 2 +- llm_server/routes/openai_request_handler.py | 2 +- llm_server/routes/stats.py | 26 ---- llm_server/threads.py | 120 ------------------ llm_server/workers/app.py | 35 ++++++ llm_server/workers/blocking.py | 11 +- llm_server/workers/main.py | 56 +++++++++ llm_server/workers/moderator.py | 51 ++++++++ llm_server/workers/printer.py | 10 +- llm_server/workers/recent.py | 19 +++ llm_server/workers/threads.py | 9 ++ requirements.txt | 3 +- server.py | 130 ++------------------ 19 files changed, 344 insertions(+), 285 deletions(-) create mode 100644 daemon.py create mode 100644 llm_server/config/__init__.py rename llm_server/{ => config}/config.py (100%) create mode 100644 llm_server/config/load.py create mode 100644 llm_server/pre_fork.py delete mode 100644 llm_server/threads.py create mode 100644 llm_server/workers/app.py create mode 100644 llm_server/workers/main.py create mode 100644 llm_server/workers/moderator.py create mode 100644 llm_server/workers/recent.py create mode 100644 llm_server/workers/threads.py diff --git a/daemon.py b/daemon.py new file mode 100644 index 0000000..20ec300 --- /dev/null +++ b/daemon.py @@ -0,0 +1,44 @@ +import time + +from llm_server.routes.cache import redis + +try: + import gevent.monkey + + gevent.monkey.patch_all() +except ImportError: + pass + +import os +import sys +from pathlib import Path + +from llm_server.config.load import load_config +from llm_server.database.create import create_db + +from llm_server.workers.app import start_background + +script_path = os.path.dirname(os.path.realpath(__file__)) +config_path_environ = os.getenv("CONFIG_PATH") +if config_path_environ: + config_path = config_path_environ +else: + config_path = Path(script_path, 'config', 'config.yml') + +if __name__ == "__main__": + flushed_keys = redis.flush() + print('Flushed', len(flushed_keys), 'keys from Redis.') + + success, config, msg = load_config(config_path, script_path) + if not success: + print('Failed to load config:', msg) + sys.exit(1) + + create_db() + start_background() + + redis.set('daemon_started', 1) + print('== Daemon Setup Complete ==\n') + + while True: + time.sleep(3600) diff --git a/gunicorn.py b/gunicorn.py index 36501df..30f9274 100644 --- a/gunicorn.py +++ b/gunicorn.py @@ -5,9 +5,9 @@ try: except ImportError: pass -import server +from llm_server.pre_fork import server_startup def on_starting(s): - server.pre_fork(s) + server_startup(s) print('Startup complete!') diff --git a/llm_server/config/__init__.py b/llm_server/config/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/llm_server/config.py b/llm_server/config/config.py similarity index 100% rename from llm_server/config.py rename to llm_server/config/config.py diff --git a/llm_server/config/load.py b/llm_server/config/load.py new file mode 100644 index 0000000..5190984 --- /dev/null +++ b/llm_server/config/load.py @@ -0,0 +1,86 @@ +import re +import sys + +import openai + +from llm_server import opts +from llm_server.config.config import ConfigLoader, config_default_vars, config_required_vars +from llm_server.database.conn import database +from llm_server.database.database import get_number_of_rows +from llm_server.helpers import resolve_path +from llm_server.routes.cache import redis + + +def load_config(config_path, script_path): + config_loader = ConfigLoader(config_path, config_default_vars, config_required_vars) + success, config, msg = config_loader.load_config() + if not success: + return success, config, msg + + # Resolve relative directory to the directory of the script + if config['database_path'].startswith('./'): + config['database_path'] = resolve_path(script_path, config['database_path'].strip('./')) + + if config['mode'] not in ['oobabooga', 'vllm']: + print('Unknown mode:', config['mode']) + sys.exit(1) + + # TODO: this is atrocious + opts.mode = config['mode'] + opts.auth_required = config['auth_required'] + opts.log_prompts = config['log_prompts'] + opts.concurrent_gens = config['concurrent_gens'] + opts.frontend_api_client = config['frontend_api_client'] + opts.context_size = config['token_limit'] + opts.show_num_prompts = config['show_num_prompts'] + opts.show_uptime = config['show_uptime'] + opts.backend_url = config['backend_url'].strip('/') + opts.show_total_output_tokens = config['show_total_output_tokens'] + opts.netdata_root = config['netdata_root'] + opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip'] + opts.show_backend_info = config['show_backend_info'] + opts.max_new_tokens = config['max_new_tokens'] + opts.manual_model_name = config['manual_model_name'] + opts.llm_middleware_name = config['llm_middleware_name'] + opts.enable_openi_compatible_backend = config['enable_openi_compatible_backend'] + opts.openai_system_prompt = config['openai_system_prompt'] + opts.expose_openai_system_prompt = config['expose_openai_system_prompt'] + opts.enable_streaming = config['enable_streaming'] + opts.openai_api_key = config['openai_api_key'] + openai.api_key = opts.openai_api_key + opts.admin_token = config['admin_token'] + opts.openai_expose_our_model = config['openai_epose_our_model'] + opts.openai_force_no_hashes = config['openai_force_no_hashes'] + opts.include_system_tokens_in_stats = config['include_system_tokens_in_stats'] + opts.openai_moderation_scan_last_n = config['openai_moderation_scan_last_n'] + opts.openai_moderation_workers = config['openai_moderation_workers'] + opts.openai_org_name = config['openai_org_name'] + opts.openai_silent_trim = config['openai_silent_trim'] + opts.openai_moderation_enabled = config['openai_moderation_enabled'] + + if opts.openai_expose_our_model and not opts.openai_api_key: + print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.') + sys.exit(1) + + opts.verify_ssl = config['verify_ssl'] + if not opts.verify_ssl: + import urllib3 + + urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + + if config['http_host']: + http_host = re.sub(r'http(?:s)?://', '', config["http_host"]) + redis.set('http_host', http_host) + redis.set('base_client_api', f'{http_host}/{opts.frontend_api_client.strip("/")}') + + database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database']) + + if config['load_num_prompts']: + redis.set('proompts', get_number_of_rows('prompts')) + + redis.set_dict('recent_prompters', {}) + redis.set_dict('processing_ips', {}) + redis.set_dict('queued_ip_count', {}) + redis.set('backend_mode', opts.mode) + + return success, config, msg diff --git a/llm_server/pre_fork.py b/llm_server/pre_fork.py new file mode 100644 index 0000000..f3ea0f4 --- /dev/null +++ b/llm_server/pre_fork.py @@ -0,0 +1,21 @@ +import sys + +from redis import Redis + +from llm_server.routes.cache import redis +from llm_server.routes.v1.generate_stats import generate_stats + + +def server_startup(s): + if not redis.get('daemon_started', bool): + print('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?') + sys.exit(1) + + # Flush the RedisPriorityQueue database. + queue_redis = Redis(host='localhost', port=6379, db=15) + for key in queue_redis.scan_iter('*'): + queue_redis.delete(key) + + # Cache the initial stats + print('Loading backend stats...') + generate_stats() diff --git a/llm_server/routes/helpers/http.py b/llm_server/routes/helpers/http.py index acc7447..2fa1190 100644 --- a/llm_server/routes/helpers/http.py +++ b/llm_server/routes/helpers/http.py @@ -57,7 +57,7 @@ def require_api_key(json_body: dict = None): return jsonify({'code': 403, 'message': 'Invalid API key'}), 403 elif 'Authorization' in request.headers: token = parse_token(request.headers['Authorization']) - if token.startswith('SYSTEM__') or opts.auth_required: + if (token and token.startswith('SYSTEM__')) or opts.auth_required: if is_valid_api_key(token): return else: diff --git a/llm_server/routes/openai_request_handler.py b/llm_server/routes/openai_request_handler.py index 561320c..5f26a35 100644 --- a/llm_server/routes/openai_request_handler.py +++ b/llm_server/routes/openai_request_handler.py @@ -10,7 +10,7 @@ from llm_server import opts from llm_server.database.database import is_api_key_moderated from llm_server.llm.openai.transform import build_openai_response, transform_messages_to_prompt, trim_prompt_to_fit from llm_server.routes.request_handler import RequestHandler -from llm_server.threads import add_moderation_task, get_results +from llm_server.workers.moderator import add_moderation_task, get_results class OpenAIRequestHandler(RequestHandler): diff --git a/llm_server/routes/stats.py b/llm_server/routes/stats.py index a16c820..a6e9e17 100644 --- a/llm_server/routes/stats.py +++ b/llm_server/routes/stats.py @@ -1,6 +1,4 @@ -import time from datetime import datetime -from threading import Thread from llm_server.routes.cache import redis @@ -46,27 +44,3 @@ def get_active_gen_workers(): else: count = int(active_gen_workers) return count - - -class SemaphoreCheckerThread(Thread): - redis.set_dict('recent_prompters', {}) - - def __init__(self): - Thread.__init__(self) - self.daemon = True - - def run(self): - while True: - current_time = time.time() - recent_prompters = redis.get_dict('recent_prompters') - new_recent_prompters = {} - - for ip, (timestamp, token) in recent_prompters.items(): - if token and token.startswith('SYSTEM__'): - continue - if current_time - timestamp <= 300: - new_recent_prompters[ip] = timestamp, token - - redis.set_dict('recent_prompters', new_recent_prompters) - redis.set('proompters_5_min', len(new_recent_prompters)) - time.sleep(1) diff --git a/llm_server/threads.py b/llm_server/threads.py deleted file mode 100644 index a6d8b90..0000000 --- a/llm_server/threads.py +++ /dev/null @@ -1,120 +0,0 @@ -import json -import threading -import time -import traceback -from threading import Thread - -import redis as redis_redis - -from llm_server import opts -from llm_server.database.database import weighted_average_column_for_model -from llm_server.llm.info import get_running_model -from llm_server.llm.openai.moderation import check_moderation_endpoint -from llm_server.routes.cache import redis -from llm_server.routes.v1.generate_stats import generate_stats - - -class MainBackgroundThread(Thread): - backend_online = False - - # TODO: do I really need to put everything in Redis? - # TODO: call generate_stats() every minute, cache the results, put results in a DB table, then have other parts of code call this cache - - def __init__(self): - Thread.__init__(self) - self.daemon = True - redis.set('average_generation_elapsed_sec', 0) - redis.set('estimated_avg_tps', 0) - redis.set('average_output_tokens', 0) - redis.set('backend_online', 0) - redis.set_dict('backend_info', {}) - - def run(self): - while True: - # TODO: unify this - if opts.mode == 'oobabooga': - running_model, err = get_running_model() - if err: - print(err) - redis.set('backend_online', 0) - else: - redis.set('running_model', running_model) - redis.set('backend_online', 1) - elif opts.mode == 'vllm': - running_model, err = get_running_model() - if err: - print(err) - redis.set('backend_online', 0) - else: - redis.set('running_model', running_model) - redis.set('backend_online', 1) - else: - raise Exception - - # exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0 - # was entered into the column. The new code enters null instead but we need to be backwards compatible for now. - average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0 - if average_generation_elapsed_sec: # returns None on exception - redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec) - - # overall = average_column_for_model('prompts', 'generation_time', opts.running_model) - # print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}') - - average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0 - if average_generation_elapsed_sec: - redis.set('average_output_tokens', average_output_tokens) - - # overall = average_column_for_model('prompts', 'response_tokens', opts.running_model) - # print(f'Weighted: {average_output_tokens}, overall: {overall}') - - estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero - redis.set('estimated_avg_tps', estimated_avg_tps) - time.sleep(60) - - -def cache_stats(): - while True: - generate_stats(regen=True) - time.sleep(5) - - -redis_moderation = redis_redis.Redis() - - -def start_moderation_workers(num_workers): - for _ in range(num_workers): - t = threading.Thread(target=moderation_worker) - t.daemon = True - t.start() - - -def moderation_worker(): - while True: - result = redis_moderation.blpop('queue:msgs_to_check') - try: - msg, tag = json.loads(result[1]) - _, categories = check_moderation_endpoint(msg) - redis_moderation.rpush('queue:flagged_categories', json.dumps((tag, categories))) - except: - print(result) - traceback.print_exc() - continue - - -def add_moderation_task(msg, tag): - redis_moderation.rpush('queue:msgs_to_check', json.dumps((msg, str(tag)))) - - -def get_results(tag, num_tasks): - tag = str(tag) # Required for comparison with Redis results. - flagged_categories = set() - num_results = 0 - while num_results < num_tasks: - result = redis_moderation.blpop('queue:flagged_categories') - result_tag, categories = json.loads(result[1]) - if result_tag == tag: - if categories: - for item in categories: - flagged_categories.add(item) - num_results += 1 - return list(flagged_categories) diff --git a/llm_server/workers/app.py b/llm_server/workers/app.py new file mode 100644 index 0000000..fda6fb3 --- /dev/null +++ b/llm_server/workers/app.py @@ -0,0 +1,35 @@ +from threading import Thread + +from .blocking import start_workers +from .main import main_background_thread +from .moderator import start_moderation_workers +from .printer import console_printer +from .recent import recent_prompters_thread +from .threads import cache_stats +from .. import opts + + +def start_background(): + start_workers(opts.concurrent_gens) + + t = Thread(target=main_background_thread) + t.daemon = True + t.start() + print('Started the main background thread.') + + start_moderation_workers(opts.openai_moderation_workers) + + t = Thread(target=cache_stats) + t.daemon = True + t.start() + print('Started the stats cacher.') + + t = Thread(target=recent_prompters_thread) + t.daemon = True + t.start() + print('Started the recent proompters thread.') + + t = Thread(target=console_printer) + t.daemon = True + t.start() + print('Started the console printer.') diff --git a/llm_server/workers/blocking.py b/llm_server/workers/blocking.py index 91a04c4..23112d9 100644 --- a/llm_server/workers/blocking.py +++ b/llm_server/workers/blocking.py @@ -1,4 +1,3 @@ -import json import threading import time @@ -17,12 +16,12 @@ def worker(): increment_ip_count(client_ip, 'processing_ips') redis.incr('active_gen_workers') - if not request_json_body: - # This was a dummy request from the websocket handler. - # We're going to let the websocket handler decrement processing_ips and active_gen_workers. - continue - try: + if not request_json_body: + # This was a dummy request from the websocket handler. + # We're going to let the websocket handler decrement processing_ips and active_gen_workers. + continue + start_time = time.time() success, response, error_msg = generator(request_json_body) end_time = time.time() diff --git a/llm_server/workers/main.py b/llm_server/workers/main.py new file mode 100644 index 0000000..747f699 --- /dev/null +++ b/llm_server/workers/main.py @@ -0,0 +1,56 @@ +import time +from threading import Thread + +from llm_server import opts +from llm_server.database.database import weighted_average_column_for_model +from llm_server.llm.info import get_running_model +from llm_server.routes.cache import redis + + +def main_background_thread(): + redis.set('average_generation_elapsed_sec', 0) + redis.set('estimated_avg_tps', 0) + redis.set('average_output_tokens', 0) + redis.set('backend_online', 0) + redis.set_dict('backend_info', {}) + + while True: + # TODO: unify this + if opts.mode == 'oobabooga': + running_model, err = get_running_model() + if err: + print(err) + redis.set('backend_online', 0) + else: + redis.set('running_model', running_model) + redis.set('backend_online', 1) + elif opts.mode == 'vllm': + running_model, err = get_running_model() + if err: + print(err) + redis.set('backend_online', 0) + else: + redis.set('running_model', running_model) + redis.set('backend_online', 1) + else: + raise Exception + + # exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0 + # was entered into the column. The new code enters null instead but we need to be backwards compatible for now. + average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0 + if average_generation_elapsed_sec: # returns None on exception + redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec) + + # overall = average_column_for_model('prompts', 'generation_time', opts.running_model) + # print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}') + + average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0 + if average_generation_elapsed_sec: + redis.set('average_output_tokens', average_output_tokens) + + # overall = average_column_for_model('prompts', 'response_tokens', opts.running_model) + # print(f'Weighted: {average_output_tokens}, overall: {overall}') + + estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero + redis.set('estimated_avg_tps', estimated_avg_tps) + time.sleep(60) diff --git a/llm_server/workers/moderator.py b/llm_server/workers/moderator.py new file mode 100644 index 0000000..4457d05 --- /dev/null +++ b/llm_server/workers/moderator.py @@ -0,0 +1,51 @@ +import json +import threading +import traceback + +import redis as redis_redis + +from llm_server.llm.openai.moderation import check_moderation_endpoint + +redis_moderation = redis_redis.Redis() + + +def start_moderation_workers(num_workers): + i = 0 + for _ in range(num_workers): + t = threading.Thread(target=moderation_worker) + t.daemon = True + t.start() + i += 1 + print(f'Started {i} moderation workers.') + + +def moderation_worker(): + while True: + result = redis_moderation.blpop('queue:msgs_to_check') + try: + msg, tag = json.loads(result[1]) + _, categories = check_moderation_endpoint(msg) + redis_moderation.rpush('queue:flagged_categories', json.dumps((tag, categories))) + except: + print(result) + traceback.print_exc() + continue + + +def add_moderation_task(msg, tag): + redis_moderation.rpush('queue:msgs_to_check', json.dumps((msg, str(tag)))) + + +def get_results(tag, num_tasks): + tag = str(tag) # Required for comparison with Redis results. + flagged_categories = set() + num_results = 0 + while num_results < num_tasks: + result = redis_moderation.blpop('queue:flagged_categories') + result_tag, categories = json.loads(result[1]) + if result_tag == tag: + if categories: + for item in categories: + flagged_categories.add(item) + num_results += 1 + return list(flagged_categories) diff --git a/llm_server/workers/printer.py b/llm_server/workers/printer.py index 40d3c88..36920a5 100644 --- a/llm_server/workers/printer.py +++ b/llm_server/workers/printer.py @@ -1,5 +1,4 @@ import logging -import threading import time from llm_server.routes.cache import redis @@ -15,14 +14,9 @@ if not logger.handlers: def console_printer(): + time.sleep(3) while True: queued_ip_count = sum([v for k, v in redis.get_dict('queued_ip_count').items()]) processing_count = sum([v for k, v in redis.get_dict('processing_ips').items()]) logger.info(f'REQUEST QUEUE -> Processing: {processing_count} | Queued: {queued_ip_count}') - time.sleep(10) - - -def start_console_printer(): - t = threading.Thread(target=console_printer) - t.daemon = True - t.start() + time.sleep(15) diff --git a/llm_server/workers/recent.py b/llm_server/workers/recent.py new file mode 100644 index 0000000..ce5d20f --- /dev/null +++ b/llm_server/workers/recent.py @@ -0,0 +1,19 @@ +import time + +from llm_server.routes.cache import redis + + +def recent_prompters_thread(): + current_time = time.time() + recent_prompters = redis.get_dict('recent_prompters') + new_recent_prompters = {} + + for ip, (timestamp, token) in recent_prompters.items(): + if token and token.startswith('SYSTEM__'): + continue + if current_time - timestamp <= 300: + new_recent_prompters[ip] = timestamp, token + + redis.set_dict('recent_prompters', new_recent_prompters) + redis.set('proompters_5_min', len(new_recent_prompters)) + time.sleep(1) diff --git a/llm_server/workers/threads.py b/llm_server/workers/threads.py new file mode 100644 index 0000000..d1c5183 --- /dev/null +++ b/llm_server/workers/threads.py @@ -0,0 +1,9 @@ +import time + +from llm_server.routes.v1.generate_stats import generate_stats + + +def cache_stats(): + while True: + generate_stats(regen=True) + time.sleep(5) diff --git a/requirements.txt b/requirements.txt index 4773cab..9b0c8eb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,6 @@ flask_caching requests~=2.31.0 tiktoken~=0.5.0 gunicorn -redis~=5.0.0 gevent~=23.9.0.post1 async-timeout flask-sock @@ -19,4 +18,4 @@ websockets~=11.0.3 basicauth~=1.0.0 openai~=0.28.0 urllib3~=2.0.4 -rq~=1.15.1 \ No newline at end of file +celery[redis] diff --git a/server.py b/server.py index 548c6bf..66a0479 100644 --- a/server.py +++ b/server.py @@ -1,3 +1,5 @@ +from llm_server.config.config import mode_ui_names + try: import gevent.monkey @@ -5,28 +7,23 @@ try: except ImportError: pass -from llm_server.workers.printer import start_console_printer +from llm_server.pre_fork import server_startup +from llm_server.config.load import load_config import os -import re import sys from pathlib import Path -from threading import Thread -import openai import simplejson as json from flask import Flask, jsonify, render_template, request -from redis import Redis import llm_server from llm_server.database.conn import database from llm_server.database.create import create_db -from llm_server.database.database import get_number_of_rows from llm_server.llm import get_token_count from llm_server.routes.openai import openai_bp from llm_server.routes.server_error import handle_server_error from llm_server.routes.v1 import bp from llm_server.stream import init_socketio -from llm_server.workers.blocking import start_workers # TODO: have the workers handle streaming too # TODO: add backend fallbacks. Backends at the bottom of the list are higher priority and are fallbacks if the upper ones fail @@ -62,16 +59,12 @@ except ModuleNotFoundError as e: import config from llm_server import opts -from llm_server.config import ConfigLoader, config_default_vars, config_required_vars, mode_ui_names -from llm_server.helpers import resolve_path, auto_set_base_client_api +from llm_server.helpers import auto_set_base_client_api from llm_server.llm.vllm.info import vllm_info from llm_server.routes.cache import RedisWrapper, flask_cache from llm_server.llm import redis -from llm_server.routes.stats import SemaphoreCheckerThread, get_active_gen_workers +from llm_server.routes.stats import get_active_gen_workers from llm_server.routes.v1.generate_stats import generate_stats -from llm_server.threads import MainBackgroundThread, cache_stats, start_moderation_workers - -script_path = os.path.dirname(os.path.realpath(__file__)) app = Flask(__name__) init_socketio(app) @@ -80,123 +73,22 @@ app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/') flask_cache.init_app(app) flask_cache.clear() +script_path = os.path.dirname(os.path.realpath(__file__)) config_path_environ = os.getenv("CONFIG_PATH") if config_path_environ: config_path = config_path_environ else: config_path = Path(script_path, 'config', 'config.yml') -config_loader = ConfigLoader(config_path, config_default_vars, config_required_vars) -success, config, msg = config_loader.load_config() +success, config, msg = load_config(config_path, script_path) if not success: print('Failed to load config:', msg) sys.exit(1) -# Resolve relative directory to the directory of the script -if config['database_path'].startswith('./'): - config['database_path'] = resolve_path(script_path, config['database_path'].strip('./')) - database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database']) create_db() - -if config['mode'] not in ['oobabooga', 'vllm']: - print('Unknown mode:', config['mode']) - sys.exit(1) - -# TODO: this is atrocious -opts.mode = config['mode'] -opts.auth_required = config['auth_required'] -opts.log_prompts = config['log_prompts'] -opts.concurrent_gens = config['concurrent_gens'] -opts.frontend_api_client = config['frontend_api_client'] -opts.context_size = config['token_limit'] -opts.show_num_prompts = config['show_num_prompts'] -opts.show_uptime = config['show_uptime'] -opts.backend_url = config['backend_url'].strip('/') -opts.show_total_output_tokens = config['show_total_output_tokens'] -opts.netdata_root = config['netdata_root'] -opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip'] -opts.show_backend_info = config['show_backend_info'] -opts.max_new_tokens = config['max_new_tokens'] -opts.manual_model_name = config['manual_model_name'] -opts.llm_middleware_name = config['llm_middleware_name'] -opts.enable_openi_compatible_backend = config['enable_openi_compatible_backend'] -opts.openai_system_prompt = config['openai_system_prompt'] -opts.expose_openai_system_prompt = config['expose_openai_system_prompt'] -opts.enable_streaming = config['enable_streaming'] -opts.openai_api_key = config['openai_api_key'] -openai.api_key = opts.openai_api_key -opts.admin_token = config['admin_token'] -opts.openai_expose_our_model = config['openai_epose_our_model'] -opts.openai_force_no_hashes = config['openai_force_no_hashes'] -opts.include_system_tokens_in_stats = config['include_system_tokens_in_stats'] -opts.openai_moderation_scan_last_n = config['openai_moderation_scan_last_n'] -opts.openai_moderation_workers = config['openai_moderation_workers'] -opts.openai_org_name = config['openai_org_name'] -opts.openai_silent_trim = config['openai_silent_trim'] -opts.openai_moderation_enabled = config['openai_moderation_enabled'] - -if opts.openai_expose_our_model and not opts.openai_api_key: - print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.') - sys.exit(1) - -opts.verify_ssl = config['verify_ssl'] -if not opts.verify_ssl: - import urllib3 - - urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - -if config['average_generation_time_mode'] not in ['database', 'minute']: - print('Invalid value for config item "average_generation_time_mode":', config['average_generation_time_mode']) - sys.exit(1) -opts.average_generation_time_mode = config['average_generation_time_mode'] - -if opts.mode == 'oobabooga': - raise NotImplementedError - # llm_server.llm.tokenizer = OobaboogaBackend() -elif opts.mode == 'vllm': - llm_server.llm.get_token_count = llm_server.llm.vllm.tokenize -else: - raise Exception - - -def pre_fork(server): - llm_server.llm.redis = RedisWrapper('local_llm') - flushed_keys = redis.flush() - print('Flushed', len(flushed_keys), 'keys from Redis.') - - redis.set_dict('processing_ips', {}) - redis.set_dict('queued_ip_count', {}) - - # Flush the RedisPriorityQueue database. - queue_redis = Redis(host='localhost', port=6379, db=15) - for key in queue_redis.scan_iter('*'): - queue_redis.delete(key) - - redis.set('backend_mode', opts.mode) - if config['http_host']: - http_host = re.sub(r'http(?:s)?://', '', config["http_host"]) - redis.set('http_host', http_host) - redis.set('base_client_api', f'{http_host}/{opts.frontend_api_client.strip("/")}') - - if config['load_num_prompts']: - redis.set('proompts', get_number_of_rows('prompts')) - - # Start background processes - start_workers(opts.concurrent_gens) - start_console_printer() - start_moderation_workers(opts.openai_moderation_workers) - MainBackgroundThread().start() - SemaphoreCheckerThread().start() - - # This needs to be started after Flask is initalized - stats_updater_thread = Thread(target=cache_stats) - stats_updater_thread.daemon = True - stats_updater_thread.start() - - # Cache the initial stats - print('Loading backend stats...') - generate_stats() +llm_server.llm.redis = RedisWrapper('local_llm') +create_db() # print(app.url_map) @@ -280,6 +172,6 @@ def before_app_request(): if __name__ == "__main__": - pre_fork(None) + server_startup(None) print('FLASK MODE - Startup complete!') app.run(host='0.0.0.0', threaded=False, processes=15)