178 lines
6.4 KiB
Python
178 lines
6.4 KiB
Python
from llm_server.config.config import mode_ui_names
|
|
|
|
try:
|
|
import gevent.monkey
|
|
|
|
gevent.monkey.patch_all()
|
|
except ImportError:
|
|
pass
|
|
|
|
from llm_server.pre_fork import server_startup
|
|
from llm_server.config.load import load_config
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import simplejson as json
|
|
from flask import Flask, jsonify, render_template, request
|
|
|
|
import llm_server
|
|
from llm_server.database.conn import database
|
|
from llm_server.database.create import create_db
|
|
from llm_server.llm import get_token_count
|
|
from llm_server.routes.openai import openai_bp
|
|
from llm_server.routes.server_error import handle_server_error
|
|
from llm_server.routes.v1 import bp
|
|
from llm_server.stream import init_socketio
|
|
|
|
# TODO: have the workers handle streaming too
|
|
# TODO: add backend fallbacks. Backends at the bottom of the list are higher priority and are fallbacks if the upper ones fail
|
|
# TODO: implement background thread to test backends via sending test prompts
|
|
# TODO: if backend fails request, mark it as down
|
|
# TODO: allow setting concurrent gens per-backend
|
|
# TODO: set the max tokens to that of the lowest backend
|
|
# TODO: implement RRD backend loadbalancer option
|
|
# TODO: have VLLM reject a request if it already has n == concurrent_gens running
|
|
# TODO: add a way to cancel VLLM gens. Maybe use websockets?
|
|
# TODO: use coloredlogs
|
|
|
|
# Lower priority
|
|
# TODO: the processing stat showed -1 and I had to restart the server
|
|
# TODO: simulate OpenAI error messages regardless of endpoint
|
|
# TODO: send extra headers when ratelimited?
|
|
# TODO: make sure log_prompt() is used everywhere, including errors and invalid requests
|
|
# TODO: unify logging thread in a function and use async/await instead
|
|
# TODO: move the netdata stats to a seperate part of the stats and have it set to the currently selected backend
|
|
# TODO: have VLLM reply with stats (TPS, generated token count, processing time)
|
|
# TODO: add config reloading via stored redis variables
|
|
|
|
# Done, but need to verify
|
|
# TODO: add more excluding to SYSTEM__ tokens
|
|
# TODO: return 200 when returning formatted sillytavern error
|
|
|
|
try:
|
|
import vllm
|
|
except ModuleNotFoundError as e:
|
|
print('Could not import vllm-gptq:', e)
|
|
print('Please see README.md for install instructions.')
|
|
sys.exit(1)
|
|
|
|
import config
|
|
from llm_server import opts
|
|
from llm_server.helpers import auto_set_base_client_api
|
|
from llm_server.llm.vllm.info import vllm_info
|
|
from llm_server.routes.cache import RedisWrapper, flask_cache
|
|
from llm_server.llm import redis
|
|
from llm_server.routes.stats import get_active_gen_workers
|
|
from llm_server.routes.v1.generate_stats import generate_stats
|
|
|
|
app = Flask(__name__)
|
|
init_socketio(app)
|
|
app.register_blueprint(bp, url_prefix='/api/v1/')
|
|
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
|
|
flask_cache.init_app(app)
|
|
flask_cache.clear()
|
|
|
|
script_path = os.path.dirname(os.path.realpath(__file__))
|
|
config_path_environ = os.getenv("CONFIG_PATH")
|
|
if config_path_environ:
|
|
config_path = config_path_environ
|
|
else:
|
|
config_path = Path(script_path, 'config', 'config.yml')
|
|
|
|
success, config, msg = load_config(config_path, script_path)
|
|
if not success:
|
|
print('Failed to load config:', msg)
|
|
sys.exit(1)
|
|
|
|
database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database'])
|
|
create_db()
|
|
llm_server.llm.redis = RedisWrapper('local_llm')
|
|
create_db()
|
|
|
|
|
|
# print(app.url_map)
|
|
|
|
|
|
@app.route('/')
|
|
@app.route('/api')
|
|
@app.route('/api/openai')
|
|
@flask_cache.cached(timeout=10)
|
|
def home():
|
|
stats = generate_stats()
|
|
|
|
if not stats['online']:
|
|
running_model = estimated_wait_sec = 'offline'
|
|
else:
|
|
running_model = redis.get('running_model', str, 'ERROR')
|
|
|
|
active_gen_workers = get_active_gen_workers()
|
|
if stats['queue']['queued'] == 0 and active_gen_workers >= opts.concurrent_gens:
|
|
# There will be a wait if the queue is empty but prompts are processing, but we don't
|
|
# know how long.
|
|
estimated_wait_sec = f"less than {stats['stats']['average_generation_elapsed_sec']} seconds"
|
|
else:
|
|
estimated_wait_sec = f"{stats['queue']['estimated_wait_sec']} seconds"
|
|
|
|
if len(config['analytics_tracking_code']):
|
|
analytics_tracking_code = f"<script>\n{config['analytics_tracking_code']}\n</script>"
|
|
else:
|
|
analytics_tracking_code = ''
|
|
|
|
if config['info_html']:
|
|
info_html = config['info_html']
|
|
else:
|
|
info_html = ''
|
|
|
|
mode_info = ''
|
|
if opts.mode == 'vllm':
|
|
mode_info = vllm_info
|
|
|
|
base_client_api = redis.get('base_client_api', str)
|
|
|
|
return render_template('home.html',
|
|
llm_middleware_name=opts.llm_middleware_name,
|
|
analytics_tracking_code=analytics_tracking_code,
|
|
info_html=info_html,
|
|
current_model=opts.manual_model_name if opts.manual_model_name else running_model,
|
|
client_api=f'https://{base_client_api}',
|
|
ws_client_api=f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else None,
|
|
estimated_wait=estimated_wait_sec,
|
|
mode_name=mode_ui_names[opts.mode][0],
|
|
api_input_textbox=mode_ui_names[opts.mode][1],
|
|
streaming_input_textbox=mode_ui_names[opts.mode][2],
|
|
context_size=opts.context_size,
|
|
stats_json=json.dumps(stats, indent=4, ensure_ascii=False),
|
|
extra_info=mode_info,
|
|
openai_client_api=f'https://{base_client_api}/openai/v1' if opts.enable_openi_compatible_backend else 'disabled',
|
|
expose_openai_system_prompt=opts.expose_openai_system_prompt,
|
|
enable_streaming=opts.enable_streaming,
|
|
)
|
|
|
|
|
|
# TODO: add authenticated route to get the current backend URL. Add it to /v1/backend
|
|
|
|
@app.route('/<first>')
|
|
@app.route('/<first>/<path:rest>')
|
|
def fallback(first=None, rest=None):
|
|
return jsonify({
|
|
'code': 404,
|
|
'msg': 'not found'
|
|
}), 404
|
|
|
|
|
|
@app.errorhandler(500)
|
|
def server_error(e):
|
|
return handle_server_error(e)
|
|
|
|
|
|
@app.before_request
|
|
def before_app_request():
|
|
auto_set_base_client_api(request)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
server_startup(None)
|
|
print('FLASK MODE - Startup complete!')
|
|
app.run(host='0.0.0.0', threaded=False, processes=15)
|