redo background processes, reorganize server.py
This commit is contained in:
parent
097d614a35
commit
e86a5182eb
|
@ -0,0 +1,44 @@
|
|||
import time
|
||||
|
||||
from llm_server.routes.cache import redis
|
||||
|
||||
try:
|
||||
import gevent.monkey
|
||||
|
||||
gevent.monkey.patch_all()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from llm_server.config.load import load_config
|
||||
from llm_server.database.create import create_db
|
||||
|
||||
from llm_server.workers.app import start_background
|
||||
|
||||
script_path = os.path.dirname(os.path.realpath(__file__))
|
||||
config_path_environ = os.getenv("CONFIG_PATH")
|
||||
if config_path_environ:
|
||||
config_path = config_path_environ
|
||||
else:
|
||||
config_path = Path(script_path, 'config', 'config.yml')
|
||||
|
||||
if __name__ == "__main__":
|
||||
flushed_keys = redis.flush()
|
||||
print('Flushed', len(flushed_keys), 'keys from Redis.')
|
||||
|
||||
success, config, msg = load_config(config_path, script_path)
|
||||
if not success:
|
||||
print('Failed to load config:', msg)
|
||||
sys.exit(1)
|
||||
|
||||
create_db()
|
||||
start_background()
|
||||
|
||||
redis.set('daemon_started', 1)
|
||||
print('== Daemon Setup Complete ==\n')
|
||||
|
||||
while True:
|
||||
time.sleep(3600)
|
|
@ -5,9 +5,9 @@ try:
|
|||
except ImportError:
|
||||
pass
|
||||
|
||||
import server
|
||||
from llm_server.pre_fork import server_startup
|
||||
|
||||
|
||||
def on_starting(s):
|
||||
server.pre_fork(s)
|
||||
server_startup(s)
|
||||
print('Startup complete!')
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
import re
|
||||
import sys
|
||||
|
||||
import openai
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.config.config import ConfigLoader, config_default_vars, config_required_vars
|
||||
from llm_server.database.conn import database
|
||||
from llm_server.database.database import get_number_of_rows
|
||||
from llm_server.helpers import resolve_path
|
||||
from llm_server.routes.cache import redis
|
||||
|
||||
|
||||
def load_config(config_path, script_path):
|
||||
config_loader = ConfigLoader(config_path, config_default_vars, config_required_vars)
|
||||
success, config, msg = config_loader.load_config()
|
||||
if not success:
|
||||
return success, config, msg
|
||||
|
||||
# Resolve relative directory to the directory of the script
|
||||
if config['database_path'].startswith('./'):
|
||||
config['database_path'] = resolve_path(script_path, config['database_path'].strip('./'))
|
||||
|
||||
if config['mode'] not in ['oobabooga', 'vllm']:
|
||||
print('Unknown mode:', config['mode'])
|
||||
sys.exit(1)
|
||||
|
||||
# TODO: this is atrocious
|
||||
opts.mode = config['mode']
|
||||
opts.auth_required = config['auth_required']
|
||||
opts.log_prompts = config['log_prompts']
|
||||
opts.concurrent_gens = config['concurrent_gens']
|
||||
opts.frontend_api_client = config['frontend_api_client']
|
||||
opts.context_size = config['token_limit']
|
||||
opts.show_num_prompts = config['show_num_prompts']
|
||||
opts.show_uptime = config['show_uptime']
|
||||
opts.backend_url = config['backend_url'].strip('/')
|
||||
opts.show_total_output_tokens = config['show_total_output_tokens']
|
||||
opts.netdata_root = config['netdata_root']
|
||||
opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip']
|
||||
opts.show_backend_info = config['show_backend_info']
|
||||
opts.max_new_tokens = config['max_new_tokens']
|
||||
opts.manual_model_name = config['manual_model_name']
|
||||
opts.llm_middleware_name = config['llm_middleware_name']
|
||||
opts.enable_openi_compatible_backend = config['enable_openi_compatible_backend']
|
||||
opts.openai_system_prompt = config['openai_system_prompt']
|
||||
opts.expose_openai_system_prompt = config['expose_openai_system_prompt']
|
||||
opts.enable_streaming = config['enable_streaming']
|
||||
opts.openai_api_key = config['openai_api_key']
|
||||
openai.api_key = opts.openai_api_key
|
||||
opts.admin_token = config['admin_token']
|
||||
opts.openai_expose_our_model = config['openai_epose_our_model']
|
||||
opts.openai_force_no_hashes = config['openai_force_no_hashes']
|
||||
opts.include_system_tokens_in_stats = config['include_system_tokens_in_stats']
|
||||
opts.openai_moderation_scan_last_n = config['openai_moderation_scan_last_n']
|
||||
opts.openai_moderation_workers = config['openai_moderation_workers']
|
||||
opts.openai_org_name = config['openai_org_name']
|
||||
opts.openai_silent_trim = config['openai_silent_trim']
|
||||
opts.openai_moderation_enabled = config['openai_moderation_enabled']
|
||||
|
||||
if opts.openai_expose_our_model and not opts.openai_api_key:
|
||||
print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.')
|
||||
sys.exit(1)
|
||||
|
||||
opts.verify_ssl = config['verify_ssl']
|
||||
if not opts.verify_ssl:
|
||||
import urllib3
|
||||
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
if config['http_host']:
|
||||
http_host = re.sub(r'http(?:s)?://', '', config["http_host"])
|
||||
redis.set('http_host', http_host)
|
||||
redis.set('base_client_api', f'{http_host}/{opts.frontend_api_client.strip("/")}')
|
||||
|
||||
database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database'])
|
||||
|
||||
if config['load_num_prompts']:
|
||||
redis.set('proompts', get_number_of_rows('prompts'))
|
||||
|
||||
redis.set_dict('recent_prompters', {})
|
||||
redis.set_dict('processing_ips', {})
|
||||
redis.set_dict('queued_ip_count', {})
|
||||
redis.set('backend_mode', opts.mode)
|
||||
|
||||
return success, config, msg
|
|
@ -0,0 +1,21 @@
|
|||
import sys
|
||||
|
||||
from redis import Redis
|
||||
|
||||
from llm_server.routes.cache import redis
|
||||
from llm_server.routes.v1.generate_stats import generate_stats
|
||||
|
||||
|
||||
def server_startup(s):
|
||||
if not redis.get('daemon_started', bool):
|
||||
print('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?')
|
||||
sys.exit(1)
|
||||
|
||||
# Flush the RedisPriorityQueue database.
|
||||
queue_redis = Redis(host='localhost', port=6379, db=15)
|
||||
for key in queue_redis.scan_iter('*'):
|
||||
queue_redis.delete(key)
|
||||
|
||||
# Cache the initial stats
|
||||
print('Loading backend stats...')
|
||||
generate_stats()
|
|
@ -57,7 +57,7 @@ def require_api_key(json_body: dict = None):
|
|||
return jsonify({'code': 403, 'message': 'Invalid API key'}), 403
|
||||
elif 'Authorization' in request.headers:
|
||||
token = parse_token(request.headers['Authorization'])
|
||||
if token.startswith('SYSTEM__') or opts.auth_required:
|
||||
if (token and token.startswith('SYSTEM__')) or opts.auth_required:
|
||||
if is_valid_api_key(token):
|
||||
return
|
||||
else:
|
||||
|
|
|
@ -10,7 +10,7 @@ from llm_server import opts
|
|||
from llm_server.database.database import is_api_key_moderated
|
||||
from llm_server.llm.openai.transform import build_openai_response, transform_messages_to_prompt, trim_prompt_to_fit
|
||||
from llm_server.routes.request_handler import RequestHandler
|
||||
from llm_server.threads import add_moderation_task, get_results
|
||||
from llm_server.workers.moderator import add_moderation_task, get_results
|
||||
|
||||
|
||||
class OpenAIRequestHandler(RequestHandler):
|
||||
|
|
|
@ -1,6 +1,4 @@
|
|||
import time
|
||||
from datetime import datetime
|
||||
from threading import Thread
|
||||
|
||||
from llm_server.routes.cache import redis
|
||||
|
||||
|
@ -46,27 +44,3 @@ def get_active_gen_workers():
|
|||
else:
|
||||
count = int(active_gen_workers)
|
||||
return count
|
||||
|
||||
|
||||
class SemaphoreCheckerThread(Thread):
|
||||
redis.set_dict('recent_prompters', {})
|
||||
|
||||
def __init__(self):
|
||||
Thread.__init__(self)
|
||||
self.daemon = True
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
current_time = time.time()
|
||||
recent_prompters = redis.get_dict('recent_prompters')
|
||||
new_recent_prompters = {}
|
||||
|
||||
for ip, (timestamp, token) in recent_prompters.items():
|
||||
if token and token.startswith('SYSTEM__'):
|
||||
continue
|
||||
if current_time - timestamp <= 300:
|
||||
new_recent_prompters[ip] = timestamp, token
|
||||
|
||||
redis.set_dict('recent_prompters', new_recent_prompters)
|
||||
redis.set('proompters_5_min', len(new_recent_prompters))
|
||||
time.sleep(1)
|
||||
|
|
|
@ -1,120 +0,0 @@
|
|||
import json
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from threading import Thread
|
||||
|
||||
import redis as redis_redis
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database.database import weighted_average_column_for_model
|
||||
from llm_server.llm.info import get_running_model
|
||||
from llm_server.llm.openai.moderation import check_moderation_endpoint
|
||||
from llm_server.routes.cache import redis
|
||||
from llm_server.routes.v1.generate_stats import generate_stats
|
||||
|
||||
|
||||
class MainBackgroundThread(Thread):
|
||||
backend_online = False
|
||||
|
||||
# TODO: do I really need to put everything in Redis?
|
||||
# TODO: call generate_stats() every minute, cache the results, put results in a DB table, then have other parts of code call this cache
|
||||
|
||||
def __init__(self):
|
||||
Thread.__init__(self)
|
||||
self.daemon = True
|
||||
redis.set('average_generation_elapsed_sec', 0)
|
||||
redis.set('estimated_avg_tps', 0)
|
||||
redis.set('average_output_tokens', 0)
|
||||
redis.set('backend_online', 0)
|
||||
redis.set_dict('backend_info', {})
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
# TODO: unify this
|
||||
if opts.mode == 'oobabooga':
|
||||
running_model, err = get_running_model()
|
||||
if err:
|
||||
print(err)
|
||||
redis.set('backend_online', 0)
|
||||
else:
|
||||
redis.set('running_model', running_model)
|
||||
redis.set('backend_online', 1)
|
||||
elif opts.mode == 'vllm':
|
||||
running_model, err = get_running_model()
|
||||
if err:
|
||||
print(err)
|
||||
redis.set('backend_online', 0)
|
||||
else:
|
||||
redis.set('running_model', running_model)
|
||||
redis.set('backend_online', 1)
|
||||
else:
|
||||
raise Exception
|
||||
|
||||
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
|
||||
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now.
|
||||
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
|
||||
if average_generation_elapsed_sec: # returns None on exception
|
||||
redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec)
|
||||
|
||||
# overall = average_column_for_model('prompts', 'generation_time', opts.running_model)
|
||||
# print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}')
|
||||
|
||||
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
|
||||
if average_generation_elapsed_sec:
|
||||
redis.set('average_output_tokens', average_output_tokens)
|
||||
|
||||
# overall = average_column_for_model('prompts', 'response_tokens', opts.running_model)
|
||||
# print(f'Weighted: {average_output_tokens}, overall: {overall}')
|
||||
|
||||
estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero
|
||||
redis.set('estimated_avg_tps', estimated_avg_tps)
|
||||
time.sleep(60)
|
||||
|
||||
|
||||
def cache_stats():
|
||||
while True:
|
||||
generate_stats(regen=True)
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
redis_moderation = redis_redis.Redis()
|
||||
|
||||
|
||||
def start_moderation_workers(num_workers):
|
||||
for _ in range(num_workers):
|
||||
t = threading.Thread(target=moderation_worker)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
|
||||
|
||||
def moderation_worker():
|
||||
while True:
|
||||
result = redis_moderation.blpop('queue:msgs_to_check')
|
||||
try:
|
||||
msg, tag = json.loads(result[1])
|
||||
_, categories = check_moderation_endpoint(msg)
|
||||
redis_moderation.rpush('queue:flagged_categories', json.dumps((tag, categories)))
|
||||
except:
|
||||
print(result)
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
|
||||
def add_moderation_task(msg, tag):
|
||||
redis_moderation.rpush('queue:msgs_to_check', json.dumps((msg, str(tag))))
|
||||
|
||||
|
||||
def get_results(tag, num_tasks):
|
||||
tag = str(tag) # Required for comparison with Redis results.
|
||||
flagged_categories = set()
|
||||
num_results = 0
|
||||
while num_results < num_tasks:
|
||||
result = redis_moderation.blpop('queue:flagged_categories')
|
||||
result_tag, categories = json.loads(result[1])
|
||||
if result_tag == tag:
|
||||
if categories:
|
||||
for item in categories:
|
||||
flagged_categories.add(item)
|
||||
num_results += 1
|
||||
return list(flagged_categories)
|
|
@ -0,0 +1,35 @@
|
|||
from threading import Thread
|
||||
|
||||
from .blocking import start_workers
|
||||
from .main import main_background_thread
|
||||
from .moderator import start_moderation_workers
|
||||
from .printer import console_printer
|
||||
from .recent import recent_prompters_thread
|
||||
from .threads import cache_stats
|
||||
from .. import opts
|
||||
|
||||
|
||||
def start_background():
|
||||
start_workers(opts.concurrent_gens)
|
||||
|
||||
t = Thread(target=main_background_thread)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
print('Started the main background thread.')
|
||||
|
||||
start_moderation_workers(opts.openai_moderation_workers)
|
||||
|
||||
t = Thread(target=cache_stats)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
print('Started the stats cacher.')
|
||||
|
||||
t = Thread(target=recent_prompters_thread)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
print('Started the recent proompters thread.')
|
||||
|
||||
t = Thread(target=console_printer)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
print('Started the console printer.')
|
|
@ -1,4 +1,3 @@
|
|||
import json
|
||||
import threading
|
||||
import time
|
||||
|
||||
|
@ -17,12 +16,12 @@ def worker():
|
|||
increment_ip_count(client_ip, 'processing_ips')
|
||||
redis.incr('active_gen_workers')
|
||||
|
||||
try:
|
||||
if not request_json_body:
|
||||
# This was a dummy request from the websocket handler.
|
||||
# We're going to let the websocket handler decrement processing_ips and active_gen_workers.
|
||||
continue
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
success, response, error_msg = generator(request_json_body)
|
||||
end_time = time.time()
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
import time
|
||||
from threading import Thread
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database.database import weighted_average_column_for_model
|
||||
from llm_server.llm.info import get_running_model
|
||||
from llm_server.routes.cache import redis
|
||||
|
||||
|
||||
def main_background_thread():
|
||||
redis.set('average_generation_elapsed_sec', 0)
|
||||
redis.set('estimated_avg_tps', 0)
|
||||
redis.set('average_output_tokens', 0)
|
||||
redis.set('backend_online', 0)
|
||||
redis.set_dict('backend_info', {})
|
||||
|
||||
while True:
|
||||
# TODO: unify this
|
||||
if opts.mode == 'oobabooga':
|
||||
running_model, err = get_running_model()
|
||||
if err:
|
||||
print(err)
|
||||
redis.set('backend_online', 0)
|
||||
else:
|
||||
redis.set('running_model', running_model)
|
||||
redis.set('backend_online', 1)
|
||||
elif opts.mode == 'vllm':
|
||||
running_model, err = get_running_model()
|
||||
if err:
|
||||
print(err)
|
||||
redis.set('backend_online', 0)
|
||||
else:
|
||||
redis.set('running_model', running_model)
|
||||
redis.set('backend_online', 1)
|
||||
else:
|
||||
raise Exception
|
||||
|
||||
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
|
||||
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now.
|
||||
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
|
||||
if average_generation_elapsed_sec: # returns None on exception
|
||||
redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec)
|
||||
|
||||
# overall = average_column_for_model('prompts', 'generation_time', opts.running_model)
|
||||
# print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}')
|
||||
|
||||
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
|
||||
if average_generation_elapsed_sec:
|
||||
redis.set('average_output_tokens', average_output_tokens)
|
||||
|
||||
# overall = average_column_for_model('prompts', 'response_tokens', opts.running_model)
|
||||
# print(f'Weighted: {average_output_tokens}, overall: {overall}')
|
||||
|
||||
estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero
|
||||
redis.set('estimated_avg_tps', estimated_avg_tps)
|
||||
time.sleep(60)
|
|
@ -0,0 +1,51 @@
|
|||
import json
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
import redis as redis_redis
|
||||
|
||||
from llm_server.llm.openai.moderation import check_moderation_endpoint
|
||||
|
||||
redis_moderation = redis_redis.Redis()
|
||||
|
||||
|
||||
def start_moderation_workers(num_workers):
|
||||
i = 0
|
||||
for _ in range(num_workers):
|
||||
t = threading.Thread(target=moderation_worker)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
i += 1
|
||||
print(f'Started {i} moderation workers.')
|
||||
|
||||
|
||||
def moderation_worker():
|
||||
while True:
|
||||
result = redis_moderation.blpop('queue:msgs_to_check')
|
||||
try:
|
||||
msg, tag = json.loads(result[1])
|
||||
_, categories = check_moderation_endpoint(msg)
|
||||
redis_moderation.rpush('queue:flagged_categories', json.dumps((tag, categories)))
|
||||
except:
|
||||
print(result)
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
|
||||
def add_moderation_task(msg, tag):
|
||||
redis_moderation.rpush('queue:msgs_to_check', json.dumps((msg, str(tag))))
|
||||
|
||||
|
||||
def get_results(tag, num_tasks):
|
||||
tag = str(tag) # Required for comparison with Redis results.
|
||||
flagged_categories = set()
|
||||
num_results = 0
|
||||
while num_results < num_tasks:
|
||||
result = redis_moderation.blpop('queue:flagged_categories')
|
||||
result_tag, categories = json.loads(result[1])
|
||||
if result_tag == tag:
|
||||
if categories:
|
||||
for item in categories:
|
||||
flagged_categories.add(item)
|
||||
num_results += 1
|
||||
return list(flagged_categories)
|
|
@ -1,5 +1,4 @@
|
|||
import logging
|
||||
import threading
|
||||
import time
|
||||
|
||||
from llm_server.routes.cache import redis
|
||||
|
@ -15,14 +14,9 @@ if not logger.handlers:
|
|||
|
||||
|
||||
def console_printer():
|
||||
time.sleep(3)
|
||||
while True:
|
||||
queued_ip_count = sum([v for k, v in redis.get_dict('queued_ip_count').items()])
|
||||
processing_count = sum([v for k, v in redis.get_dict('processing_ips').items()])
|
||||
logger.info(f'REQUEST QUEUE -> Processing: {processing_count} | Queued: {queued_ip_count}')
|
||||
time.sleep(10)
|
||||
|
||||
|
||||
def start_console_printer():
|
||||
t = threading.Thread(target=console_printer)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
time.sleep(15)
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
import time
|
||||
|
||||
from llm_server.routes.cache import redis
|
||||
|
||||
|
||||
def recent_prompters_thread():
|
||||
current_time = time.time()
|
||||
recent_prompters = redis.get_dict('recent_prompters')
|
||||
new_recent_prompters = {}
|
||||
|
||||
for ip, (timestamp, token) in recent_prompters.items():
|
||||
if token and token.startswith('SYSTEM__'):
|
||||
continue
|
||||
if current_time - timestamp <= 300:
|
||||
new_recent_prompters[ip] = timestamp, token
|
||||
|
||||
redis.set_dict('recent_prompters', new_recent_prompters)
|
||||
redis.set('proompters_5_min', len(new_recent_prompters))
|
||||
time.sleep(1)
|
|
@ -0,0 +1,9 @@
|
|||
import time
|
||||
|
||||
from llm_server.routes.v1.generate_stats import generate_stats
|
||||
|
||||
|
||||
def cache_stats():
|
||||
while True:
|
||||
generate_stats(regen=True)
|
||||
time.sleep(5)
|
|
@ -5,7 +5,6 @@ flask_caching
|
|||
requests~=2.31.0
|
||||
tiktoken~=0.5.0
|
||||
gunicorn
|
||||
redis~=5.0.0
|
||||
gevent~=23.9.0.post1
|
||||
async-timeout
|
||||
flask-sock
|
||||
|
@ -19,4 +18,4 @@ websockets~=11.0.3
|
|||
basicauth~=1.0.0
|
||||
openai~=0.28.0
|
||||
urllib3~=2.0.4
|
||||
rq~=1.15.1
|
||||
celery[redis]
|
||||
|
|
128
server.py
128
server.py
|
@ -1,3 +1,5 @@
|
|||
from llm_server.config.config import mode_ui_names
|
||||
|
||||
try:
|
||||
import gevent.monkey
|
||||
|
||||
|
@ -5,28 +7,23 @@ try:
|
|||
except ImportError:
|
||||
pass
|
||||
|
||||
from llm_server.workers.printer import start_console_printer
|
||||
from llm_server.pre_fork import server_startup
|
||||
from llm_server.config.load import load_config
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
|
||||
import openai
|
||||
import simplejson as json
|
||||
from flask import Flask, jsonify, render_template, request
|
||||
from redis import Redis
|
||||
|
||||
import llm_server
|
||||
from llm_server.database.conn import database
|
||||
from llm_server.database.create import create_db
|
||||
from llm_server.database.database import get_number_of_rows
|
||||
from llm_server.llm import get_token_count
|
||||
from llm_server.routes.openai import openai_bp
|
||||
from llm_server.routes.server_error import handle_server_error
|
||||
from llm_server.routes.v1 import bp
|
||||
from llm_server.stream import init_socketio
|
||||
from llm_server.workers.blocking import start_workers
|
||||
|
||||
# TODO: have the workers handle streaming too
|
||||
# TODO: add backend fallbacks. Backends at the bottom of the list are higher priority and are fallbacks if the upper ones fail
|
||||
|
@ -62,16 +59,12 @@ except ModuleNotFoundError as e:
|
|||
|
||||
import config
|
||||
from llm_server import opts
|
||||
from llm_server.config import ConfigLoader, config_default_vars, config_required_vars, mode_ui_names
|
||||
from llm_server.helpers import resolve_path, auto_set_base_client_api
|
||||
from llm_server.helpers import auto_set_base_client_api
|
||||
from llm_server.llm.vllm.info import vllm_info
|
||||
from llm_server.routes.cache import RedisWrapper, flask_cache
|
||||
from llm_server.llm import redis
|
||||
from llm_server.routes.stats import SemaphoreCheckerThread, get_active_gen_workers
|
||||
from llm_server.routes.stats import get_active_gen_workers
|
||||
from llm_server.routes.v1.generate_stats import generate_stats
|
||||
from llm_server.threads import MainBackgroundThread, cache_stats, start_moderation_workers
|
||||
|
||||
script_path = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
app = Flask(__name__)
|
||||
init_socketio(app)
|
||||
|
@ -80,123 +73,22 @@ app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
|
|||
flask_cache.init_app(app)
|
||||
flask_cache.clear()
|
||||
|
||||
script_path = os.path.dirname(os.path.realpath(__file__))
|
||||
config_path_environ = os.getenv("CONFIG_PATH")
|
||||
if config_path_environ:
|
||||
config_path = config_path_environ
|
||||
else:
|
||||
config_path = Path(script_path, 'config', 'config.yml')
|
||||
|
||||
config_loader = ConfigLoader(config_path, config_default_vars, config_required_vars)
|
||||
success, config, msg = config_loader.load_config()
|
||||
success, config, msg = load_config(config_path, script_path)
|
||||
if not success:
|
||||
print('Failed to load config:', msg)
|
||||
sys.exit(1)
|
||||
|
||||
# Resolve relative directory to the directory of the script
|
||||
if config['database_path'].startswith('./'):
|
||||
config['database_path'] = resolve_path(script_path, config['database_path'].strip('./'))
|
||||
|
||||
database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database'])
|
||||
create_db()
|
||||
|
||||
if config['mode'] not in ['oobabooga', 'vllm']:
|
||||
print('Unknown mode:', config['mode'])
|
||||
sys.exit(1)
|
||||
|
||||
# TODO: this is atrocious
|
||||
opts.mode = config['mode']
|
||||
opts.auth_required = config['auth_required']
|
||||
opts.log_prompts = config['log_prompts']
|
||||
opts.concurrent_gens = config['concurrent_gens']
|
||||
opts.frontend_api_client = config['frontend_api_client']
|
||||
opts.context_size = config['token_limit']
|
||||
opts.show_num_prompts = config['show_num_prompts']
|
||||
opts.show_uptime = config['show_uptime']
|
||||
opts.backend_url = config['backend_url'].strip('/')
|
||||
opts.show_total_output_tokens = config['show_total_output_tokens']
|
||||
opts.netdata_root = config['netdata_root']
|
||||
opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip']
|
||||
opts.show_backend_info = config['show_backend_info']
|
||||
opts.max_new_tokens = config['max_new_tokens']
|
||||
opts.manual_model_name = config['manual_model_name']
|
||||
opts.llm_middleware_name = config['llm_middleware_name']
|
||||
opts.enable_openi_compatible_backend = config['enable_openi_compatible_backend']
|
||||
opts.openai_system_prompt = config['openai_system_prompt']
|
||||
opts.expose_openai_system_prompt = config['expose_openai_system_prompt']
|
||||
opts.enable_streaming = config['enable_streaming']
|
||||
opts.openai_api_key = config['openai_api_key']
|
||||
openai.api_key = opts.openai_api_key
|
||||
opts.admin_token = config['admin_token']
|
||||
opts.openai_expose_our_model = config['openai_epose_our_model']
|
||||
opts.openai_force_no_hashes = config['openai_force_no_hashes']
|
||||
opts.include_system_tokens_in_stats = config['include_system_tokens_in_stats']
|
||||
opts.openai_moderation_scan_last_n = config['openai_moderation_scan_last_n']
|
||||
opts.openai_moderation_workers = config['openai_moderation_workers']
|
||||
opts.openai_org_name = config['openai_org_name']
|
||||
opts.openai_silent_trim = config['openai_silent_trim']
|
||||
opts.openai_moderation_enabled = config['openai_moderation_enabled']
|
||||
|
||||
if opts.openai_expose_our_model and not opts.openai_api_key:
|
||||
print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.')
|
||||
sys.exit(1)
|
||||
|
||||
opts.verify_ssl = config['verify_ssl']
|
||||
if not opts.verify_ssl:
|
||||
import urllib3
|
||||
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
if config['average_generation_time_mode'] not in ['database', 'minute']:
|
||||
print('Invalid value for config item "average_generation_time_mode":', config['average_generation_time_mode'])
|
||||
sys.exit(1)
|
||||
opts.average_generation_time_mode = config['average_generation_time_mode']
|
||||
|
||||
if opts.mode == 'oobabooga':
|
||||
raise NotImplementedError
|
||||
# llm_server.llm.tokenizer = OobaboogaBackend()
|
||||
elif opts.mode == 'vllm':
|
||||
llm_server.llm.get_token_count = llm_server.llm.vllm.tokenize
|
||||
else:
|
||||
raise Exception
|
||||
|
||||
|
||||
def pre_fork(server):
|
||||
llm_server.llm.redis = RedisWrapper('local_llm')
|
||||
flushed_keys = redis.flush()
|
||||
print('Flushed', len(flushed_keys), 'keys from Redis.')
|
||||
|
||||
redis.set_dict('processing_ips', {})
|
||||
redis.set_dict('queued_ip_count', {})
|
||||
|
||||
# Flush the RedisPriorityQueue database.
|
||||
queue_redis = Redis(host='localhost', port=6379, db=15)
|
||||
for key in queue_redis.scan_iter('*'):
|
||||
queue_redis.delete(key)
|
||||
|
||||
redis.set('backend_mode', opts.mode)
|
||||
if config['http_host']:
|
||||
http_host = re.sub(r'http(?:s)?://', '', config["http_host"])
|
||||
redis.set('http_host', http_host)
|
||||
redis.set('base_client_api', f'{http_host}/{opts.frontend_api_client.strip("/")}')
|
||||
|
||||
if config['load_num_prompts']:
|
||||
redis.set('proompts', get_number_of_rows('prompts'))
|
||||
|
||||
# Start background processes
|
||||
start_workers(opts.concurrent_gens)
|
||||
start_console_printer()
|
||||
start_moderation_workers(opts.openai_moderation_workers)
|
||||
MainBackgroundThread().start()
|
||||
SemaphoreCheckerThread().start()
|
||||
|
||||
# This needs to be started after Flask is initalized
|
||||
stats_updater_thread = Thread(target=cache_stats)
|
||||
stats_updater_thread.daemon = True
|
||||
stats_updater_thread.start()
|
||||
|
||||
# Cache the initial stats
|
||||
print('Loading backend stats...')
|
||||
generate_stats()
|
||||
create_db()
|
||||
|
||||
|
||||
# print(app.url_map)
|
||||
|
@ -280,6 +172,6 @@ def before_app_request():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pre_fork(None)
|
||||
server_startup(None)
|
||||
print('FLASK MODE - Startup complete!')
|
||||
app.run(host='0.0.0.0', threaded=False, processes=15)
|
||||
|
|
Reference in New Issue