cache stats in background

This commit is contained in:
Cyberes 2023-09-17 18:55:36 -06:00
parent edf13db324
commit 3c1254d3bf
7 changed files with 62 additions and 30 deletions

View File

@ -12,7 +12,6 @@ database_path = './proxy-server.db'
auth_required = False
log_prompts = False
frontend_api_client = ''
base_client_api = None
http_host = None
verify_ssl = True
show_num_prompts = True

View File

@ -1,5 +1,6 @@
from flask import Blueprint, request
from ..cache import redis
from ..helpers.client import format_sillytavern_err
from ..helpers.http import require_api_key
from ..openai_request_handler import build_openai_response
@ -10,13 +11,14 @@ openai_bp = Blueprint('openai/v1/', __name__)
@openai_bp.before_request
def before_request():
def before_oai_request():
# TODO: unify with normal before_request()
if not opts.http_host:
opts.http_host = request.headers.get("Host")
if not opts.enable_openi_compatible_backend:
return build_openai_response('', format_sillytavern_err('The OpenAI-compatible backend is disabled.', 'Access Denied')), 401
if not opts.base_client_api:
opts.base_client_api = f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}'
if not redis.get('base_client_api'):
redis.set('base_client_api', f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}')
if request.endpoint != 'v1.get_stats':
response = require_api_key()
if response is not None:

View File

@ -3,7 +3,7 @@ import time
from typing import Tuple, Union
import flask
from flask import Response
from flask import Response, request
from llm_server import opts
from llm_server.database import log_prompt
@ -11,7 +11,7 @@ from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
from llm_server.llm.vllm.vllm_backend import VLLMBackend
from llm_server.routes.cache import redis
from llm_server.routes.helpers.client import format_sillytavern_err
from llm_server.routes.helpers.http import validate_json
from llm_server.routes.helpers.http import require_api_key, validate_json
from llm_server.routes.queue import priority_queue
from llm_server.routes.stats import SemaphoreCheckerThread
@ -178,3 +178,14 @@ def delete_dict_key(d: dict, k: Union[str, list]):
else:
raise ValueError
return d
def before_request():
if not opts.http_host:
opts.http_host = request.headers.get("Host")
if not redis.get('base_client_api'):
redis.set('base_client_api', f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}')
if request.endpoint != 'v1.get_stats':
response = require_api_key()
if response is not None:
return response

View File

@ -1,22 +1,14 @@
from flask import Blueprint, request
from flask import Blueprint
from ..helpers.http import require_api_key
from ..request_handler import before_request
from ..server_error import handle_server_error
from ... import opts
bp = Blueprint('v1', __name__)
@bp.before_request
def before_request():
if not opts.http_host:
opts.http_host = request.headers.get("Host")
if not opts.base_client_api:
opts.base_client_api = f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}'
if request.endpoint != 'v1.get_stats':
response = require_api_key()
if response is not None:
return response
def before_bp_request():
return before_request()
@bp.errorhandler(500)

View File

@ -1,12 +1,14 @@
import time
from datetime import datetime
from flask import request
from llm_server import opts
from llm_server.database import get_distinct_ips_24h, sum_column
from llm_server.helpers import deep_sort, round_up_base
from llm_server.llm.info import get_running_model
from llm_server.netdata import get_power_states
from llm_server.routes.cache import redis
from llm_server.routes.cache import cache, redis
from llm_server.routes.queue import priority_queue
from llm_server.routes.stats import SemaphoreCheckerThread, calculate_avg_gen_time, get_active_gen_workers, get_total_proompts, server_start_time
@ -34,6 +36,7 @@ def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, act
# TODO: have routes/__init__.py point to the latest API version generate_stats()
@cache.memoize(timeout=10)
def generate_stats():
model_name, error = get_running_model() # will return False when the fetch fails
if isinstance(model_name, bool):
@ -83,6 +86,10 @@ def generate_stats():
else:
netdata_stats = {}
x = redis.get('base_client_api')
base_client_api = x.decode() if x else None
del x
output = {
'stats': {
'proompters': {
@ -98,8 +105,8 @@ def generate_stats():
},
'online': online,
'endpoints': {
'blocking': f'https://{opts.base_client_api}',
'streaming': f'wss://{opts.base_client_api}/v1/stream' if opts.enable_streaming else None,
'blocking': f'https://{base_client_api}',
'streaming': f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else None,
},
'queue': {
'processing': active_gen_workers,

View File

@ -5,6 +5,7 @@ from llm_server import opts
from llm_server.database import weighted_average_column_for_model
from llm_server.llm.info import get_running_model
from llm_server.routes.cache import redis
from llm_server.routes.v1.generate_stats import generate_stats
class MainBackgroundThread(Thread):
@ -60,3 +61,12 @@ class MainBackgroundThread(Thread):
estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero
redis.set('estimated_avg_tps', estimated_avg_tps)
time.sleep(60)
def cache_stats():
while True:
# If opts.base_client_api is null that means no one has visited the site yet
# and the base_client_api hasn't been set. Do nothing until then.
if redis.get('base_client_api'):
x = generate_stats()
time.sleep(5)

View File

@ -28,7 +28,7 @@ from llm_server.routes.stats import SemaphoreCheckerThread, process_avg_gen_time
from llm_server.routes.v1 import bp
from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.stream import init_socketio
from llm_server.threads import MainBackgroundThread
from llm_server.threads import MainBackgroundThread, cache_stats
script_path = os.path.dirname(os.path.realpath(__file__))
@ -93,12 +93,8 @@ if config['average_generation_time_mode'] not in ['database', 'minute']:
sys.exit(1)
opts.average_generation_time_mode = config['average_generation_time_mode']
# Start background processes
start_workers(opts.concurrent_gens)
# cleanup_thread = Thread(target=elapsed_times_cleanup)
# cleanup_thread.daemon = True
# cleanup_thread.start()
# Start the background thread
process_avg_gen_time_background_thread = Thread(target=process_avg_gen_time)
process_avg_gen_time_background_thread.daemon = True
process_avg_gen_time_background_thread.start()
@ -112,6 +108,11 @@ init_socketio(app)
app.register_blueprint(bp, url_prefix='/api/v1/')
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
# This needs to be started after Flask is initalized
stats_updater_thread = Thread(target=cache_stats)
stats_updater_thread.daemon = True
stats_updater_thread.start()
# print(app.url_map)
@ -121,8 +122,6 @@ app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
@app.route('/api/openai')
@cache.cached(timeout=10)
def home():
if not opts.base_client_api:
opts.base_client_api = f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}'
stats = generate_stats()
if not bool(redis.get('backend_online')) or not stats['online']:
@ -151,6 +150,10 @@ def home():
if opts.mode == 'vllm':
mode_info = vllm_info
x = redis.get('base_client_api')
base_client_api = x.decode() if x else None
del x
return render_template('home.html',
llm_middleware_name=opts.llm_middleware_name,
analytics_tracking_code=analytics_tracking_code,
@ -165,7 +168,7 @@ def home():
context_size=opts.context_size,
stats_json=json.dumps(stats, indent=4, ensure_ascii=False),
extra_info=mode_info,
openai_client_api=f'https://{opts.base_client_api}/openai/v1' if opts.enable_openi_compatible_backend else 'disabled',
openai_client_api=f'https://{base_client_api}/openai/v1' if opts.enable_openi_compatible_backend else 'disabled',
expose_openai_system_prompt=opts.expose_openai_system_prompt,
enable_streaming=opts.enable_streaming,
)
@ -185,5 +188,13 @@ def server_error(e):
return handle_server_error(e)
@app.before_request
def before_app_request():
if not opts.http_host:
opts.http_host = request.headers.get("Host")
if not redis.get('base_client_api'):
redis.set('base_client_api', f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}')
if __name__ == "__main__":
app.run(host='0.0.0.0', threaded=False, processes=15)