local-llm-server/server.py

159 lines
5.8 KiB
Python
Raw Normal View History

2023-08-23 23:11:12 -06:00
import json
2023-08-21 21:28:52 -06:00
import os
import sys
from pathlib import Path
from threading import Thread
2023-08-21 21:28:52 -06:00
2023-08-23 23:11:12 -06:00
from flask import Flask, jsonify, render_template, request
2023-08-21 21:28:52 -06:00
2023-08-23 23:11:12 -06:00
import config
2023-08-21 21:28:52 -06:00
from llm_server import opts
2023-08-23 23:11:12 -06:00
from llm_server.config import ConfigLoader, config_default_vars, config_required_vars, mode_ui_names
from llm_server.database import get_number_of_rows, init_db
2023-08-21 21:28:52 -06:00
from llm_server.helpers import resolve_path
2023-08-29 14:00:35 -06:00
from llm_server.llm.hf_textgen.info import hf_textget_info
from llm_server.routes.cache import cache, redis
2023-08-23 20:33:49 -06:00
from llm_server.routes.queue import start_workers
2023-08-23 22:21:59 -06:00
from llm_server.routes.stats import SemaphoreCheckerThread, process_avg_gen_time
2023-08-21 21:28:52 -06:00
from llm_server.routes.v1 import bp
2023-08-23 23:11:12 -06:00
from llm_server.routes.v1.generate_stats import generate_stats
2023-08-29 17:56:12 -06:00
from llm_server.stream import init_socketio
2023-08-24 20:43:11 -06:00
from llm_server.threads import MainBackgroundThread
2023-08-21 21:28:52 -06:00
2023-08-21 23:07:12 -06:00
script_path = os.path.dirname(os.path.realpath(__file__))
2023-08-21 21:28:52 -06:00
config_path_environ = os.getenv("CONFIG_PATH")
if config_path_environ:
config_path = config_path_environ
else:
2023-08-21 23:07:12 -06:00
config_path = Path(script_path, 'config', 'config.yml')
2023-08-21 21:28:52 -06:00
2023-08-23 22:08:10 -06:00
config_loader = ConfigLoader(config_path, config_default_vars, config_required_vars)
2023-08-21 21:28:52 -06:00
success, config, msg = config_loader.load_config()
if not success:
print('Failed to load config:', msg)
sys.exit(1)
2023-08-21 23:07:12 -06:00
# Resolve relative directory to the directory of the script
if config['database_path'].startswith('./'):
config['database_path'] = resolve_path(script_path, config['database_path'].strip('./'))
2023-08-21 21:28:52 -06:00
opts.database_path = resolve_path(config['database_path'])
init_db()
2023-08-21 21:28:52 -06:00
if config['mode'] not in ['oobabooga', 'hf-textgen']:
print('Unknown mode:', config['mode'])
opts.mode = config['mode']
opts.auth_required = config['auth_required']
opts.log_prompts = config['log_prompts']
2023-08-22 00:26:46 -06:00
opts.concurrent_gens = config['concurrent_gens']
2023-08-22 16:50:49 -06:00
opts.frontend_api_client = config['frontend_api_client']
2023-08-22 20:42:38 -06:00
opts.context_size = config['token_limit']
2023-08-23 22:08:10 -06:00
opts.show_num_prompts = config['show_num_prompts']
opts.show_uptime = config['show_uptime']
opts.backend_url = config['backend_url'].strip('/')
2023-08-24 20:43:11 -06:00
opts.show_total_output_tokens = config['show_total_output_tokens']
opts.netdata_root = config['netdata_root']
opts.ip_in_queue_max = config['ip_in_queue_max']
opts.show_backend_info = config['show_backend_info']
2023-08-23 16:14:13 -06:00
2023-08-23 16:11:32 -06:00
opts.verify_ssl = config['verify_ssl']
2023-08-23 16:14:13 -06:00
if not opts.verify_ssl:
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
2023-08-21 21:28:52 -06:00
flushed_keys = redis.flush()
print('Flushed', len(flushed_keys), 'keys from Redis.')
if config['load_num_prompts']:
redis.set('proompts', get_number_of_rows('prompts'))
if config['average_generation_time_mode'] not in ['database', 'minute']:
print('Invalid value for config item "average_generation_time_mode":', config['average_generation_time_mode'])
sys.exit(1)
opts.average_generation_time_mode = config['average_generation_time_mode']
2023-08-23 20:33:49 -06:00
start_workers(opts.concurrent_gens)
# cleanup_thread = Thread(target=elapsed_times_cleanup)
# cleanup_thread.daemon = True
# cleanup_thread.start()
# Start the background thread
process_avg_gen_time_background_thread = Thread(target=process_avg_gen_time)
process_avg_gen_time_background_thread.daemon = True
process_avg_gen_time_background_thread.start()
2023-08-24 20:43:11 -06:00
MainBackgroundThread().start()
2023-08-23 20:33:49 -06:00
SemaphoreCheckerThread().start()
2023-08-21 21:28:52 -06:00
app = Flask(__name__)
cache.init_app(app)
cache.clear() # clear redis cache
2023-08-29 17:56:12 -06:00
init_socketio(app)
2023-08-21 21:28:52 -06:00
# with app.app_context():
# current_app.tokenizer = tiktoken.get_encoding("cl100k_base")
app.register_blueprint(bp, url_prefix='/api/v1/')
2023-08-21 22:49:44 -06:00
# print(app.url_map)
2023-08-21 21:28:52 -06:00
@app.route('/')
2023-08-23 23:11:12 -06:00
@app.route('/api')
2023-08-24 16:48:36 -06:00
@cache.cached(timeout=10, query_string=True)
2023-08-23 23:11:12 -06:00
def home():
2023-08-29 17:56:12 -06:00
if not opts.base_client_api:
opts.base_client_api = f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}'
stats = generate_stats()
2023-08-23 23:11:12 -06:00
if not bool(redis.get('backend_online')) or not stats['online']:
running_model = estimated_wait_sec = 'offline'
else:
running_model = opts.running_model
2023-08-27 22:24:44 -06:00
if stats['queue']['queued'] == 0 and stats['queue']['processing'] > 0:
# There will be a wait if the queue is empty but prompts are processing, but we don't
# know how long.
estimated_wait_sec = f"less than {stats['stats']['average_generation_elapsed_sec']} seconds"
else:
2023-08-27 22:24:44 -06:00
estimated_wait_sec = f"{stats['queue']['estimated_wait_sec']} seconds"
2023-08-23 23:11:12 -06:00
if len(config['analytics_tracking_code']):
analytics_tracking_code = f"<script>\n{config['analytics_tracking_code']}\n</script>"
else:
analytics_tracking_code = ''
if config['info_html']:
info_html = '<br>\n' + config['info_html']
else:
info_html = ''
2023-08-23 23:11:12 -06:00
return render_template('home.html',
llm_middleware_name=config['llm_middleware_name'],
analytics_tracking_code=analytics_tracking_code,
info_html=info_html,
2023-08-23 23:11:12 -06:00
current_model=running_model,
2023-08-29 17:56:12 -06:00
client_api=f'https://{opts.base_client_api}',
ws_client_api=f'wss://{opts.base_client_api}/v1/stream',
2023-08-23 23:11:12 -06:00
estimated_wait=estimated_wait_sec,
mode_name=mode_ui_names[opts.mode][0],
api_input_textbox=mode_ui_names[opts.mode][1],
2023-08-29 17:56:12 -06:00
streaming_input_textbox=mode_ui_names[opts.mode][2],
2023-08-23 23:11:12 -06:00
context_size=opts.context_size,
2023-08-29 14:00:35 -06:00
stats_json=json.dumps(stats, indent=4, ensure_ascii=False),
extra_info=hf_textget_info if opts.mode == 'hf-textgen' else '',
2023-08-23 23:11:12 -06:00
)
2023-08-21 21:28:52 -06:00
@app.route('/<first>')
@app.route('/<first>/<path:rest>')
def fallback(first=None, rest=None):
return jsonify({
'error': 404,
'msg': 'not found'
}), 404
if __name__ == "__main__":
app.run(host='0.0.0.0')