local-llm-server/server.py

270 lines
10 KiB
Python
Raw Normal View History

2023-09-26 22:09:11 -06:00
try:
import gevent.monkey
gevent.monkey.patch_all()
except ImportError:
pass
2023-08-21 21:28:52 -06:00
import os
import re
2023-08-21 21:28:52 -06:00
import sys
from pathlib import Path
from threading import Thread
2023-09-23 21:17:13 -06:00
import openai
import simplejson as json
2023-08-23 23:11:12 -06:00
from flask import Flask, jsonify, render_template, request
2023-08-21 21:28:52 -06:00
import llm_server
from llm_server.database.conn import database
from llm_server.database.create import create_db
from llm_server.database.database import get_number_of_rows
from llm_server.llm import get_token_count
2023-09-12 16:40:09 -06:00
from llm_server.routes.openai import openai_bp
from llm_server.routes.server_error import handle_server_error
2023-09-26 13:32:33 -06:00
from llm_server.routes.v1 import bp
from llm_server.stream import init_socketio
2023-09-27 14:36:49 -06:00
# TODO: have the workers handle streaming too
2023-09-27 16:12:36 -06:00
# TODO: add backend fallbacks. Backends at the bottom of the list are higher priority and are fallbacks if the upper ones fail
# TODO: implement background thread to test backends via sending test prompts
# TODO: if backend fails request, mark it as down
2023-09-26 22:09:11 -06:00
# TODO: allow setting concurrent gens per-backend
2023-09-27 16:12:36 -06:00
# TODO: set the max tokens to that of the lowest backend
# TODO: implement RRD backend loadbalancer option
2023-09-27 14:36:49 -06:00
# TODO: simulate OpenAI error messages regardless of endpoint
2023-09-27 16:12:36 -06:00
# TODO: send extra headers when ratelimited?
2023-09-25 18:18:29 -06:00
# TODO: make sure log_prompt() is used everywhere, including errors and invalid requests
# TODO: unify logging thread in a function and use async/await instead
2023-09-27 16:12:36 -06:00
# Done, but need to verify
# TODO: add more excluding to SYSTEM__ tokens
# TODO: return 200 when returning formatted sillytavern error
2023-09-12 01:05:03 -06:00
try:
import vllm
except ModuleNotFoundError as e:
print('Could not import vllm-gptq:', e)
2023-09-12 01:10:58 -06:00
print('Please see README.md for install instructions.')
2023-09-12 01:05:03 -06:00
sys.exit(1)
2023-08-23 23:11:12 -06:00
import config
2023-08-21 21:28:52 -06:00
from llm_server import opts
2023-08-23 23:11:12 -06:00
from llm_server.config import ConfigLoader, config_default_vars, config_required_vars, mode_ui_names
2023-09-23 23:24:08 -06:00
from llm_server.helpers import resolve_path, auto_set_base_client_api
from llm_server.llm.vllm.info import vllm_info
2023-09-26 22:09:11 -06:00
from llm_server.routes.cache import RedisWrapper, flask_cache
from llm_server.llm import redis
2023-08-23 20:33:49 -06:00
from llm_server.routes.queue import start_workers
2023-09-23 21:17:13 -06:00
from llm_server.routes.stats import SemaphoreCheckerThread, get_active_gen_workers, process_avg_gen_time
2023-08-23 23:11:12 -06:00
from llm_server.routes.v1.generate_stats import generate_stats
2023-09-26 22:09:11 -06:00
from llm_server.threads import MainBackgroundThread, cache_stats, start_moderation_workers
2023-08-21 21:28:52 -06:00
2023-08-21 23:07:12 -06:00
script_path = os.path.dirname(os.path.realpath(__file__))
2023-09-26 13:32:33 -06:00
app = Flask(__name__)
init_socketio(app)
app.register_blueprint(bp, url_prefix='/api/v1/')
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
2023-09-26 22:09:11 -06:00
flask_cache.init_app(app)
flask_cache.clear()
2023-09-26 13:32:33 -06:00
2023-08-21 21:28:52 -06:00
config_path_environ = os.getenv("CONFIG_PATH")
if config_path_environ:
config_path = config_path_environ
else:
2023-08-21 23:07:12 -06:00
config_path = Path(script_path, 'config', 'config.yml')
2023-08-21 21:28:52 -06:00
2023-08-23 22:08:10 -06:00
config_loader = ConfigLoader(config_path, config_default_vars, config_required_vars)
2023-08-21 21:28:52 -06:00
success, config, msg = config_loader.load_config()
if not success:
print('Failed to load config:', msg)
sys.exit(1)
2023-08-21 23:07:12 -06:00
# Resolve relative directory to the directory of the script
if config['database_path'].startswith('./'):
config['database_path'] = resolve_path(script_path, config['database_path'].strip('./'))
database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database'])
create_db()
2023-08-21 21:28:52 -06:00
if config['mode'] not in ['oobabooga', 'vllm']:
2023-08-21 21:28:52 -06:00
print('Unknown mode:', config['mode'])
sys.exit(1)
2023-09-23 23:24:08 -06:00
2023-09-26 22:09:11 -06:00
# TODO: this is atrocious
2023-08-21 21:28:52 -06:00
opts.mode = config['mode']
opts.auth_required = config['auth_required']
opts.log_prompts = config['log_prompts']
2023-08-22 00:26:46 -06:00
opts.concurrent_gens = config['concurrent_gens']
2023-08-22 16:50:49 -06:00
opts.frontend_api_client = config['frontend_api_client']
2023-08-22 20:42:38 -06:00
opts.context_size = config['token_limit']
2023-08-23 22:08:10 -06:00
opts.show_num_prompts = config['show_num_prompts']
opts.show_uptime = config['show_uptime']
opts.backend_url = config['backend_url'].strip('/')
2023-08-24 20:43:11 -06:00
opts.show_total_output_tokens = config['show_total_output_tokens']
opts.netdata_root = config['netdata_root']
2023-09-11 20:47:19 -06:00
opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip']
opts.show_backend_info = config['show_backend_info']
2023-08-30 18:53:26 -06:00
opts.max_new_tokens = config['max_new_tokens']
opts.manual_model_name = config['manual_model_name']
2023-09-12 16:40:09 -06:00
opts.llm_middleware_name = config['llm_middleware_name']
opts.enable_openi_compatible_backend = config['enable_openi_compatible_backend']
opts.openai_system_prompt = config['openai_system_prompt']
2023-09-13 20:25:56 -06:00
opts.expose_openai_system_prompt = config['expose_openai_system_prompt']
opts.enable_streaming = config['enable_streaming']
opts.openai_api_key = config['openai_api_key']
openai.api_key = opts.openai_api_key
opts.admin_token = config['admin_token']
opts.openai_expose_our_model = config['openai_epose_our_model']
2023-09-25 22:01:57 -06:00
opts.openai_force_no_hashes = config['openai_force_no_hashes']
2023-09-25 23:39:50 -06:00
opts.include_system_tokens_in_stats = config['include_system_tokens_in_stats']
2023-09-26 22:09:11 -06:00
opts.openai_moderation_scan_last_n = config['openai_moderation_scan_last_n']
opts.openai_moderation_workers = config['openai_moderation_workers']
opts.openai_org_name = config['openai_org_name']
opts.openai_silent_trim = config['openai_silent_trim']
opts.openai_moderation_enabled = config['openai_moderation_enabled']
if opts.openai_expose_our_model and not opts.openai_api_key:
print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.')
sys.exit(1)
2023-08-23 16:11:32 -06:00
opts.verify_ssl = config['verify_ssl']
2023-08-23 16:14:13 -06:00
if not opts.verify_ssl:
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
2023-08-21 21:28:52 -06:00
if config['average_generation_time_mode'] not in ['database', 'minute']:
print('Invalid value for config item "average_generation_time_mode":', config['average_generation_time_mode'])
sys.exit(1)
opts.average_generation_time_mode = config['average_generation_time_mode']
if opts.mode == 'oobabooga':
raise NotImplementedError
# llm_server.llm.tokenizer = OobaboogaBackend()
elif opts.mode == 'vllm':
llm_server.llm.get_token_count = llm_server.llm.vllm.tokenize
else:
raise Exception
2023-09-26 13:32:33 -06:00
def pre_fork(server):
2023-09-26 22:09:11 -06:00
llm_server.llm.redis = RedisWrapper('local_llm')
2023-09-26 13:32:33 -06:00
flushed_keys = redis.flush()
print('Flushed', len(flushed_keys), 'keys from Redis.')
2023-09-26 13:32:33 -06:00
redis.set('backend_mode', opts.mode)
if config['http_host']:
http_host = re.sub(r'http(?:s)?://', '', config["http_host"])
redis.set('http_host', http_host)
redis.set('base_client_api', f'{http_host}/{opts.frontend_api_client.strip("/")}')
2023-09-26 13:32:33 -06:00
if config['load_num_prompts']:
redis.set('proompts', get_number_of_rows('prompts'))
# Start background processes
start_workers(opts.concurrent_gens)
2023-09-26 22:09:11 -06:00
start_moderation_workers(opts.openai_moderation_workers)
2023-09-26 13:32:33 -06:00
process_avg_gen_time_background_thread = Thread(target=process_avg_gen_time)
process_avg_gen_time_background_thread.daemon = True
process_avg_gen_time_background_thread.start()
MainBackgroundThread().start()
SemaphoreCheckerThread().start()
# This needs to be started after Flask is initalized
stats_updater_thread = Thread(target=cache_stats)
stats_updater_thread.daemon = True
stats_updater_thread.start()
2023-08-21 21:28:52 -06:00
2023-09-26 13:32:33 -06:00
# Cache the initial stats
print('Loading backend stats...')
generate_stats()
2023-09-17 18:55:36 -06:00
2023-08-21 22:49:44 -06:00
# print(app.url_map)
2023-08-21 21:28:52 -06:00
@app.route('/')
2023-08-23 23:11:12 -06:00
@app.route('/api')
@app.route('/api/openai')
2023-09-26 22:09:11 -06:00
@flask_cache.cached(timeout=10)
2023-08-23 23:11:12 -06:00
def home():
stats = generate_stats()
2023-08-23 23:11:12 -06:00
if not stats['online']:
2023-08-23 23:11:12 -06:00
running_model = estimated_wait_sec = 'offline'
else:
2023-09-26 13:32:33 -06:00
running_model = redis.get('running_model', str, 'ERROR')
2023-09-23 21:17:13 -06:00
active_gen_workers = get_active_gen_workers()
if stats['queue']['queued'] == 0 and active_gen_workers >= opts.concurrent_gens:
# There will be a wait if the queue is empty but prompts are processing, but we don't
# know how long.
estimated_wait_sec = f"less than {stats['stats']['average_generation_elapsed_sec']} seconds"
else:
2023-08-27 22:24:44 -06:00
estimated_wait_sec = f"{stats['queue']['estimated_wait_sec']} seconds"
2023-08-23 23:11:12 -06:00
if len(config['analytics_tracking_code']):
analytics_tracking_code = f"<script>\n{config['analytics_tracking_code']}\n</script>"
else:
analytics_tracking_code = ''
if config['info_html']:
2023-09-13 20:40:55 -06:00
info_html = config['info_html']
else:
info_html = ''
mode_info = ''
if opts.mode == 'vllm':
mode_info = vllm_info
base_client_api = redis.get('base_client_api', str)
2023-09-17 18:55:36 -06:00
2023-08-23 23:11:12 -06:00
return render_template('home.html',
2023-09-12 16:40:09 -06:00
llm_middleware_name=opts.llm_middleware_name,
analytics_tracking_code=analytics_tracking_code,
info_html=info_html,
current_model=opts.manual_model_name if opts.manual_model_name else running_model,
client_api=f'https://{base_client_api}',
ws_client_api=f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else None,
2023-08-23 23:11:12 -06:00
estimated_wait=estimated_wait_sec,
mode_name=mode_ui_names[opts.mode][0],
api_input_textbox=mode_ui_names[opts.mode][1],
2023-08-29 17:56:12 -06:00
streaming_input_textbox=mode_ui_names[opts.mode][2],
2023-08-23 23:11:12 -06:00
context_size=opts.context_size,
2023-08-29 14:00:35 -06:00
stats_json=json.dumps(stats, indent=4, ensure_ascii=False),
extra_info=mode_info,
2023-09-17 18:55:36 -06:00
openai_client_api=f'https://{base_client_api}/openai/v1' if opts.enable_openi_compatible_backend else 'disabled',
expose_openai_system_prompt=opts.expose_openai_system_prompt,
enable_streaming=opts.enable_streaming,
2023-08-23 23:11:12 -06:00
)
# TODO: add authenticated route to get the current backend URL. Add it to /v1/backend
2023-08-21 21:28:52 -06:00
@app.route('/<first>')
@app.route('/<first>/<path:rest>')
def fallback(first=None, rest=None):
return jsonify({
2023-08-30 18:53:26 -06:00
'code': 404,
2023-08-21 21:28:52 -06:00
'msg': 'not found'
}), 404
@app.errorhandler(500)
def server_error(e):
return handle_server_error(e)
2023-09-17 18:55:36 -06:00
@app.before_request
def before_app_request():
2023-09-23 23:24:08 -06:00
auto_set_base_client_api(request)
2023-09-17 18:55:36 -06:00
2023-08-21 21:28:52 -06:00
if __name__ == "__main__":
2023-09-26 22:09:11 -06:00
pre_fork(None)
print('FLASK MODE - Startup complete!')
app.run(host='0.0.0.0', threaded=False, processes=15)