redo background processes, reorganize server.py
This commit is contained in:
parent
097d614a35
commit
e86a5182eb
|
@ -0,0 +1,44 @@
|
||||||
|
import time
|
||||||
|
|
||||||
|
from llm_server.routes.cache import redis
|
||||||
|
|
||||||
|
try:
|
||||||
|
import gevent.monkey
|
||||||
|
|
||||||
|
gevent.monkey.patch_all()
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from llm_server.config.load import load_config
|
||||||
|
from llm_server.database.create import create_db
|
||||||
|
|
||||||
|
from llm_server.workers.app import start_background
|
||||||
|
|
||||||
|
script_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
config_path_environ = os.getenv("CONFIG_PATH")
|
||||||
|
if config_path_environ:
|
||||||
|
config_path = config_path_environ
|
||||||
|
else:
|
||||||
|
config_path = Path(script_path, 'config', 'config.yml')
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
flushed_keys = redis.flush()
|
||||||
|
print('Flushed', len(flushed_keys), 'keys from Redis.')
|
||||||
|
|
||||||
|
success, config, msg = load_config(config_path, script_path)
|
||||||
|
if not success:
|
||||||
|
print('Failed to load config:', msg)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
create_db()
|
||||||
|
start_background()
|
||||||
|
|
||||||
|
redis.set('daemon_started', 1)
|
||||||
|
print('== Daemon Setup Complete ==\n')
|
||||||
|
|
||||||
|
while True:
|
||||||
|
time.sleep(3600)
|
|
@ -5,9 +5,9 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
import server
|
from llm_server.pre_fork import server_startup
|
||||||
|
|
||||||
|
|
||||||
def on_starting(s):
|
def on_starting(s):
|
||||||
server.pre_fork(s)
|
server_startup(s)
|
||||||
print('Startup complete!')
|
print('Startup complete!')
|
||||||
|
|
|
@ -0,0 +1,86 @@
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import openai
|
||||||
|
|
||||||
|
from llm_server import opts
|
||||||
|
from llm_server.config.config import ConfigLoader, config_default_vars, config_required_vars
|
||||||
|
from llm_server.database.conn import database
|
||||||
|
from llm_server.database.database import get_number_of_rows
|
||||||
|
from llm_server.helpers import resolve_path
|
||||||
|
from llm_server.routes.cache import redis
|
||||||
|
|
||||||
|
|
||||||
|
def load_config(config_path, script_path):
|
||||||
|
config_loader = ConfigLoader(config_path, config_default_vars, config_required_vars)
|
||||||
|
success, config, msg = config_loader.load_config()
|
||||||
|
if not success:
|
||||||
|
return success, config, msg
|
||||||
|
|
||||||
|
# Resolve relative directory to the directory of the script
|
||||||
|
if config['database_path'].startswith('./'):
|
||||||
|
config['database_path'] = resolve_path(script_path, config['database_path'].strip('./'))
|
||||||
|
|
||||||
|
if config['mode'] not in ['oobabooga', 'vllm']:
|
||||||
|
print('Unknown mode:', config['mode'])
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# TODO: this is atrocious
|
||||||
|
opts.mode = config['mode']
|
||||||
|
opts.auth_required = config['auth_required']
|
||||||
|
opts.log_prompts = config['log_prompts']
|
||||||
|
opts.concurrent_gens = config['concurrent_gens']
|
||||||
|
opts.frontend_api_client = config['frontend_api_client']
|
||||||
|
opts.context_size = config['token_limit']
|
||||||
|
opts.show_num_prompts = config['show_num_prompts']
|
||||||
|
opts.show_uptime = config['show_uptime']
|
||||||
|
opts.backend_url = config['backend_url'].strip('/')
|
||||||
|
opts.show_total_output_tokens = config['show_total_output_tokens']
|
||||||
|
opts.netdata_root = config['netdata_root']
|
||||||
|
opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip']
|
||||||
|
opts.show_backend_info = config['show_backend_info']
|
||||||
|
opts.max_new_tokens = config['max_new_tokens']
|
||||||
|
opts.manual_model_name = config['manual_model_name']
|
||||||
|
opts.llm_middleware_name = config['llm_middleware_name']
|
||||||
|
opts.enable_openi_compatible_backend = config['enable_openi_compatible_backend']
|
||||||
|
opts.openai_system_prompt = config['openai_system_prompt']
|
||||||
|
opts.expose_openai_system_prompt = config['expose_openai_system_prompt']
|
||||||
|
opts.enable_streaming = config['enable_streaming']
|
||||||
|
opts.openai_api_key = config['openai_api_key']
|
||||||
|
openai.api_key = opts.openai_api_key
|
||||||
|
opts.admin_token = config['admin_token']
|
||||||
|
opts.openai_expose_our_model = config['openai_epose_our_model']
|
||||||
|
opts.openai_force_no_hashes = config['openai_force_no_hashes']
|
||||||
|
opts.include_system_tokens_in_stats = config['include_system_tokens_in_stats']
|
||||||
|
opts.openai_moderation_scan_last_n = config['openai_moderation_scan_last_n']
|
||||||
|
opts.openai_moderation_workers = config['openai_moderation_workers']
|
||||||
|
opts.openai_org_name = config['openai_org_name']
|
||||||
|
opts.openai_silent_trim = config['openai_silent_trim']
|
||||||
|
opts.openai_moderation_enabled = config['openai_moderation_enabled']
|
||||||
|
|
||||||
|
if opts.openai_expose_our_model and not opts.openai_api_key:
|
||||||
|
print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.')
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
opts.verify_ssl = config['verify_ssl']
|
||||||
|
if not opts.verify_ssl:
|
||||||
|
import urllib3
|
||||||
|
|
||||||
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||||
|
|
||||||
|
if config['http_host']:
|
||||||
|
http_host = re.sub(r'http(?:s)?://', '', config["http_host"])
|
||||||
|
redis.set('http_host', http_host)
|
||||||
|
redis.set('base_client_api', f'{http_host}/{opts.frontend_api_client.strip("/")}')
|
||||||
|
|
||||||
|
database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database'])
|
||||||
|
|
||||||
|
if config['load_num_prompts']:
|
||||||
|
redis.set('proompts', get_number_of_rows('prompts'))
|
||||||
|
|
||||||
|
redis.set_dict('recent_prompters', {})
|
||||||
|
redis.set_dict('processing_ips', {})
|
||||||
|
redis.set_dict('queued_ip_count', {})
|
||||||
|
redis.set('backend_mode', opts.mode)
|
||||||
|
|
||||||
|
return success, config, msg
|
|
@ -0,0 +1,21 @@
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from redis import Redis
|
||||||
|
|
||||||
|
from llm_server.routes.cache import redis
|
||||||
|
from llm_server.routes.v1.generate_stats import generate_stats
|
||||||
|
|
||||||
|
|
||||||
|
def server_startup(s):
|
||||||
|
if not redis.get('daemon_started', bool):
|
||||||
|
print('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?')
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Flush the RedisPriorityQueue database.
|
||||||
|
queue_redis = Redis(host='localhost', port=6379, db=15)
|
||||||
|
for key in queue_redis.scan_iter('*'):
|
||||||
|
queue_redis.delete(key)
|
||||||
|
|
||||||
|
# Cache the initial stats
|
||||||
|
print('Loading backend stats...')
|
||||||
|
generate_stats()
|
|
@ -57,7 +57,7 @@ def require_api_key(json_body: dict = None):
|
||||||
return jsonify({'code': 403, 'message': 'Invalid API key'}), 403
|
return jsonify({'code': 403, 'message': 'Invalid API key'}), 403
|
||||||
elif 'Authorization' in request.headers:
|
elif 'Authorization' in request.headers:
|
||||||
token = parse_token(request.headers['Authorization'])
|
token = parse_token(request.headers['Authorization'])
|
||||||
if token.startswith('SYSTEM__') or opts.auth_required:
|
if (token and token.startswith('SYSTEM__')) or opts.auth_required:
|
||||||
if is_valid_api_key(token):
|
if is_valid_api_key(token):
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -10,7 +10,7 @@ from llm_server import opts
|
||||||
from llm_server.database.database import is_api_key_moderated
|
from llm_server.database.database import is_api_key_moderated
|
||||||
from llm_server.llm.openai.transform import build_openai_response, transform_messages_to_prompt, trim_prompt_to_fit
|
from llm_server.llm.openai.transform import build_openai_response, transform_messages_to_prompt, trim_prompt_to_fit
|
||||||
from llm_server.routes.request_handler import RequestHandler
|
from llm_server.routes.request_handler import RequestHandler
|
||||||
from llm_server.threads import add_moderation_task, get_results
|
from llm_server.workers.moderator import add_moderation_task, get_results
|
||||||
|
|
||||||
|
|
||||||
class OpenAIRequestHandler(RequestHandler):
|
class OpenAIRequestHandler(RequestHandler):
|
||||||
|
|
|
@ -1,6 +1,4 @@
|
||||||
import time
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from threading import Thread
|
|
||||||
|
|
||||||
from llm_server.routes.cache import redis
|
from llm_server.routes.cache import redis
|
||||||
|
|
||||||
|
@ -46,27 +44,3 @@ def get_active_gen_workers():
|
||||||
else:
|
else:
|
||||||
count = int(active_gen_workers)
|
count = int(active_gen_workers)
|
||||||
return count
|
return count
|
||||||
|
|
||||||
|
|
||||||
class SemaphoreCheckerThread(Thread):
|
|
||||||
redis.set_dict('recent_prompters', {})
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
Thread.__init__(self)
|
|
||||||
self.daemon = True
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
while True:
|
|
||||||
current_time = time.time()
|
|
||||||
recent_prompters = redis.get_dict('recent_prompters')
|
|
||||||
new_recent_prompters = {}
|
|
||||||
|
|
||||||
for ip, (timestamp, token) in recent_prompters.items():
|
|
||||||
if token and token.startswith('SYSTEM__'):
|
|
||||||
continue
|
|
||||||
if current_time - timestamp <= 300:
|
|
||||||
new_recent_prompters[ip] = timestamp, token
|
|
||||||
|
|
||||||
redis.set_dict('recent_prompters', new_recent_prompters)
|
|
||||||
redis.set('proompters_5_min', len(new_recent_prompters))
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
|
@ -1,120 +0,0 @@
|
||||||
import json
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
import traceback
|
|
||||||
from threading import Thread
|
|
||||||
|
|
||||||
import redis as redis_redis
|
|
||||||
|
|
||||||
from llm_server import opts
|
|
||||||
from llm_server.database.database import weighted_average_column_for_model
|
|
||||||
from llm_server.llm.info import get_running_model
|
|
||||||
from llm_server.llm.openai.moderation import check_moderation_endpoint
|
|
||||||
from llm_server.routes.cache import redis
|
|
||||||
from llm_server.routes.v1.generate_stats import generate_stats
|
|
||||||
|
|
||||||
|
|
||||||
class MainBackgroundThread(Thread):
|
|
||||||
backend_online = False
|
|
||||||
|
|
||||||
# TODO: do I really need to put everything in Redis?
|
|
||||||
# TODO: call generate_stats() every minute, cache the results, put results in a DB table, then have other parts of code call this cache
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
Thread.__init__(self)
|
|
||||||
self.daemon = True
|
|
||||||
redis.set('average_generation_elapsed_sec', 0)
|
|
||||||
redis.set('estimated_avg_tps', 0)
|
|
||||||
redis.set('average_output_tokens', 0)
|
|
||||||
redis.set('backend_online', 0)
|
|
||||||
redis.set_dict('backend_info', {})
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
while True:
|
|
||||||
# TODO: unify this
|
|
||||||
if opts.mode == 'oobabooga':
|
|
||||||
running_model, err = get_running_model()
|
|
||||||
if err:
|
|
||||||
print(err)
|
|
||||||
redis.set('backend_online', 0)
|
|
||||||
else:
|
|
||||||
redis.set('running_model', running_model)
|
|
||||||
redis.set('backend_online', 1)
|
|
||||||
elif opts.mode == 'vllm':
|
|
||||||
running_model, err = get_running_model()
|
|
||||||
if err:
|
|
||||||
print(err)
|
|
||||||
redis.set('backend_online', 0)
|
|
||||||
else:
|
|
||||||
redis.set('running_model', running_model)
|
|
||||||
redis.set('backend_online', 1)
|
|
||||||
else:
|
|
||||||
raise Exception
|
|
||||||
|
|
||||||
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
|
|
||||||
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now.
|
|
||||||
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
|
|
||||||
if average_generation_elapsed_sec: # returns None on exception
|
|
||||||
redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec)
|
|
||||||
|
|
||||||
# overall = average_column_for_model('prompts', 'generation_time', opts.running_model)
|
|
||||||
# print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}')
|
|
||||||
|
|
||||||
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
|
|
||||||
if average_generation_elapsed_sec:
|
|
||||||
redis.set('average_output_tokens', average_output_tokens)
|
|
||||||
|
|
||||||
# overall = average_column_for_model('prompts', 'response_tokens', opts.running_model)
|
|
||||||
# print(f'Weighted: {average_output_tokens}, overall: {overall}')
|
|
||||||
|
|
||||||
estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero
|
|
||||||
redis.set('estimated_avg_tps', estimated_avg_tps)
|
|
||||||
time.sleep(60)
|
|
||||||
|
|
||||||
|
|
||||||
def cache_stats():
|
|
||||||
while True:
|
|
||||||
generate_stats(regen=True)
|
|
||||||
time.sleep(5)
|
|
||||||
|
|
||||||
|
|
||||||
redis_moderation = redis_redis.Redis()
|
|
||||||
|
|
||||||
|
|
||||||
def start_moderation_workers(num_workers):
|
|
||||||
for _ in range(num_workers):
|
|
||||||
t = threading.Thread(target=moderation_worker)
|
|
||||||
t.daemon = True
|
|
||||||
t.start()
|
|
||||||
|
|
||||||
|
|
||||||
def moderation_worker():
|
|
||||||
while True:
|
|
||||||
result = redis_moderation.blpop('queue:msgs_to_check')
|
|
||||||
try:
|
|
||||||
msg, tag = json.loads(result[1])
|
|
||||||
_, categories = check_moderation_endpoint(msg)
|
|
||||||
redis_moderation.rpush('queue:flagged_categories', json.dumps((tag, categories)))
|
|
||||||
except:
|
|
||||||
print(result)
|
|
||||||
traceback.print_exc()
|
|
||||||
continue
|
|
||||||
|
|
||||||
|
|
||||||
def add_moderation_task(msg, tag):
|
|
||||||
redis_moderation.rpush('queue:msgs_to_check', json.dumps((msg, str(tag))))
|
|
||||||
|
|
||||||
|
|
||||||
def get_results(tag, num_tasks):
|
|
||||||
tag = str(tag) # Required for comparison with Redis results.
|
|
||||||
flagged_categories = set()
|
|
||||||
num_results = 0
|
|
||||||
while num_results < num_tasks:
|
|
||||||
result = redis_moderation.blpop('queue:flagged_categories')
|
|
||||||
result_tag, categories = json.loads(result[1])
|
|
||||||
if result_tag == tag:
|
|
||||||
if categories:
|
|
||||||
for item in categories:
|
|
||||||
flagged_categories.add(item)
|
|
||||||
num_results += 1
|
|
||||||
return list(flagged_categories)
|
|
|
@ -0,0 +1,35 @@
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
from .blocking import start_workers
|
||||||
|
from .main import main_background_thread
|
||||||
|
from .moderator import start_moderation_workers
|
||||||
|
from .printer import console_printer
|
||||||
|
from .recent import recent_prompters_thread
|
||||||
|
from .threads import cache_stats
|
||||||
|
from .. import opts
|
||||||
|
|
||||||
|
|
||||||
|
def start_background():
|
||||||
|
start_workers(opts.concurrent_gens)
|
||||||
|
|
||||||
|
t = Thread(target=main_background_thread)
|
||||||
|
t.daemon = True
|
||||||
|
t.start()
|
||||||
|
print('Started the main background thread.')
|
||||||
|
|
||||||
|
start_moderation_workers(opts.openai_moderation_workers)
|
||||||
|
|
||||||
|
t = Thread(target=cache_stats)
|
||||||
|
t.daemon = True
|
||||||
|
t.start()
|
||||||
|
print('Started the stats cacher.')
|
||||||
|
|
||||||
|
t = Thread(target=recent_prompters_thread)
|
||||||
|
t.daemon = True
|
||||||
|
t.start()
|
||||||
|
print('Started the recent proompters thread.')
|
||||||
|
|
||||||
|
t = Thread(target=console_printer)
|
||||||
|
t.daemon = True
|
||||||
|
t.start()
|
||||||
|
print('Started the console printer.')
|
|
@ -1,4 +1,3 @@
|
||||||
import json
|
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
@ -17,12 +16,12 @@ def worker():
|
||||||
increment_ip_count(client_ip, 'processing_ips')
|
increment_ip_count(client_ip, 'processing_ips')
|
||||||
redis.incr('active_gen_workers')
|
redis.incr('active_gen_workers')
|
||||||
|
|
||||||
if not request_json_body:
|
|
||||||
# This was a dummy request from the websocket handler.
|
|
||||||
# We're going to let the websocket handler decrement processing_ips and active_gen_workers.
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if not request_json_body:
|
||||||
|
# This was a dummy request from the websocket handler.
|
||||||
|
# We're going to let the websocket handler decrement processing_ips and active_gen_workers.
|
||||||
|
continue
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
success, response, error_msg = generator(request_json_body)
|
success, response, error_msg = generator(request_json_body)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
|
@ -0,0 +1,56 @@
|
||||||
|
import time
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
from llm_server import opts
|
||||||
|
from llm_server.database.database import weighted_average_column_for_model
|
||||||
|
from llm_server.llm.info import get_running_model
|
||||||
|
from llm_server.routes.cache import redis
|
||||||
|
|
||||||
|
|
||||||
|
def main_background_thread():
|
||||||
|
redis.set('average_generation_elapsed_sec', 0)
|
||||||
|
redis.set('estimated_avg_tps', 0)
|
||||||
|
redis.set('average_output_tokens', 0)
|
||||||
|
redis.set('backend_online', 0)
|
||||||
|
redis.set_dict('backend_info', {})
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# TODO: unify this
|
||||||
|
if opts.mode == 'oobabooga':
|
||||||
|
running_model, err = get_running_model()
|
||||||
|
if err:
|
||||||
|
print(err)
|
||||||
|
redis.set('backend_online', 0)
|
||||||
|
else:
|
||||||
|
redis.set('running_model', running_model)
|
||||||
|
redis.set('backend_online', 1)
|
||||||
|
elif opts.mode == 'vllm':
|
||||||
|
running_model, err = get_running_model()
|
||||||
|
if err:
|
||||||
|
print(err)
|
||||||
|
redis.set('backend_online', 0)
|
||||||
|
else:
|
||||||
|
redis.set('running_model', running_model)
|
||||||
|
redis.set('backend_online', 1)
|
||||||
|
else:
|
||||||
|
raise Exception
|
||||||
|
|
||||||
|
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
|
||||||
|
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now.
|
||||||
|
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
|
||||||
|
if average_generation_elapsed_sec: # returns None on exception
|
||||||
|
redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec)
|
||||||
|
|
||||||
|
# overall = average_column_for_model('prompts', 'generation_time', opts.running_model)
|
||||||
|
# print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}')
|
||||||
|
|
||||||
|
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
|
||||||
|
if average_generation_elapsed_sec:
|
||||||
|
redis.set('average_output_tokens', average_output_tokens)
|
||||||
|
|
||||||
|
# overall = average_column_for_model('prompts', 'response_tokens', opts.running_model)
|
||||||
|
# print(f'Weighted: {average_output_tokens}, overall: {overall}')
|
||||||
|
|
||||||
|
estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero
|
||||||
|
redis.set('estimated_avg_tps', estimated_avg_tps)
|
||||||
|
time.sleep(60)
|
|
@ -0,0 +1,51 @@
|
||||||
|
import json
|
||||||
|
import threading
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import redis as redis_redis
|
||||||
|
|
||||||
|
from llm_server.llm.openai.moderation import check_moderation_endpoint
|
||||||
|
|
||||||
|
redis_moderation = redis_redis.Redis()
|
||||||
|
|
||||||
|
|
||||||
|
def start_moderation_workers(num_workers):
|
||||||
|
i = 0
|
||||||
|
for _ in range(num_workers):
|
||||||
|
t = threading.Thread(target=moderation_worker)
|
||||||
|
t.daemon = True
|
||||||
|
t.start()
|
||||||
|
i += 1
|
||||||
|
print(f'Started {i} moderation workers.')
|
||||||
|
|
||||||
|
|
||||||
|
def moderation_worker():
|
||||||
|
while True:
|
||||||
|
result = redis_moderation.blpop('queue:msgs_to_check')
|
||||||
|
try:
|
||||||
|
msg, tag = json.loads(result[1])
|
||||||
|
_, categories = check_moderation_endpoint(msg)
|
||||||
|
redis_moderation.rpush('queue:flagged_categories', json.dumps((tag, categories)))
|
||||||
|
except:
|
||||||
|
print(result)
|
||||||
|
traceback.print_exc()
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
def add_moderation_task(msg, tag):
|
||||||
|
redis_moderation.rpush('queue:msgs_to_check', json.dumps((msg, str(tag))))
|
||||||
|
|
||||||
|
|
||||||
|
def get_results(tag, num_tasks):
|
||||||
|
tag = str(tag) # Required for comparison with Redis results.
|
||||||
|
flagged_categories = set()
|
||||||
|
num_results = 0
|
||||||
|
while num_results < num_tasks:
|
||||||
|
result = redis_moderation.blpop('queue:flagged_categories')
|
||||||
|
result_tag, categories = json.loads(result[1])
|
||||||
|
if result_tag == tag:
|
||||||
|
if categories:
|
||||||
|
for item in categories:
|
||||||
|
flagged_categories.add(item)
|
||||||
|
num_results += 1
|
||||||
|
return list(flagged_categories)
|
|
@ -1,5 +1,4 @@
|
||||||
import logging
|
import logging
|
||||||
import threading
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from llm_server.routes.cache import redis
|
from llm_server.routes.cache import redis
|
||||||
|
@ -15,14 +14,9 @@ if not logger.handlers:
|
||||||
|
|
||||||
|
|
||||||
def console_printer():
|
def console_printer():
|
||||||
|
time.sleep(3)
|
||||||
while True:
|
while True:
|
||||||
queued_ip_count = sum([v for k, v in redis.get_dict('queued_ip_count').items()])
|
queued_ip_count = sum([v for k, v in redis.get_dict('queued_ip_count').items()])
|
||||||
processing_count = sum([v for k, v in redis.get_dict('processing_ips').items()])
|
processing_count = sum([v for k, v in redis.get_dict('processing_ips').items()])
|
||||||
logger.info(f'REQUEST QUEUE -> Processing: {processing_count} | Queued: {queued_ip_count}')
|
logger.info(f'REQUEST QUEUE -> Processing: {processing_count} | Queued: {queued_ip_count}')
|
||||||
time.sleep(10)
|
time.sleep(15)
|
||||||
|
|
||||||
|
|
||||||
def start_console_printer():
|
|
||||||
t = threading.Thread(target=console_printer)
|
|
||||||
t.daemon = True
|
|
||||||
t.start()
|
|
||||||
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
import time
|
||||||
|
|
||||||
|
from llm_server.routes.cache import redis
|
||||||
|
|
||||||
|
|
||||||
|
def recent_prompters_thread():
|
||||||
|
current_time = time.time()
|
||||||
|
recent_prompters = redis.get_dict('recent_prompters')
|
||||||
|
new_recent_prompters = {}
|
||||||
|
|
||||||
|
for ip, (timestamp, token) in recent_prompters.items():
|
||||||
|
if token and token.startswith('SYSTEM__'):
|
||||||
|
continue
|
||||||
|
if current_time - timestamp <= 300:
|
||||||
|
new_recent_prompters[ip] = timestamp, token
|
||||||
|
|
||||||
|
redis.set_dict('recent_prompters', new_recent_prompters)
|
||||||
|
redis.set('proompters_5_min', len(new_recent_prompters))
|
||||||
|
time.sleep(1)
|
|
@ -0,0 +1,9 @@
|
||||||
|
import time
|
||||||
|
|
||||||
|
from llm_server.routes.v1.generate_stats import generate_stats
|
||||||
|
|
||||||
|
|
||||||
|
def cache_stats():
|
||||||
|
while True:
|
||||||
|
generate_stats(regen=True)
|
||||||
|
time.sleep(5)
|
|
@ -5,7 +5,6 @@ flask_caching
|
||||||
requests~=2.31.0
|
requests~=2.31.0
|
||||||
tiktoken~=0.5.0
|
tiktoken~=0.5.0
|
||||||
gunicorn
|
gunicorn
|
||||||
redis~=5.0.0
|
|
||||||
gevent~=23.9.0.post1
|
gevent~=23.9.0.post1
|
||||||
async-timeout
|
async-timeout
|
||||||
flask-sock
|
flask-sock
|
||||||
|
@ -19,4 +18,4 @@ websockets~=11.0.3
|
||||||
basicauth~=1.0.0
|
basicauth~=1.0.0
|
||||||
openai~=0.28.0
|
openai~=0.28.0
|
||||||
urllib3~=2.0.4
|
urllib3~=2.0.4
|
||||||
rq~=1.15.1
|
celery[redis]
|
||||||
|
|
130
server.py
130
server.py
|
@ -1,3 +1,5 @@
|
||||||
|
from llm_server.config.config import mode_ui_names
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import gevent.monkey
|
import gevent.monkey
|
||||||
|
|
||||||
|
@ -5,28 +7,23 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
from llm_server.workers.printer import start_console_printer
|
from llm_server.pre_fork import server_startup
|
||||||
|
from llm_server.config.load import load_config
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from threading import Thread
|
|
||||||
|
|
||||||
import openai
|
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
from flask import Flask, jsonify, render_template, request
|
from flask import Flask, jsonify, render_template, request
|
||||||
from redis import Redis
|
|
||||||
|
|
||||||
import llm_server
|
import llm_server
|
||||||
from llm_server.database.conn import database
|
from llm_server.database.conn import database
|
||||||
from llm_server.database.create import create_db
|
from llm_server.database.create import create_db
|
||||||
from llm_server.database.database import get_number_of_rows
|
|
||||||
from llm_server.llm import get_token_count
|
from llm_server.llm import get_token_count
|
||||||
from llm_server.routes.openai import openai_bp
|
from llm_server.routes.openai import openai_bp
|
||||||
from llm_server.routes.server_error import handle_server_error
|
from llm_server.routes.server_error import handle_server_error
|
||||||
from llm_server.routes.v1 import bp
|
from llm_server.routes.v1 import bp
|
||||||
from llm_server.stream import init_socketio
|
from llm_server.stream import init_socketio
|
||||||
from llm_server.workers.blocking import start_workers
|
|
||||||
|
|
||||||
# TODO: have the workers handle streaming too
|
# TODO: have the workers handle streaming too
|
||||||
# TODO: add backend fallbacks. Backends at the bottom of the list are higher priority and are fallbacks if the upper ones fail
|
# TODO: add backend fallbacks. Backends at the bottom of the list are higher priority and are fallbacks if the upper ones fail
|
||||||
|
@ -62,16 +59,12 @@ except ModuleNotFoundError as e:
|
||||||
|
|
||||||
import config
|
import config
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
from llm_server.config import ConfigLoader, config_default_vars, config_required_vars, mode_ui_names
|
from llm_server.helpers import auto_set_base_client_api
|
||||||
from llm_server.helpers import resolve_path, auto_set_base_client_api
|
|
||||||
from llm_server.llm.vllm.info import vllm_info
|
from llm_server.llm.vllm.info import vllm_info
|
||||||
from llm_server.routes.cache import RedisWrapper, flask_cache
|
from llm_server.routes.cache import RedisWrapper, flask_cache
|
||||||
from llm_server.llm import redis
|
from llm_server.llm import redis
|
||||||
from llm_server.routes.stats import SemaphoreCheckerThread, get_active_gen_workers
|
from llm_server.routes.stats import get_active_gen_workers
|
||||||
from llm_server.routes.v1.generate_stats import generate_stats
|
from llm_server.routes.v1.generate_stats import generate_stats
|
||||||
from llm_server.threads import MainBackgroundThread, cache_stats, start_moderation_workers
|
|
||||||
|
|
||||||
script_path = os.path.dirname(os.path.realpath(__file__))
|
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
init_socketio(app)
|
init_socketio(app)
|
||||||
|
@ -80,123 +73,22 @@ app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
|
||||||
flask_cache.init_app(app)
|
flask_cache.init_app(app)
|
||||||
flask_cache.clear()
|
flask_cache.clear()
|
||||||
|
|
||||||
|
script_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
config_path_environ = os.getenv("CONFIG_PATH")
|
config_path_environ = os.getenv("CONFIG_PATH")
|
||||||
if config_path_environ:
|
if config_path_environ:
|
||||||
config_path = config_path_environ
|
config_path = config_path_environ
|
||||||
else:
|
else:
|
||||||
config_path = Path(script_path, 'config', 'config.yml')
|
config_path = Path(script_path, 'config', 'config.yml')
|
||||||
|
|
||||||
config_loader = ConfigLoader(config_path, config_default_vars, config_required_vars)
|
success, config, msg = load_config(config_path, script_path)
|
||||||
success, config, msg = config_loader.load_config()
|
|
||||||
if not success:
|
if not success:
|
||||||
print('Failed to load config:', msg)
|
print('Failed to load config:', msg)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# Resolve relative directory to the directory of the script
|
|
||||||
if config['database_path'].startswith('./'):
|
|
||||||
config['database_path'] = resolve_path(script_path, config['database_path'].strip('./'))
|
|
||||||
|
|
||||||
database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database'])
|
database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database'])
|
||||||
create_db()
|
create_db()
|
||||||
|
llm_server.llm.redis = RedisWrapper('local_llm')
|
||||||
if config['mode'] not in ['oobabooga', 'vllm']:
|
create_db()
|
||||||
print('Unknown mode:', config['mode'])
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# TODO: this is atrocious
|
|
||||||
opts.mode = config['mode']
|
|
||||||
opts.auth_required = config['auth_required']
|
|
||||||
opts.log_prompts = config['log_prompts']
|
|
||||||
opts.concurrent_gens = config['concurrent_gens']
|
|
||||||
opts.frontend_api_client = config['frontend_api_client']
|
|
||||||
opts.context_size = config['token_limit']
|
|
||||||
opts.show_num_prompts = config['show_num_prompts']
|
|
||||||
opts.show_uptime = config['show_uptime']
|
|
||||||
opts.backend_url = config['backend_url'].strip('/')
|
|
||||||
opts.show_total_output_tokens = config['show_total_output_tokens']
|
|
||||||
opts.netdata_root = config['netdata_root']
|
|
||||||
opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip']
|
|
||||||
opts.show_backend_info = config['show_backend_info']
|
|
||||||
opts.max_new_tokens = config['max_new_tokens']
|
|
||||||
opts.manual_model_name = config['manual_model_name']
|
|
||||||
opts.llm_middleware_name = config['llm_middleware_name']
|
|
||||||
opts.enable_openi_compatible_backend = config['enable_openi_compatible_backend']
|
|
||||||
opts.openai_system_prompt = config['openai_system_prompt']
|
|
||||||
opts.expose_openai_system_prompt = config['expose_openai_system_prompt']
|
|
||||||
opts.enable_streaming = config['enable_streaming']
|
|
||||||
opts.openai_api_key = config['openai_api_key']
|
|
||||||
openai.api_key = opts.openai_api_key
|
|
||||||
opts.admin_token = config['admin_token']
|
|
||||||
opts.openai_expose_our_model = config['openai_epose_our_model']
|
|
||||||
opts.openai_force_no_hashes = config['openai_force_no_hashes']
|
|
||||||
opts.include_system_tokens_in_stats = config['include_system_tokens_in_stats']
|
|
||||||
opts.openai_moderation_scan_last_n = config['openai_moderation_scan_last_n']
|
|
||||||
opts.openai_moderation_workers = config['openai_moderation_workers']
|
|
||||||
opts.openai_org_name = config['openai_org_name']
|
|
||||||
opts.openai_silent_trim = config['openai_silent_trim']
|
|
||||||
opts.openai_moderation_enabled = config['openai_moderation_enabled']
|
|
||||||
|
|
||||||
if opts.openai_expose_our_model and not opts.openai_api_key:
|
|
||||||
print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.')
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
opts.verify_ssl = config['verify_ssl']
|
|
||||||
if not opts.verify_ssl:
|
|
||||||
import urllib3
|
|
||||||
|
|
||||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
|
||||||
|
|
||||||
if config['average_generation_time_mode'] not in ['database', 'minute']:
|
|
||||||
print('Invalid value for config item "average_generation_time_mode":', config['average_generation_time_mode'])
|
|
||||||
sys.exit(1)
|
|
||||||
opts.average_generation_time_mode = config['average_generation_time_mode']
|
|
||||||
|
|
||||||
if opts.mode == 'oobabooga':
|
|
||||||
raise NotImplementedError
|
|
||||||
# llm_server.llm.tokenizer = OobaboogaBackend()
|
|
||||||
elif opts.mode == 'vllm':
|
|
||||||
llm_server.llm.get_token_count = llm_server.llm.vllm.tokenize
|
|
||||||
else:
|
|
||||||
raise Exception
|
|
||||||
|
|
||||||
|
|
||||||
def pre_fork(server):
|
|
||||||
llm_server.llm.redis = RedisWrapper('local_llm')
|
|
||||||
flushed_keys = redis.flush()
|
|
||||||
print('Flushed', len(flushed_keys), 'keys from Redis.')
|
|
||||||
|
|
||||||
redis.set_dict('processing_ips', {})
|
|
||||||
redis.set_dict('queued_ip_count', {})
|
|
||||||
|
|
||||||
# Flush the RedisPriorityQueue database.
|
|
||||||
queue_redis = Redis(host='localhost', port=6379, db=15)
|
|
||||||
for key in queue_redis.scan_iter('*'):
|
|
||||||
queue_redis.delete(key)
|
|
||||||
|
|
||||||
redis.set('backend_mode', opts.mode)
|
|
||||||
if config['http_host']:
|
|
||||||
http_host = re.sub(r'http(?:s)?://', '', config["http_host"])
|
|
||||||
redis.set('http_host', http_host)
|
|
||||||
redis.set('base_client_api', f'{http_host}/{opts.frontend_api_client.strip("/")}')
|
|
||||||
|
|
||||||
if config['load_num_prompts']:
|
|
||||||
redis.set('proompts', get_number_of_rows('prompts'))
|
|
||||||
|
|
||||||
# Start background processes
|
|
||||||
start_workers(opts.concurrent_gens)
|
|
||||||
start_console_printer()
|
|
||||||
start_moderation_workers(opts.openai_moderation_workers)
|
|
||||||
MainBackgroundThread().start()
|
|
||||||
SemaphoreCheckerThread().start()
|
|
||||||
|
|
||||||
# This needs to be started after Flask is initalized
|
|
||||||
stats_updater_thread = Thread(target=cache_stats)
|
|
||||||
stats_updater_thread.daemon = True
|
|
||||||
stats_updater_thread.start()
|
|
||||||
|
|
||||||
# Cache the initial stats
|
|
||||||
print('Loading backend stats...')
|
|
||||||
generate_stats()
|
|
||||||
|
|
||||||
|
|
||||||
# print(app.url_map)
|
# print(app.url_map)
|
||||||
|
@ -280,6 +172,6 @@ def before_app_request():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pre_fork(None)
|
server_startup(None)
|
||||||
print('FLASK MODE - Startup complete!')
|
print('FLASK MODE - Startup complete!')
|
||||||
app.run(host='0.0.0.0', threaded=False, processes=15)
|
app.run(host='0.0.0.0', threaded=False, processes=15)
|
||||||
|
|
Reference in New Issue