redo background processes, reorganize server.py

This commit is contained in:
Cyberes 2023-09-27 23:36:44 -06:00
parent 097d614a35
commit e86a5182eb
19 changed files with 344 additions and 285 deletions

44
daemon.py Normal file
View File

@ -0,0 +1,44 @@
import time
from llm_server.routes.cache import redis
try:
import gevent.monkey
gevent.monkey.patch_all()
except ImportError:
pass
import os
import sys
from pathlib import Path
from llm_server.config.load import load_config
from llm_server.database.create import create_db
from llm_server.workers.app import start_background
script_path = os.path.dirname(os.path.realpath(__file__))
config_path_environ = os.getenv("CONFIG_PATH")
if config_path_environ:
config_path = config_path_environ
else:
config_path = Path(script_path, 'config', 'config.yml')
if __name__ == "__main__":
flushed_keys = redis.flush()
print('Flushed', len(flushed_keys), 'keys from Redis.')
success, config, msg = load_config(config_path, script_path)
if not success:
print('Failed to load config:', msg)
sys.exit(1)
create_db()
start_background()
redis.set('daemon_started', 1)
print('== Daemon Setup Complete ==\n')
while True:
time.sleep(3600)

View File

@ -5,9 +5,9 @@ try:
except ImportError:
pass
import server
from llm_server.pre_fork import server_startup
def on_starting(s):
server.pre_fork(s)
server_startup(s)
print('Startup complete!')

View File

86
llm_server/config/load.py Normal file
View File

@ -0,0 +1,86 @@
import re
import sys
import openai
from llm_server import opts
from llm_server.config.config import ConfigLoader, config_default_vars, config_required_vars
from llm_server.database.conn import database
from llm_server.database.database import get_number_of_rows
from llm_server.helpers import resolve_path
from llm_server.routes.cache import redis
def load_config(config_path, script_path):
config_loader = ConfigLoader(config_path, config_default_vars, config_required_vars)
success, config, msg = config_loader.load_config()
if not success:
return success, config, msg
# Resolve relative directory to the directory of the script
if config['database_path'].startswith('./'):
config['database_path'] = resolve_path(script_path, config['database_path'].strip('./'))
if config['mode'] not in ['oobabooga', 'vllm']:
print('Unknown mode:', config['mode'])
sys.exit(1)
# TODO: this is atrocious
opts.mode = config['mode']
opts.auth_required = config['auth_required']
opts.log_prompts = config['log_prompts']
opts.concurrent_gens = config['concurrent_gens']
opts.frontend_api_client = config['frontend_api_client']
opts.context_size = config['token_limit']
opts.show_num_prompts = config['show_num_prompts']
opts.show_uptime = config['show_uptime']
opts.backend_url = config['backend_url'].strip('/')
opts.show_total_output_tokens = config['show_total_output_tokens']
opts.netdata_root = config['netdata_root']
opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip']
opts.show_backend_info = config['show_backend_info']
opts.max_new_tokens = config['max_new_tokens']
opts.manual_model_name = config['manual_model_name']
opts.llm_middleware_name = config['llm_middleware_name']
opts.enable_openi_compatible_backend = config['enable_openi_compatible_backend']
opts.openai_system_prompt = config['openai_system_prompt']
opts.expose_openai_system_prompt = config['expose_openai_system_prompt']
opts.enable_streaming = config['enable_streaming']
opts.openai_api_key = config['openai_api_key']
openai.api_key = opts.openai_api_key
opts.admin_token = config['admin_token']
opts.openai_expose_our_model = config['openai_epose_our_model']
opts.openai_force_no_hashes = config['openai_force_no_hashes']
opts.include_system_tokens_in_stats = config['include_system_tokens_in_stats']
opts.openai_moderation_scan_last_n = config['openai_moderation_scan_last_n']
opts.openai_moderation_workers = config['openai_moderation_workers']
opts.openai_org_name = config['openai_org_name']
opts.openai_silent_trim = config['openai_silent_trim']
opts.openai_moderation_enabled = config['openai_moderation_enabled']
if opts.openai_expose_our_model and not opts.openai_api_key:
print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.')
sys.exit(1)
opts.verify_ssl = config['verify_ssl']
if not opts.verify_ssl:
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
if config['http_host']:
http_host = re.sub(r'http(?:s)?://', '', config["http_host"])
redis.set('http_host', http_host)
redis.set('base_client_api', f'{http_host}/{opts.frontend_api_client.strip("/")}')
database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database'])
if config['load_num_prompts']:
redis.set('proompts', get_number_of_rows('prompts'))
redis.set_dict('recent_prompters', {})
redis.set_dict('processing_ips', {})
redis.set_dict('queued_ip_count', {})
redis.set('backend_mode', opts.mode)
return success, config, msg

21
llm_server/pre_fork.py Normal file
View File

@ -0,0 +1,21 @@
import sys
from redis import Redis
from llm_server.routes.cache import redis
from llm_server.routes.v1.generate_stats import generate_stats
def server_startup(s):
if not redis.get('daemon_started', bool):
print('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?')
sys.exit(1)
# Flush the RedisPriorityQueue database.
queue_redis = Redis(host='localhost', port=6379, db=15)
for key in queue_redis.scan_iter('*'):
queue_redis.delete(key)
# Cache the initial stats
print('Loading backend stats...')
generate_stats()

View File

@ -57,7 +57,7 @@ def require_api_key(json_body: dict = None):
return jsonify({'code': 403, 'message': 'Invalid API key'}), 403
elif 'Authorization' in request.headers:
token = parse_token(request.headers['Authorization'])
if token.startswith('SYSTEM__') or opts.auth_required:
if (token and token.startswith('SYSTEM__')) or opts.auth_required:
if is_valid_api_key(token):
return
else:

View File

@ -10,7 +10,7 @@ from llm_server import opts
from llm_server.database.database import is_api_key_moderated
from llm_server.llm.openai.transform import build_openai_response, transform_messages_to_prompt, trim_prompt_to_fit
from llm_server.routes.request_handler import RequestHandler
from llm_server.threads import add_moderation_task, get_results
from llm_server.workers.moderator import add_moderation_task, get_results
class OpenAIRequestHandler(RequestHandler):

View File

@ -1,6 +1,4 @@
import time
from datetime import datetime
from threading import Thread
from llm_server.routes.cache import redis
@ -46,27 +44,3 @@ def get_active_gen_workers():
else:
count = int(active_gen_workers)
return count
class SemaphoreCheckerThread(Thread):
redis.set_dict('recent_prompters', {})
def __init__(self):
Thread.__init__(self)
self.daemon = True
def run(self):
while True:
current_time = time.time()
recent_prompters = redis.get_dict('recent_prompters')
new_recent_prompters = {}
for ip, (timestamp, token) in recent_prompters.items():
if token and token.startswith('SYSTEM__'):
continue
if current_time - timestamp <= 300:
new_recent_prompters[ip] = timestamp, token
redis.set_dict('recent_prompters', new_recent_prompters)
redis.set('proompters_5_min', len(new_recent_prompters))
time.sleep(1)

View File

@ -1,120 +0,0 @@
import json
import threading
import time
import traceback
from threading import Thread
import redis as redis_redis
from llm_server import opts
from llm_server.database.database import weighted_average_column_for_model
from llm_server.llm.info import get_running_model
from llm_server.llm.openai.moderation import check_moderation_endpoint
from llm_server.routes.cache import redis
from llm_server.routes.v1.generate_stats import generate_stats
class MainBackgroundThread(Thread):
backend_online = False
# TODO: do I really need to put everything in Redis?
# TODO: call generate_stats() every minute, cache the results, put results in a DB table, then have other parts of code call this cache
def __init__(self):
Thread.__init__(self)
self.daemon = True
redis.set('average_generation_elapsed_sec', 0)
redis.set('estimated_avg_tps', 0)
redis.set('average_output_tokens', 0)
redis.set('backend_online', 0)
redis.set_dict('backend_info', {})
def run(self):
while True:
# TODO: unify this
if opts.mode == 'oobabooga':
running_model, err = get_running_model()
if err:
print(err)
redis.set('backend_online', 0)
else:
redis.set('running_model', running_model)
redis.set('backend_online', 1)
elif opts.mode == 'vllm':
running_model, err = get_running_model()
if err:
print(err)
redis.set('backend_online', 0)
else:
redis.set('running_model', running_model)
redis.set('backend_online', 1)
else:
raise Exception
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now.
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
if average_generation_elapsed_sec: # returns None on exception
redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec)
# overall = average_column_for_model('prompts', 'generation_time', opts.running_model)
# print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}')
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
if average_generation_elapsed_sec:
redis.set('average_output_tokens', average_output_tokens)
# overall = average_column_for_model('prompts', 'response_tokens', opts.running_model)
# print(f'Weighted: {average_output_tokens}, overall: {overall}')
estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero
redis.set('estimated_avg_tps', estimated_avg_tps)
time.sleep(60)
def cache_stats():
while True:
generate_stats(regen=True)
time.sleep(5)
redis_moderation = redis_redis.Redis()
def start_moderation_workers(num_workers):
for _ in range(num_workers):
t = threading.Thread(target=moderation_worker)
t.daemon = True
t.start()
def moderation_worker():
while True:
result = redis_moderation.blpop('queue:msgs_to_check')
try:
msg, tag = json.loads(result[1])
_, categories = check_moderation_endpoint(msg)
redis_moderation.rpush('queue:flagged_categories', json.dumps((tag, categories)))
except:
print(result)
traceback.print_exc()
continue
def add_moderation_task(msg, tag):
redis_moderation.rpush('queue:msgs_to_check', json.dumps((msg, str(tag))))
def get_results(tag, num_tasks):
tag = str(tag) # Required for comparison with Redis results.
flagged_categories = set()
num_results = 0
while num_results < num_tasks:
result = redis_moderation.blpop('queue:flagged_categories')
result_tag, categories = json.loads(result[1])
if result_tag == tag:
if categories:
for item in categories:
flagged_categories.add(item)
num_results += 1
return list(flagged_categories)

35
llm_server/workers/app.py Normal file
View File

@ -0,0 +1,35 @@
from threading import Thread
from .blocking import start_workers
from .main import main_background_thread
from .moderator import start_moderation_workers
from .printer import console_printer
from .recent import recent_prompters_thread
from .threads import cache_stats
from .. import opts
def start_background():
start_workers(opts.concurrent_gens)
t = Thread(target=main_background_thread)
t.daemon = True
t.start()
print('Started the main background thread.')
start_moderation_workers(opts.openai_moderation_workers)
t = Thread(target=cache_stats)
t.daemon = True
t.start()
print('Started the stats cacher.')
t = Thread(target=recent_prompters_thread)
t.daemon = True
t.start()
print('Started the recent proompters thread.')
t = Thread(target=console_printer)
t.daemon = True
t.start()
print('Started the console printer.')

View File

@ -1,4 +1,3 @@
import json
import threading
import time
@ -17,12 +16,12 @@ def worker():
increment_ip_count(client_ip, 'processing_ips')
redis.incr('active_gen_workers')
try:
if not request_json_body:
# This was a dummy request from the websocket handler.
# We're going to let the websocket handler decrement processing_ips and active_gen_workers.
continue
try:
start_time = time.time()
success, response, error_msg = generator(request_json_body)
end_time = time.time()

View File

@ -0,0 +1,56 @@
import time
from threading import Thread
from llm_server import opts
from llm_server.database.database import weighted_average_column_for_model
from llm_server.llm.info import get_running_model
from llm_server.routes.cache import redis
def main_background_thread():
redis.set('average_generation_elapsed_sec', 0)
redis.set('estimated_avg_tps', 0)
redis.set('average_output_tokens', 0)
redis.set('backend_online', 0)
redis.set_dict('backend_info', {})
while True:
# TODO: unify this
if opts.mode == 'oobabooga':
running_model, err = get_running_model()
if err:
print(err)
redis.set('backend_online', 0)
else:
redis.set('running_model', running_model)
redis.set('backend_online', 1)
elif opts.mode == 'vllm':
running_model, err = get_running_model()
if err:
print(err)
redis.set('backend_online', 0)
else:
redis.set('running_model', running_model)
redis.set('backend_online', 1)
else:
raise Exception
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now.
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
if average_generation_elapsed_sec: # returns None on exception
redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec)
# overall = average_column_for_model('prompts', 'generation_time', opts.running_model)
# print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}')
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
if average_generation_elapsed_sec:
redis.set('average_output_tokens', average_output_tokens)
# overall = average_column_for_model('prompts', 'response_tokens', opts.running_model)
# print(f'Weighted: {average_output_tokens}, overall: {overall}')
estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero
redis.set('estimated_avg_tps', estimated_avg_tps)
time.sleep(60)

View File

@ -0,0 +1,51 @@
import json
import threading
import traceback
import redis as redis_redis
from llm_server.llm.openai.moderation import check_moderation_endpoint
redis_moderation = redis_redis.Redis()
def start_moderation_workers(num_workers):
i = 0
for _ in range(num_workers):
t = threading.Thread(target=moderation_worker)
t.daemon = True
t.start()
i += 1
print(f'Started {i} moderation workers.')
def moderation_worker():
while True:
result = redis_moderation.blpop('queue:msgs_to_check')
try:
msg, tag = json.loads(result[1])
_, categories = check_moderation_endpoint(msg)
redis_moderation.rpush('queue:flagged_categories', json.dumps((tag, categories)))
except:
print(result)
traceback.print_exc()
continue
def add_moderation_task(msg, tag):
redis_moderation.rpush('queue:msgs_to_check', json.dumps((msg, str(tag))))
def get_results(tag, num_tasks):
tag = str(tag) # Required for comparison with Redis results.
flagged_categories = set()
num_results = 0
while num_results < num_tasks:
result = redis_moderation.blpop('queue:flagged_categories')
result_tag, categories = json.loads(result[1])
if result_tag == tag:
if categories:
for item in categories:
flagged_categories.add(item)
num_results += 1
return list(flagged_categories)

View File

@ -1,5 +1,4 @@
import logging
import threading
import time
from llm_server.routes.cache import redis
@ -15,14 +14,9 @@ if not logger.handlers:
def console_printer():
time.sleep(3)
while True:
queued_ip_count = sum([v for k, v in redis.get_dict('queued_ip_count').items()])
processing_count = sum([v for k, v in redis.get_dict('processing_ips').items()])
logger.info(f'REQUEST QUEUE -> Processing: {processing_count} | Queued: {queued_ip_count}')
time.sleep(10)
def start_console_printer():
t = threading.Thread(target=console_printer)
t.daemon = True
t.start()
time.sleep(15)

View File

@ -0,0 +1,19 @@
import time
from llm_server.routes.cache import redis
def recent_prompters_thread():
current_time = time.time()
recent_prompters = redis.get_dict('recent_prompters')
new_recent_prompters = {}
for ip, (timestamp, token) in recent_prompters.items():
if token and token.startswith('SYSTEM__'):
continue
if current_time - timestamp <= 300:
new_recent_prompters[ip] = timestamp, token
redis.set_dict('recent_prompters', new_recent_prompters)
redis.set('proompters_5_min', len(new_recent_prompters))
time.sleep(1)

View File

@ -0,0 +1,9 @@
import time
from llm_server.routes.v1.generate_stats import generate_stats
def cache_stats():
while True:
generate_stats(regen=True)
time.sleep(5)

View File

@ -5,7 +5,6 @@ flask_caching
requests~=2.31.0
tiktoken~=0.5.0
gunicorn
redis~=5.0.0
gevent~=23.9.0.post1
async-timeout
flask-sock
@ -19,4 +18,4 @@ websockets~=11.0.3
basicauth~=1.0.0
openai~=0.28.0
urllib3~=2.0.4
rq~=1.15.1
celery[redis]

128
server.py
View File

@ -1,3 +1,5 @@
from llm_server.config.config import mode_ui_names
try:
import gevent.monkey
@ -5,28 +7,23 @@ try:
except ImportError:
pass
from llm_server.workers.printer import start_console_printer
from llm_server.pre_fork import server_startup
from llm_server.config.load import load_config
import os
import re
import sys
from pathlib import Path
from threading import Thread
import openai
import simplejson as json
from flask import Flask, jsonify, render_template, request
from redis import Redis
import llm_server
from llm_server.database.conn import database
from llm_server.database.create import create_db
from llm_server.database.database import get_number_of_rows
from llm_server.llm import get_token_count
from llm_server.routes.openai import openai_bp
from llm_server.routes.server_error import handle_server_error
from llm_server.routes.v1 import bp
from llm_server.stream import init_socketio
from llm_server.workers.blocking import start_workers
# TODO: have the workers handle streaming too
# TODO: add backend fallbacks. Backends at the bottom of the list are higher priority and are fallbacks if the upper ones fail
@ -62,16 +59,12 @@ except ModuleNotFoundError as e:
import config
from llm_server import opts
from llm_server.config import ConfigLoader, config_default_vars, config_required_vars, mode_ui_names
from llm_server.helpers import resolve_path, auto_set_base_client_api
from llm_server.helpers import auto_set_base_client_api
from llm_server.llm.vllm.info import vllm_info
from llm_server.routes.cache import RedisWrapper, flask_cache
from llm_server.llm import redis
from llm_server.routes.stats import SemaphoreCheckerThread, get_active_gen_workers
from llm_server.routes.stats import get_active_gen_workers
from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.threads import MainBackgroundThread, cache_stats, start_moderation_workers
script_path = os.path.dirname(os.path.realpath(__file__))
app = Flask(__name__)
init_socketio(app)
@ -80,123 +73,22 @@ app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
flask_cache.init_app(app)
flask_cache.clear()
script_path = os.path.dirname(os.path.realpath(__file__))
config_path_environ = os.getenv("CONFIG_PATH")
if config_path_environ:
config_path = config_path_environ
else:
config_path = Path(script_path, 'config', 'config.yml')
config_loader = ConfigLoader(config_path, config_default_vars, config_required_vars)
success, config, msg = config_loader.load_config()
success, config, msg = load_config(config_path, script_path)
if not success:
print('Failed to load config:', msg)
sys.exit(1)
# Resolve relative directory to the directory of the script
if config['database_path'].startswith('./'):
config['database_path'] = resolve_path(script_path, config['database_path'].strip('./'))
database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database'])
create_db()
if config['mode'] not in ['oobabooga', 'vllm']:
print('Unknown mode:', config['mode'])
sys.exit(1)
# TODO: this is atrocious
opts.mode = config['mode']
opts.auth_required = config['auth_required']
opts.log_prompts = config['log_prompts']
opts.concurrent_gens = config['concurrent_gens']
opts.frontend_api_client = config['frontend_api_client']
opts.context_size = config['token_limit']
opts.show_num_prompts = config['show_num_prompts']
opts.show_uptime = config['show_uptime']
opts.backend_url = config['backend_url'].strip('/')
opts.show_total_output_tokens = config['show_total_output_tokens']
opts.netdata_root = config['netdata_root']
opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip']
opts.show_backend_info = config['show_backend_info']
opts.max_new_tokens = config['max_new_tokens']
opts.manual_model_name = config['manual_model_name']
opts.llm_middleware_name = config['llm_middleware_name']
opts.enable_openi_compatible_backend = config['enable_openi_compatible_backend']
opts.openai_system_prompt = config['openai_system_prompt']
opts.expose_openai_system_prompt = config['expose_openai_system_prompt']
opts.enable_streaming = config['enable_streaming']
opts.openai_api_key = config['openai_api_key']
openai.api_key = opts.openai_api_key
opts.admin_token = config['admin_token']
opts.openai_expose_our_model = config['openai_epose_our_model']
opts.openai_force_no_hashes = config['openai_force_no_hashes']
opts.include_system_tokens_in_stats = config['include_system_tokens_in_stats']
opts.openai_moderation_scan_last_n = config['openai_moderation_scan_last_n']
opts.openai_moderation_workers = config['openai_moderation_workers']
opts.openai_org_name = config['openai_org_name']
opts.openai_silent_trim = config['openai_silent_trim']
opts.openai_moderation_enabled = config['openai_moderation_enabled']
if opts.openai_expose_our_model and not opts.openai_api_key:
print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.')
sys.exit(1)
opts.verify_ssl = config['verify_ssl']
if not opts.verify_ssl:
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
if config['average_generation_time_mode'] not in ['database', 'minute']:
print('Invalid value for config item "average_generation_time_mode":', config['average_generation_time_mode'])
sys.exit(1)
opts.average_generation_time_mode = config['average_generation_time_mode']
if opts.mode == 'oobabooga':
raise NotImplementedError
# llm_server.llm.tokenizer = OobaboogaBackend()
elif opts.mode == 'vllm':
llm_server.llm.get_token_count = llm_server.llm.vllm.tokenize
else:
raise Exception
def pre_fork(server):
llm_server.llm.redis = RedisWrapper('local_llm')
flushed_keys = redis.flush()
print('Flushed', len(flushed_keys), 'keys from Redis.')
redis.set_dict('processing_ips', {})
redis.set_dict('queued_ip_count', {})
# Flush the RedisPriorityQueue database.
queue_redis = Redis(host='localhost', port=6379, db=15)
for key in queue_redis.scan_iter('*'):
queue_redis.delete(key)
redis.set('backend_mode', opts.mode)
if config['http_host']:
http_host = re.sub(r'http(?:s)?://', '', config["http_host"])
redis.set('http_host', http_host)
redis.set('base_client_api', f'{http_host}/{opts.frontend_api_client.strip("/")}')
if config['load_num_prompts']:
redis.set('proompts', get_number_of_rows('prompts'))
# Start background processes
start_workers(opts.concurrent_gens)
start_console_printer()
start_moderation_workers(opts.openai_moderation_workers)
MainBackgroundThread().start()
SemaphoreCheckerThread().start()
# This needs to be started after Flask is initalized
stats_updater_thread = Thread(target=cache_stats)
stats_updater_thread.daemon = True
stats_updater_thread.start()
# Cache the initial stats
print('Loading backend stats...')
generate_stats()
create_db()
# print(app.url_map)
@ -280,6 +172,6 @@ def before_app_request():
if __name__ == "__main__":
pre_fork(None)
server_startup(None)
print('FLASK MODE - Startup complete!')
app.run(host='0.0.0.0', threaded=False, processes=15)