local-llm-server/server.py

245 lines
9.1 KiB
Python

import os
import re
import sys
from pathlib import Path
from threading import Thread
import openai
import simplejson as json
from flask import Flask, jsonify, render_template, request
import llm_server
from llm_server.database.conn import db_pool
from llm_server.database.create import create_db
from llm_server.database.database import get_number_of_rows
from llm_server.llm import get_token_count
from llm_server.routes.openai import openai_bp
from llm_server.routes.server_error import handle_server_error
# TODO: allow setting more custom ratelimits per-token
# TODO: add more excluding to SYSTEM__ tokens
# TODO: make sure the OpenAI moderation endpoint scans the last n messages rather than only the last one (make that threaded)
# TODO: support turbo-instruct on openai endpoint
# TODO: option to trim context in openai mode so that we silently fit the model's context
# TODO: validate system tokens before excluding them
# TODO: unify logging thread in a function and use async/await instead
# TODO: make sure log_prompt() is used everywhere, including errors and invalid requests
try:
import vllm
except ModuleNotFoundError as e:
print('Could not import vllm-gptq:', e)
print('Please see README.md for install instructions.')
sys.exit(1)
import config
from llm_server import opts
from llm_server.config import ConfigLoader, config_default_vars, config_required_vars, mode_ui_names
from llm_server.helpers import resolve_path, auto_set_base_client_api
from llm_server.llm.vllm.info import vllm_info
from llm_server.routes.cache import cache, redis
from llm_server.routes.queue import start_workers
from llm_server.routes.stats import SemaphoreCheckerThread, get_active_gen_workers, process_avg_gen_time
from llm_server.routes.v1 import bp
from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.stream import init_socketio
from llm_server.threads import MainBackgroundThread, cache_stats
script_path = os.path.dirname(os.path.realpath(__file__))
config_path_environ = os.getenv("CONFIG_PATH")
if config_path_environ:
config_path = config_path_environ
else:
config_path = Path(script_path, 'config', 'config.yml')
config_loader = ConfigLoader(config_path, config_default_vars, config_required_vars)
success, config, msg = config_loader.load_config()
if not success:
print('Failed to load config:', msg)
sys.exit(1)
# Resolve relative directory to the directory of the script
if config['database_path'].startswith('./'):
config['database_path'] = resolve_path(script_path, config['database_path'].strip('./'))
db_pool.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database'])
create_db()
if config['mode'] not in ['oobabooga', 'vllm']:
print('Unknown mode:', config['mode'])
sys.exit(1)
flushed_keys = redis.flush()
print('Flushed', len(flushed_keys), 'keys from Redis.')
# TODO: this is a MESS
opts.mode = config['mode']
opts.auth_required = config['auth_required']
opts.log_prompts = config['log_prompts']
opts.concurrent_gens = config['concurrent_gens']
opts.frontend_api_client = config['frontend_api_client']
opts.context_size = config['token_limit']
opts.show_num_prompts = config['show_num_prompts']
opts.show_uptime = config['show_uptime']
opts.backend_url = config['backend_url'].strip('/')
opts.show_total_output_tokens = config['show_total_output_tokens']
opts.netdata_root = config['netdata_root']
opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip']
opts.show_backend_info = config['show_backend_info']
opts.max_new_tokens = config['max_new_tokens']
opts.manual_model_name = config['manual_model_name']
opts.llm_middleware_name = config['llm_middleware_name']
opts.enable_openi_compatible_backend = config['enable_openi_compatible_backend']
opts.openai_system_prompt = config['openai_system_prompt']
opts.expose_openai_system_prompt = config['expose_openai_system_prompt']
opts.enable_streaming = config['enable_streaming']
opts.openai_api_key = config['openai_api_key']
openai.api_key = opts.openai_api_key
opts.admin_token = config['admin_token']
opts.openai_expose_our_model = config['openai_epose_our_model']
opts.openai_force_no_hashes = config['openai_force_no_hashes']
if opts.openai_expose_our_model and not opts.openai_api_key:
print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.')
sys.exit(1)
if config['http_host']:
http_host = re.sub(r'http(?:s)?://', '', config["http_host"])
redis.set('http_host', http_host)
redis.set('base_client_api', f'{http_host}/{opts.frontend_api_client.strip("/")}')
print('Set host to', redis.get('http_host', str))
opts.verify_ssl = config['verify_ssl']
if not opts.verify_ssl:
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
redis.set('backend_mode', opts.mode)
if config['load_num_prompts']:
redis.set('proompts', get_number_of_rows('prompts'))
if config['average_generation_time_mode'] not in ['database', 'minute']:
print('Invalid value for config item "average_generation_time_mode":', config['average_generation_time_mode'])
sys.exit(1)
opts.average_generation_time_mode = config['average_generation_time_mode']
if opts.mode == 'oobabooga':
raise NotImplementedError
# llm_server.llm.tokenizer = OobaboogaBackend()
elif opts.mode == 'vllm':
llm_server.llm.get_token_count = llm_server.llm.vllm.tokenize
else:
raise Exception
app = Flask(__name__)
cache.init_app(app)
cache.clear() # clear redis cache
# Start background processes
start_workers(opts.concurrent_gens)
process_avg_gen_time_background_thread = Thread(target=process_avg_gen_time)
process_avg_gen_time_background_thread.daemon = True
process_avg_gen_time_background_thread.start()
MainBackgroundThread().start()
SemaphoreCheckerThread().start()
# Cache the initial stats
print('Loading backend stats...')
generate_stats()
init_socketio(app)
app.register_blueprint(bp, url_prefix='/api/v1/')
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
# This needs to be started after Flask is initalized
stats_updater_thread = Thread(target=cache_stats)
stats_updater_thread.daemon = True
stats_updater_thread.start()
# print(app.url_map)
@app.route('/')
@app.route('/api')
@app.route('/api/openai')
@cache.cached(timeout=10)
def home():
stats = generate_stats()
if not stats['online']:
running_model = estimated_wait_sec = 'offline'
else:
running_model = opts.running_model
active_gen_workers = get_active_gen_workers()
if stats['queue']['queued'] == 0 and active_gen_workers >= opts.concurrent_gens:
# There will be a wait if the queue is empty but prompts are processing, but we don't
# know how long.
estimated_wait_sec = f"less than {stats['stats']['average_generation_elapsed_sec']} seconds"
else:
estimated_wait_sec = f"{stats['queue']['estimated_wait_sec']} seconds"
if len(config['analytics_tracking_code']):
analytics_tracking_code = f"<script>\n{config['analytics_tracking_code']}\n</script>"
else:
analytics_tracking_code = ''
if config['info_html']:
info_html = config['info_html']
else:
info_html = ''
mode_info = ''
if opts.mode == 'vllm':
mode_info = vllm_info
base_client_api = redis.get('base_client_api', str)
return render_template('home.html',
llm_middleware_name=opts.llm_middleware_name,
analytics_tracking_code=analytics_tracking_code,
info_html=info_html,
current_model=opts.manual_model_name if opts.manual_model_name else running_model,
client_api=f'https://{base_client_api}',
ws_client_api=f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else None,
estimated_wait=estimated_wait_sec,
mode_name=mode_ui_names[opts.mode][0],
api_input_textbox=mode_ui_names[opts.mode][1],
streaming_input_textbox=mode_ui_names[opts.mode][2],
context_size=opts.context_size,
stats_json=json.dumps(stats, indent=4, ensure_ascii=False),
extra_info=mode_info,
openai_client_api=f'https://{base_client_api}/openai/v1' if opts.enable_openi_compatible_backend else 'disabled',
expose_openai_system_prompt=opts.expose_openai_system_prompt,
enable_streaming=opts.enable_streaming,
)
# TODO: add authenticated route to get the current backend URL. Add it to /v1/backend
@app.route('/<first>')
@app.route('/<first>/<path:rest>')
def fallback(first=None, rest=None):
return jsonify({
'code': 404,
'msg': 'not found'
}), 404
@app.errorhandler(500)
def server_error(e):
return handle_server_error(e)
@app.before_request
def before_app_request():
auto_set_base_client_api(request)
if __name__ == "__main__":
app.run(host='0.0.0.0', threaded=False, processes=15)