redo background processes, reorganize server.py

This commit is contained in:
Cyberes 2023-09-27 23:36:44 -06:00
parent 097d614a35
commit e86a5182eb
19 changed files with 344 additions and 285 deletions

44
daemon.py Normal file
View File

@ -0,0 +1,44 @@
import time
from llm_server.routes.cache import redis
try:
import gevent.monkey
gevent.monkey.patch_all()
except ImportError:
pass
import os
import sys
from pathlib import Path
from llm_server.config.load import load_config
from llm_server.database.create import create_db
from llm_server.workers.app import start_background
script_path = os.path.dirname(os.path.realpath(__file__))
config_path_environ = os.getenv("CONFIG_PATH")
if config_path_environ:
config_path = config_path_environ
else:
config_path = Path(script_path, 'config', 'config.yml')
if __name__ == "__main__":
flushed_keys = redis.flush()
print('Flushed', len(flushed_keys), 'keys from Redis.')
success, config, msg = load_config(config_path, script_path)
if not success:
print('Failed to load config:', msg)
sys.exit(1)
create_db()
start_background()
redis.set('daemon_started', 1)
print('== Daemon Setup Complete ==\n')
while True:
time.sleep(3600)

View File

@ -5,9 +5,9 @@ try:
except ImportError: except ImportError:
pass pass
import server from llm_server.pre_fork import server_startup
def on_starting(s): def on_starting(s):
server.pre_fork(s) server_startup(s)
print('Startup complete!') print('Startup complete!')

View File

86
llm_server/config/load.py Normal file
View File

@ -0,0 +1,86 @@
import re
import sys
import openai
from llm_server import opts
from llm_server.config.config import ConfigLoader, config_default_vars, config_required_vars
from llm_server.database.conn import database
from llm_server.database.database import get_number_of_rows
from llm_server.helpers import resolve_path
from llm_server.routes.cache import redis
def load_config(config_path, script_path):
config_loader = ConfigLoader(config_path, config_default_vars, config_required_vars)
success, config, msg = config_loader.load_config()
if not success:
return success, config, msg
# Resolve relative directory to the directory of the script
if config['database_path'].startswith('./'):
config['database_path'] = resolve_path(script_path, config['database_path'].strip('./'))
if config['mode'] not in ['oobabooga', 'vllm']:
print('Unknown mode:', config['mode'])
sys.exit(1)
# TODO: this is atrocious
opts.mode = config['mode']
opts.auth_required = config['auth_required']
opts.log_prompts = config['log_prompts']
opts.concurrent_gens = config['concurrent_gens']
opts.frontend_api_client = config['frontend_api_client']
opts.context_size = config['token_limit']
opts.show_num_prompts = config['show_num_prompts']
opts.show_uptime = config['show_uptime']
opts.backend_url = config['backend_url'].strip('/')
opts.show_total_output_tokens = config['show_total_output_tokens']
opts.netdata_root = config['netdata_root']
opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip']
opts.show_backend_info = config['show_backend_info']
opts.max_new_tokens = config['max_new_tokens']
opts.manual_model_name = config['manual_model_name']
opts.llm_middleware_name = config['llm_middleware_name']
opts.enable_openi_compatible_backend = config['enable_openi_compatible_backend']
opts.openai_system_prompt = config['openai_system_prompt']
opts.expose_openai_system_prompt = config['expose_openai_system_prompt']
opts.enable_streaming = config['enable_streaming']
opts.openai_api_key = config['openai_api_key']
openai.api_key = opts.openai_api_key
opts.admin_token = config['admin_token']
opts.openai_expose_our_model = config['openai_epose_our_model']
opts.openai_force_no_hashes = config['openai_force_no_hashes']
opts.include_system_tokens_in_stats = config['include_system_tokens_in_stats']
opts.openai_moderation_scan_last_n = config['openai_moderation_scan_last_n']
opts.openai_moderation_workers = config['openai_moderation_workers']
opts.openai_org_name = config['openai_org_name']
opts.openai_silent_trim = config['openai_silent_trim']
opts.openai_moderation_enabled = config['openai_moderation_enabled']
if opts.openai_expose_our_model and not opts.openai_api_key:
print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.')
sys.exit(1)
opts.verify_ssl = config['verify_ssl']
if not opts.verify_ssl:
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
if config['http_host']:
http_host = re.sub(r'http(?:s)?://', '', config["http_host"])
redis.set('http_host', http_host)
redis.set('base_client_api', f'{http_host}/{opts.frontend_api_client.strip("/")}')
database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database'])
if config['load_num_prompts']:
redis.set('proompts', get_number_of_rows('prompts'))
redis.set_dict('recent_prompters', {})
redis.set_dict('processing_ips', {})
redis.set_dict('queued_ip_count', {})
redis.set('backend_mode', opts.mode)
return success, config, msg

21
llm_server/pre_fork.py Normal file
View File

@ -0,0 +1,21 @@
import sys
from redis import Redis
from llm_server.routes.cache import redis
from llm_server.routes.v1.generate_stats import generate_stats
def server_startup(s):
if not redis.get('daemon_started', bool):
print('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?')
sys.exit(1)
# Flush the RedisPriorityQueue database.
queue_redis = Redis(host='localhost', port=6379, db=15)
for key in queue_redis.scan_iter('*'):
queue_redis.delete(key)
# Cache the initial stats
print('Loading backend stats...')
generate_stats()

View File

@ -57,7 +57,7 @@ def require_api_key(json_body: dict = None):
return jsonify({'code': 403, 'message': 'Invalid API key'}), 403 return jsonify({'code': 403, 'message': 'Invalid API key'}), 403
elif 'Authorization' in request.headers: elif 'Authorization' in request.headers:
token = parse_token(request.headers['Authorization']) token = parse_token(request.headers['Authorization'])
if token.startswith('SYSTEM__') or opts.auth_required: if (token and token.startswith('SYSTEM__')) or opts.auth_required:
if is_valid_api_key(token): if is_valid_api_key(token):
return return
else: else:

View File

@ -10,7 +10,7 @@ from llm_server import opts
from llm_server.database.database import is_api_key_moderated from llm_server.database.database import is_api_key_moderated
from llm_server.llm.openai.transform import build_openai_response, transform_messages_to_prompt, trim_prompt_to_fit from llm_server.llm.openai.transform import build_openai_response, transform_messages_to_prompt, trim_prompt_to_fit
from llm_server.routes.request_handler import RequestHandler from llm_server.routes.request_handler import RequestHandler
from llm_server.threads import add_moderation_task, get_results from llm_server.workers.moderator import add_moderation_task, get_results
class OpenAIRequestHandler(RequestHandler): class OpenAIRequestHandler(RequestHandler):

View File

@ -1,6 +1,4 @@
import time
from datetime import datetime from datetime import datetime
from threading import Thread
from llm_server.routes.cache import redis from llm_server.routes.cache import redis
@ -46,27 +44,3 @@ def get_active_gen_workers():
else: else:
count = int(active_gen_workers) count = int(active_gen_workers)
return count return count
class SemaphoreCheckerThread(Thread):
redis.set_dict('recent_prompters', {})
def __init__(self):
Thread.__init__(self)
self.daemon = True
def run(self):
while True:
current_time = time.time()
recent_prompters = redis.get_dict('recent_prompters')
new_recent_prompters = {}
for ip, (timestamp, token) in recent_prompters.items():
if token and token.startswith('SYSTEM__'):
continue
if current_time - timestamp <= 300:
new_recent_prompters[ip] = timestamp, token
redis.set_dict('recent_prompters', new_recent_prompters)
redis.set('proompters_5_min', len(new_recent_prompters))
time.sleep(1)

View File

@ -1,120 +0,0 @@
import json
import threading
import time
import traceback
from threading import Thread
import redis as redis_redis
from llm_server import opts
from llm_server.database.database import weighted_average_column_for_model
from llm_server.llm.info import get_running_model
from llm_server.llm.openai.moderation import check_moderation_endpoint
from llm_server.routes.cache import redis
from llm_server.routes.v1.generate_stats import generate_stats
class MainBackgroundThread(Thread):
backend_online = False
# TODO: do I really need to put everything in Redis?
# TODO: call generate_stats() every minute, cache the results, put results in a DB table, then have other parts of code call this cache
def __init__(self):
Thread.__init__(self)
self.daemon = True
redis.set('average_generation_elapsed_sec', 0)
redis.set('estimated_avg_tps', 0)
redis.set('average_output_tokens', 0)
redis.set('backend_online', 0)
redis.set_dict('backend_info', {})
def run(self):
while True:
# TODO: unify this
if opts.mode == 'oobabooga':
running_model, err = get_running_model()
if err:
print(err)
redis.set('backend_online', 0)
else:
redis.set('running_model', running_model)
redis.set('backend_online', 1)
elif opts.mode == 'vllm':
running_model, err = get_running_model()
if err:
print(err)
redis.set('backend_online', 0)
else:
redis.set('running_model', running_model)
redis.set('backend_online', 1)
else:
raise Exception
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now.
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
if average_generation_elapsed_sec: # returns None on exception
redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec)
# overall = average_column_for_model('prompts', 'generation_time', opts.running_model)
# print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}')
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
if average_generation_elapsed_sec:
redis.set('average_output_tokens', average_output_tokens)
# overall = average_column_for_model('prompts', 'response_tokens', opts.running_model)
# print(f'Weighted: {average_output_tokens}, overall: {overall}')
estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero
redis.set('estimated_avg_tps', estimated_avg_tps)
time.sleep(60)
def cache_stats():
while True:
generate_stats(regen=True)
time.sleep(5)
redis_moderation = redis_redis.Redis()
def start_moderation_workers(num_workers):
for _ in range(num_workers):
t = threading.Thread(target=moderation_worker)
t.daemon = True
t.start()
def moderation_worker():
while True:
result = redis_moderation.blpop('queue:msgs_to_check')
try:
msg, tag = json.loads(result[1])
_, categories = check_moderation_endpoint(msg)
redis_moderation.rpush('queue:flagged_categories', json.dumps((tag, categories)))
except:
print(result)
traceback.print_exc()
continue
def add_moderation_task(msg, tag):
redis_moderation.rpush('queue:msgs_to_check', json.dumps((msg, str(tag))))
def get_results(tag, num_tasks):
tag = str(tag) # Required for comparison with Redis results.
flagged_categories = set()
num_results = 0
while num_results < num_tasks:
result = redis_moderation.blpop('queue:flagged_categories')
result_tag, categories = json.loads(result[1])
if result_tag == tag:
if categories:
for item in categories:
flagged_categories.add(item)
num_results += 1
return list(flagged_categories)

35
llm_server/workers/app.py Normal file
View File

@ -0,0 +1,35 @@
from threading import Thread
from .blocking import start_workers
from .main import main_background_thread
from .moderator import start_moderation_workers
from .printer import console_printer
from .recent import recent_prompters_thread
from .threads import cache_stats
from .. import opts
def start_background():
start_workers(opts.concurrent_gens)
t = Thread(target=main_background_thread)
t.daemon = True
t.start()
print('Started the main background thread.')
start_moderation_workers(opts.openai_moderation_workers)
t = Thread(target=cache_stats)
t.daemon = True
t.start()
print('Started the stats cacher.')
t = Thread(target=recent_prompters_thread)
t.daemon = True
t.start()
print('Started the recent proompters thread.')
t = Thread(target=console_printer)
t.daemon = True
t.start()
print('Started the console printer.')

View File

@ -1,4 +1,3 @@
import json
import threading import threading
import time import time
@ -17,12 +16,12 @@ def worker():
increment_ip_count(client_ip, 'processing_ips') increment_ip_count(client_ip, 'processing_ips')
redis.incr('active_gen_workers') redis.incr('active_gen_workers')
if not request_json_body:
# This was a dummy request from the websocket handler.
# We're going to let the websocket handler decrement processing_ips and active_gen_workers.
continue
try: try:
if not request_json_body:
# This was a dummy request from the websocket handler.
# We're going to let the websocket handler decrement processing_ips and active_gen_workers.
continue
start_time = time.time() start_time = time.time()
success, response, error_msg = generator(request_json_body) success, response, error_msg = generator(request_json_body)
end_time = time.time() end_time = time.time()

View File

@ -0,0 +1,56 @@
import time
from threading import Thread
from llm_server import opts
from llm_server.database.database import weighted_average_column_for_model
from llm_server.llm.info import get_running_model
from llm_server.routes.cache import redis
def main_background_thread():
redis.set('average_generation_elapsed_sec', 0)
redis.set('estimated_avg_tps', 0)
redis.set('average_output_tokens', 0)
redis.set('backend_online', 0)
redis.set_dict('backend_info', {})
while True:
# TODO: unify this
if opts.mode == 'oobabooga':
running_model, err = get_running_model()
if err:
print(err)
redis.set('backend_online', 0)
else:
redis.set('running_model', running_model)
redis.set('backend_online', 1)
elif opts.mode == 'vllm':
running_model, err = get_running_model()
if err:
print(err)
redis.set('backend_online', 0)
else:
redis.set('running_model', running_model)
redis.set('backend_online', 1)
else:
raise Exception
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now.
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
if average_generation_elapsed_sec: # returns None on exception
redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec)
# overall = average_column_for_model('prompts', 'generation_time', opts.running_model)
# print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}')
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
if average_generation_elapsed_sec:
redis.set('average_output_tokens', average_output_tokens)
# overall = average_column_for_model('prompts', 'response_tokens', opts.running_model)
# print(f'Weighted: {average_output_tokens}, overall: {overall}')
estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero
redis.set('estimated_avg_tps', estimated_avg_tps)
time.sleep(60)

View File

@ -0,0 +1,51 @@
import json
import threading
import traceback
import redis as redis_redis
from llm_server.llm.openai.moderation import check_moderation_endpoint
redis_moderation = redis_redis.Redis()
def start_moderation_workers(num_workers):
i = 0
for _ in range(num_workers):
t = threading.Thread(target=moderation_worker)
t.daemon = True
t.start()
i += 1
print(f'Started {i} moderation workers.')
def moderation_worker():
while True:
result = redis_moderation.blpop('queue:msgs_to_check')
try:
msg, tag = json.loads(result[1])
_, categories = check_moderation_endpoint(msg)
redis_moderation.rpush('queue:flagged_categories', json.dumps((tag, categories)))
except:
print(result)
traceback.print_exc()
continue
def add_moderation_task(msg, tag):
redis_moderation.rpush('queue:msgs_to_check', json.dumps((msg, str(tag))))
def get_results(tag, num_tasks):
tag = str(tag) # Required for comparison with Redis results.
flagged_categories = set()
num_results = 0
while num_results < num_tasks:
result = redis_moderation.blpop('queue:flagged_categories')
result_tag, categories = json.loads(result[1])
if result_tag == tag:
if categories:
for item in categories:
flagged_categories.add(item)
num_results += 1
return list(flagged_categories)

View File

@ -1,5 +1,4 @@
import logging import logging
import threading
import time import time
from llm_server.routes.cache import redis from llm_server.routes.cache import redis
@ -15,14 +14,9 @@ if not logger.handlers:
def console_printer(): def console_printer():
time.sleep(3)
while True: while True:
queued_ip_count = sum([v for k, v in redis.get_dict('queued_ip_count').items()]) queued_ip_count = sum([v for k, v in redis.get_dict('queued_ip_count').items()])
processing_count = sum([v for k, v in redis.get_dict('processing_ips').items()]) processing_count = sum([v for k, v in redis.get_dict('processing_ips').items()])
logger.info(f'REQUEST QUEUE -> Processing: {processing_count} | Queued: {queued_ip_count}') logger.info(f'REQUEST QUEUE -> Processing: {processing_count} | Queued: {queued_ip_count}')
time.sleep(10) time.sleep(15)
def start_console_printer():
t = threading.Thread(target=console_printer)
t.daemon = True
t.start()

View File

@ -0,0 +1,19 @@
import time
from llm_server.routes.cache import redis
def recent_prompters_thread():
current_time = time.time()
recent_prompters = redis.get_dict('recent_prompters')
new_recent_prompters = {}
for ip, (timestamp, token) in recent_prompters.items():
if token and token.startswith('SYSTEM__'):
continue
if current_time - timestamp <= 300:
new_recent_prompters[ip] = timestamp, token
redis.set_dict('recent_prompters', new_recent_prompters)
redis.set('proompters_5_min', len(new_recent_prompters))
time.sleep(1)

View File

@ -0,0 +1,9 @@
import time
from llm_server.routes.v1.generate_stats import generate_stats
def cache_stats():
while True:
generate_stats(regen=True)
time.sleep(5)

View File

@ -5,7 +5,6 @@ flask_caching
requests~=2.31.0 requests~=2.31.0
tiktoken~=0.5.0 tiktoken~=0.5.0
gunicorn gunicorn
redis~=5.0.0
gevent~=23.9.0.post1 gevent~=23.9.0.post1
async-timeout async-timeout
flask-sock flask-sock
@ -19,4 +18,4 @@ websockets~=11.0.3
basicauth~=1.0.0 basicauth~=1.0.0
openai~=0.28.0 openai~=0.28.0
urllib3~=2.0.4 urllib3~=2.0.4
rq~=1.15.1 celery[redis]

130
server.py
View File

@ -1,3 +1,5 @@
from llm_server.config.config import mode_ui_names
try: try:
import gevent.monkey import gevent.monkey
@ -5,28 +7,23 @@ try:
except ImportError: except ImportError:
pass pass
from llm_server.workers.printer import start_console_printer from llm_server.pre_fork import server_startup
from llm_server.config.load import load_config
import os import os
import re
import sys import sys
from pathlib import Path from pathlib import Path
from threading import Thread
import openai
import simplejson as json import simplejson as json
from flask import Flask, jsonify, render_template, request from flask import Flask, jsonify, render_template, request
from redis import Redis
import llm_server import llm_server
from llm_server.database.conn import database from llm_server.database.conn import database
from llm_server.database.create import create_db from llm_server.database.create import create_db
from llm_server.database.database import get_number_of_rows
from llm_server.llm import get_token_count from llm_server.llm import get_token_count
from llm_server.routes.openai import openai_bp from llm_server.routes.openai import openai_bp
from llm_server.routes.server_error import handle_server_error from llm_server.routes.server_error import handle_server_error
from llm_server.routes.v1 import bp from llm_server.routes.v1 import bp
from llm_server.stream import init_socketio from llm_server.stream import init_socketio
from llm_server.workers.blocking import start_workers
# TODO: have the workers handle streaming too # TODO: have the workers handle streaming too
# TODO: add backend fallbacks. Backends at the bottom of the list are higher priority and are fallbacks if the upper ones fail # TODO: add backend fallbacks. Backends at the bottom of the list are higher priority and are fallbacks if the upper ones fail
@ -62,16 +59,12 @@ except ModuleNotFoundError as e:
import config import config
from llm_server import opts from llm_server import opts
from llm_server.config import ConfigLoader, config_default_vars, config_required_vars, mode_ui_names from llm_server.helpers import auto_set_base_client_api
from llm_server.helpers import resolve_path, auto_set_base_client_api
from llm_server.llm.vllm.info import vllm_info from llm_server.llm.vllm.info import vllm_info
from llm_server.routes.cache import RedisWrapper, flask_cache from llm_server.routes.cache import RedisWrapper, flask_cache
from llm_server.llm import redis from llm_server.llm import redis
from llm_server.routes.stats import SemaphoreCheckerThread, get_active_gen_workers from llm_server.routes.stats import get_active_gen_workers
from llm_server.routes.v1.generate_stats import generate_stats from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.threads import MainBackgroundThread, cache_stats, start_moderation_workers
script_path = os.path.dirname(os.path.realpath(__file__))
app = Flask(__name__) app = Flask(__name__)
init_socketio(app) init_socketio(app)
@ -80,123 +73,22 @@ app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
flask_cache.init_app(app) flask_cache.init_app(app)
flask_cache.clear() flask_cache.clear()
script_path = os.path.dirname(os.path.realpath(__file__))
config_path_environ = os.getenv("CONFIG_PATH") config_path_environ = os.getenv("CONFIG_PATH")
if config_path_environ: if config_path_environ:
config_path = config_path_environ config_path = config_path_environ
else: else:
config_path = Path(script_path, 'config', 'config.yml') config_path = Path(script_path, 'config', 'config.yml')
config_loader = ConfigLoader(config_path, config_default_vars, config_required_vars) success, config, msg = load_config(config_path, script_path)
success, config, msg = config_loader.load_config()
if not success: if not success:
print('Failed to load config:', msg) print('Failed to load config:', msg)
sys.exit(1) sys.exit(1)
# Resolve relative directory to the directory of the script
if config['database_path'].startswith('./'):
config['database_path'] = resolve_path(script_path, config['database_path'].strip('./'))
database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database']) database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database'])
create_db() create_db()
llm_server.llm.redis = RedisWrapper('local_llm')
if config['mode'] not in ['oobabooga', 'vllm']: create_db()
print('Unknown mode:', config['mode'])
sys.exit(1)
# TODO: this is atrocious
opts.mode = config['mode']
opts.auth_required = config['auth_required']
opts.log_prompts = config['log_prompts']
opts.concurrent_gens = config['concurrent_gens']
opts.frontend_api_client = config['frontend_api_client']
opts.context_size = config['token_limit']
opts.show_num_prompts = config['show_num_prompts']
opts.show_uptime = config['show_uptime']
opts.backend_url = config['backend_url'].strip('/')
opts.show_total_output_tokens = config['show_total_output_tokens']
opts.netdata_root = config['netdata_root']
opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip']
opts.show_backend_info = config['show_backend_info']
opts.max_new_tokens = config['max_new_tokens']
opts.manual_model_name = config['manual_model_name']
opts.llm_middleware_name = config['llm_middleware_name']
opts.enable_openi_compatible_backend = config['enable_openi_compatible_backend']
opts.openai_system_prompt = config['openai_system_prompt']
opts.expose_openai_system_prompt = config['expose_openai_system_prompt']
opts.enable_streaming = config['enable_streaming']
opts.openai_api_key = config['openai_api_key']
openai.api_key = opts.openai_api_key
opts.admin_token = config['admin_token']
opts.openai_expose_our_model = config['openai_epose_our_model']
opts.openai_force_no_hashes = config['openai_force_no_hashes']
opts.include_system_tokens_in_stats = config['include_system_tokens_in_stats']
opts.openai_moderation_scan_last_n = config['openai_moderation_scan_last_n']
opts.openai_moderation_workers = config['openai_moderation_workers']
opts.openai_org_name = config['openai_org_name']
opts.openai_silent_trim = config['openai_silent_trim']
opts.openai_moderation_enabled = config['openai_moderation_enabled']
if opts.openai_expose_our_model and not opts.openai_api_key:
print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.')
sys.exit(1)
opts.verify_ssl = config['verify_ssl']
if not opts.verify_ssl:
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
if config['average_generation_time_mode'] not in ['database', 'minute']:
print('Invalid value for config item "average_generation_time_mode":', config['average_generation_time_mode'])
sys.exit(1)
opts.average_generation_time_mode = config['average_generation_time_mode']
if opts.mode == 'oobabooga':
raise NotImplementedError
# llm_server.llm.tokenizer = OobaboogaBackend()
elif opts.mode == 'vllm':
llm_server.llm.get_token_count = llm_server.llm.vllm.tokenize
else:
raise Exception
def pre_fork(server):
llm_server.llm.redis = RedisWrapper('local_llm')
flushed_keys = redis.flush()
print('Flushed', len(flushed_keys), 'keys from Redis.')
redis.set_dict('processing_ips', {})
redis.set_dict('queued_ip_count', {})
# Flush the RedisPriorityQueue database.
queue_redis = Redis(host='localhost', port=6379, db=15)
for key in queue_redis.scan_iter('*'):
queue_redis.delete(key)
redis.set('backend_mode', opts.mode)
if config['http_host']:
http_host = re.sub(r'http(?:s)?://', '', config["http_host"])
redis.set('http_host', http_host)
redis.set('base_client_api', f'{http_host}/{opts.frontend_api_client.strip("/")}')
if config['load_num_prompts']:
redis.set('proompts', get_number_of_rows('prompts'))
# Start background processes
start_workers(opts.concurrent_gens)
start_console_printer()
start_moderation_workers(opts.openai_moderation_workers)
MainBackgroundThread().start()
SemaphoreCheckerThread().start()
# This needs to be started after Flask is initalized
stats_updater_thread = Thread(target=cache_stats)
stats_updater_thread.daemon = True
stats_updater_thread.start()
# Cache the initial stats
print('Loading backend stats...')
generate_stats()
# print(app.url_map) # print(app.url_map)
@ -280,6 +172,6 @@ def before_app_request():
if __name__ == "__main__": if __name__ == "__main__":
pre_fork(None) server_startup(None)
print('FLASK MODE - Startup complete!') print('FLASK MODE - Startup complete!')
app.run(host='0.0.0.0', threaded=False, processes=15) app.run(host='0.0.0.0', threaded=False, processes=15)