diff --git a/gunicorn.py b/gunicorn.py new file mode 100644 index 0000000..36501df --- /dev/null +++ b/gunicorn.py @@ -0,0 +1,13 @@ +try: + import gevent.monkey + + gevent.monkey.patch_all() +except ImportError: + pass + +import server + + +def on_starting(s): + server.pre_fork(s) + print('Startup complete!') diff --git a/llm_server/database/database.py b/llm_server/database/database.py index c913fa5..6bd2026 100644 --- a/llm_server/database/database.py +++ b/llm_server/database/database.py @@ -6,6 +6,7 @@ import llm_server from llm_server import opts from llm_server.database.conn import db_pool from llm_server.llm.vllm import tokenize +from llm_server.routes.cache import redis def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, response_tokens: int = None, is_error: bool = False): @@ -33,6 +34,8 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe if token: increment_token_uses(token) + running_model = redis.get('running_model', str, 'ERROR') + timestamp = int(time.time()) conn = db_pool.connection() cursor = conn.cursor() @@ -42,7 +45,7 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe (ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) """, - (ip, token, opts.running_model, opts.mode, opts.backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp)) + (ip, token, running_model, opts.mode, opts.backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp)) finally: cursor.close() diff --git a/llm_server/llm/vllm/generate.py b/llm_server/llm/vllm/generate.py index 0c957a8..86a27ac 100644 --- a/llm_server/llm/vllm/generate.py +++ b/llm_server/llm/vllm/generate.py @@ -10,6 +10,7 @@ import requests import llm_server from llm_server import opts +from llm_server.routes.cache import redis # TODO: make the VLMM backend return TPS and time elapsed @@ -49,13 +50,14 @@ def transform_to_text(json_request, api_response): prompt_tokens = len(llm_server.llm.get_token_count(prompt)) completion_tokens = len(llm_server.llm.get_token_count(text)) + running_model = redis.get('running_model', str, 'ERROR') # https://platform.openai.com/docs/api-reference/making-requests?lang=python return { "id": str(uuid4()), "object": "chat.completion", "created": int(time.time()), - "model": opts.running_model, + "model": running_model, "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, diff --git a/llm_server/opts.py b/llm_server/opts.py index fee8a0e..5de33a5 100644 --- a/llm_server/opts.py +++ b/llm_server/opts.py @@ -2,7 +2,7 @@ # TODO: rewrite the config system so I don't have to add every single config default here -running_model = 'none' +running_model = 'ERROR' concurrent_gens = 3 mode = 'oobabooga' backend_url = None diff --git a/llm_server/routes/openai/chat_completions.py b/llm_server/routes/openai/chat_completions.py index 0be4c5d..d3fb2c4 100644 --- a/llm_server/routes/openai/chat_completions.py +++ b/llm_server/routes/openai/chat_completions.py @@ -6,6 +6,7 @@ import traceback from flask import Response, jsonify, request from . import openai_bp +from ..cache import redis from ..helpers.client import format_sillytavern_err from ..helpers.http import validate_json from ..openai_request_handler import OpenAIRequestHandler, build_openai_response, generate_oai_string @@ -50,7 +51,7 @@ def openai_chat_completions(): response = generator(msg_to_backend) r_headers = dict(request.headers) r_url = request.url - model = opts.running_model if opts.openai_expose_our_model else request_json_body.get('model') + model = redis.get('running_model', str, 'ERROR') if opts.openai_expose_our_model else request_json_body.get('model') oai_string = generate_oai_string(30) def generate(): diff --git a/llm_server/routes/openai/models.py b/llm_server/routes/openai/models.py index f9d8591..d9bcd3e 100644 --- a/llm_server/routes/openai/models.py +++ b/llm_server/routes/openai/models.py @@ -4,7 +4,7 @@ import requests from flask import jsonify from . import openai_bp -from ..cache import ONE_MONTH_SECONDS, cache +from ..cache import ONE_MONTH_SECONDS, cache, redis from ..stats import server_start_time from ... import opts from ...helpers import jsonify_pretty @@ -22,6 +22,7 @@ def openai_list_models(): 'type': error.__class__.__name__ }), 500 # return 500 so Cloudflare doesn't intercept us else: + running_model = redis.get('running_model', str, 'ERROR') oai = fetch_openai_models() r = [] if opts.openai_expose_our_model: @@ -29,13 +30,13 @@ def openai_list_models(): "object": "list", "data": [ { - "id": opts.running_model, + "id": running_model, "object": "model", "created": int(server_start_time.timestamp()), "owned_by": opts.llm_middleware_name, "permission": [ { - "id": opts.running_model, + "id": running_model, "object": "model_permission", "created": int(server_start_time.timestamp()), "allow_create_engine": False, diff --git a/llm_server/routes/openai_request_handler.py b/llm_server/routes/openai_request_handler.py index 833f0f7..69a36a9 100644 --- a/llm_server/routes/openai_request_handler.py +++ b/llm_server/routes/openai_request_handler.py @@ -14,6 +14,7 @@ from flask import jsonify import llm_server from llm_server import opts from llm_server.database.database import log_prompt +from llm_server.routes.cache import redis from llm_server.routes.helpers.client import format_sillytavern_err from llm_server.routes.request_handler import RequestHandler @@ -157,11 +158,13 @@ def build_openai_response(prompt, response, model=None): # TODO: async/await prompt_tokens = llm_server.llm.get_token_count(prompt) response_tokens = llm_server.llm.get_token_count(response) + running_model = redis.get('running_model', str, 'ERROR') + return jsonify({ "id": f"chatcmpl-{generate_oai_string(30)}", "object": "chat.completion", "created": int(time.time()), - "model": opts.running_model if opts.openai_expose_our_model else model, + "model": running_model if opts.openai_expose_our_model else model, "choices": [{ "index": 0, "message": { diff --git a/llm_server/routes/v1/generate_stats.py b/llm_server/routes/v1/generate_stats.py index 0a7e87a..9fe69f9 100644 --- a/llm_server/routes/v1/generate_stats.py +++ b/llm_server/routes/v1/generate_stats.py @@ -46,7 +46,7 @@ def generate_stats(regen: bool = False): online = False else: online = True - opts.running_model = model_name + redis.set('running_model', model_name) # t = elapsed_times.copy() # copy since we do multiple operations and don't want it to change # if len(t) == 0: diff --git a/llm_server/threads.py b/llm_server/threads.py index be72077..c542009 100644 --- a/llm_server/threads.py +++ b/llm_server/threads.py @@ -25,35 +25,36 @@ class MainBackgroundThread(Thread): def run(self): while True: + # TODO: unify this if opts.mode == 'oobabooga': - model, err = get_running_model() + running_model, err = get_running_model() if err: print(err) redis.set('backend_online', 0) else: - opts.running_model = model + redis.set('running_model', running_model) redis.set('backend_online', 1) elif opts.mode == 'vllm': - model, err = get_running_model() + running_model, err = get_running_model() if err: print(err) redis.set('backend_online', 0) else: - opts.running_model = model + redis.set('running_model', running_model) redis.set('backend_online', 1) else: raise Exception # exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0 # was entered into the column. The new code enters null instead but we need to be backwards compatible for now. - average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', opts.running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0 + average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0 if average_generation_elapsed_sec: # returns None on exception redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec) # overall = average_column_for_model('prompts', 'generation_time', opts.running_model) # print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}') - average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', opts.running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0 + average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0 if average_generation_elapsed_sec: redis.set('average_output_tokens', average_output_tokens) diff --git a/server.py b/server.py index b3eca0a..4709e32 100644 --- a/server.py +++ b/server.py @@ -15,6 +15,8 @@ from llm_server.database.database import get_number_of_rows from llm_server.llm import get_token_count from llm_server.routes.openai import openai_bp from llm_server.routes.server_error import handle_server_error +from llm_server.routes.v1 import bp +from llm_server.stream import init_socketio # TODO: make sure the OpenAI moderation endpoint scans the last n messages rather than only the last one (make that threaded) # TODO: support turbo-instruct on openai endpoint @@ -22,6 +24,7 @@ from llm_server.routes.server_error import handle_server_error # TODO: validate system tokens before excluding them # TODO: make sure prompts are logged even when the user cancels generation # TODO: add some sort of loadbalancer to send requests to a group of backends +# TODO: use the current estimated wait time for ratelimit headers on openai # TODO: make sure log_prompt() is used everywhere, including errors and invalid requests # TODO: unify logging thread in a function and use async/await instead @@ -43,13 +46,18 @@ from llm_server.llm.vllm.info import vllm_info from llm_server.routes.cache import cache, redis from llm_server.routes.queue import start_workers from llm_server.routes.stats import SemaphoreCheckerThread, get_active_gen_workers, process_avg_gen_time -from llm_server.routes.v1 import bp from llm_server.routes.v1.generate_stats import generate_stats -from llm_server.stream import init_socketio from llm_server.threads import MainBackgroundThread, cache_stats script_path = os.path.dirname(os.path.realpath(__file__)) +app = Flask(__name__) +init_socketio(app) +app.register_blueprint(bp, url_prefix='/api/v1/') +app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/') +cache.init_app(app) +cache.clear() # clear redis cache + config_path_environ = os.getenv("CONFIG_PATH") if config_path_environ: config_path = config_path_environ @@ -73,9 +81,6 @@ if config['mode'] not in ['oobabooga', 'vllm']: print('Unknown mode:', config['mode']) sys.exit(1) -flushed_keys = redis.flush() -print('Flushed', len(flushed_keys), 'keys from Redis.') - # TODO: this is a MESS opts.mode = config['mode'] opts.auth_required = config['auth_required'] @@ -108,23 +113,12 @@ if opts.openai_expose_our_model and not opts.openai_api_key: print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.') sys.exit(1) -if config['http_host']: - http_host = re.sub(r'http(?:s)?://', '', config["http_host"]) - redis.set('http_host', http_host) - redis.set('base_client_api', f'{http_host}/{opts.frontend_api_client.strip("/")}') - print('Set host to', redis.get('http_host', str)) - opts.verify_ssl = config['verify_ssl'] if not opts.verify_ssl: import urllib3 urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) -redis.set('backend_mode', opts.mode) - -if config['load_num_prompts']: - redis.set('proompts', get_number_of_rows('prompts')) - if config['average_generation_time_mode'] not in ['database', 'minute']: print('Invalid value for config item "average_generation_time_mode":', config['average_generation_time_mode']) sys.exit(1) @@ -138,30 +132,36 @@ elif opts.mode == 'vllm': else: raise Exception -app = Flask(__name__) -cache.init_app(app) -cache.clear() # clear redis cache -# Start background processes -start_workers(opts.concurrent_gens) -process_avg_gen_time_background_thread = Thread(target=process_avg_gen_time) -process_avg_gen_time_background_thread.daemon = True -process_avg_gen_time_background_thread.start() -MainBackgroundThread().start() -SemaphoreCheckerThread().start() +def pre_fork(server): + flushed_keys = redis.flush() + print('Flushed', len(flushed_keys), 'keys from Redis.') -# Cache the initial stats -print('Loading backend stats...') -generate_stats() + redis.set('backend_mode', opts.mode) + if config['http_host']: + http_host = re.sub(r'http(?:s)?://', '', config["http_host"]) + redis.set('http_host', http_host) + redis.set('base_client_api', f'{http_host}/{opts.frontend_api_client.strip("/")}') -init_socketio(app) -app.register_blueprint(bp, url_prefix='/api/v1/') -app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/') + if config['load_num_prompts']: + redis.set('proompts', get_number_of_rows('prompts')) -# This needs to be started after Flask is initalized -stats_updater_thread = Thread(target=cache_stats) -stats_updater_thread.daemon = True -stats_updater_thread.start() + # Start background processes + start_workers(opts.concurrent_gens) + process_avg_gen_time_background_thread = Thread(target=process_avg_gen_time) + process_avg_gen_time_background_thread.daemon = True + process_avg_gen_time_background_thread.start() + MainBackgroundThread().start() + SemaphoreCheckerThread().start() + + # This needs to be started after Flask is initalized + stats_updater_thread = Thread(target=cache_stats) + stats_updater_thread.daemon = True + stats_updater_thread.start() + + # Cache the initial stats + print('Loading backend stats...') + generate_stats() # print(app.url_map) @@ -177,7 +177,7 @@ def home(): if not stats['online']: running_model = estimated_wait_sec = 'offline' else: - running_model = opts.running_model + running_model = redis.get('running_model', str, 'ERROR') active_gen_workers = get_active_gen_workers() if stats['queue']['queued'] == 0 and active_gen_workers >= opts.concurrent_gens: