cache stats in background
This commit is contained in:
parent
edf13db324
commit
3c1254d3bf
|
@ -12,7 +12,6 @@ database_path = './proxy-server.db'
|
|||
auth_required = False
|
||||
log_prompts = False
|
||||
frontend_api_client = ''
|
||||
base_client_api = None
|
||||
http_host = None
|
||||
verify_ssl = True
|
||||
show_num_prompts = True
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from flask import Blueprint, request
|
||||
|
||||
from ..cache import redis
|
||||
from ..helpers.client import format_sillytavern_err
|
||||
from ..helpers.http import require_api_key
|
||||
from ..openai_request_handler import build_openai_response
|
||||
|
@ -10,13 +11,14 @@ openai_bp = Blueprint('openai/v1/', __name__)
|
|||
|
||||
|
||||
@openai_bp.before_request
|
||||
def before_request():
|
||||
def before_oai_request():
|
||||
# TODO: unify with normal before_request()
|
||||
if not opts.http_host:
|
||||
opts.http_host = request.headers.get("Host")
|
||||
if not opts.enable_openi_compatible_backend:
|
||||
return build_openai_response('', format_sillytavern_err('The OpenAI-compatible backend is disabled.', 'Access Denied')), 401
|
||||
if not opts.base_client_api:
|
||||
opts.base_client_api = f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}'
|
||||
if not redis.get('base_client_api'):
|
||||
redis.set('base_client_api', f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}')
|
||||
if request.endpoint != 'v1.get_stats':
|
||||
response = require_api_key()
|
||||
if response is not None:
|
||||
|
|
|
@ -3,7 +3,7 @@ import time
|
|||
from typing import Tuple, Union
|
||||
|
||||
import flask
|
||||
from flask import Response
|
||||
from flask import Response, request
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database import log_prompt
|
||||
|
@ -11,7 +11,7 @@ from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
|
|||
from llm_server.llm.vllm.vllm_backend import VLLMBackend
|
||||
from llm_server.routes.cache import redis
|
||||
from llm_server.routes.helpers.client import format_sillytavern_err
|
||||
from llm_server.routes.helpers.http import validate_json
|
||||
from llm_server.routes.helpers.http import require_api_key, validate_json
|
||||
from llm_server.routes.queue import priority_queue
|
||||
from llm_server.routes.stats import SemaphoreCheckerThread
|
||||
|
||||
|
@ -178,3 +178,14 @@ def delete_dict_key(d: dict, k: Union[str, list]):
|
|||
else:
|
||||
raise ValueError
|
||||
return d
|
||||
|
||||
|
||||
def before_request():
|
||||
if not opts.http_host:
|
||||
opts.http_host = request.headers.get("Host")
|
||||
if not redis.get('base_client_api'):
|
||||
redis.set('base_client_api', f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}')
|
||||
if request.endpoint != 'v1.get_stats':
|
||||
response = require_api_key()
|
||||
if response is not None:
|
||||
return response
|
||||
|
|
|
@ -1,22 +1,14 @@
|
|||
from flask import Blueprint, request
|
||||
from flask import Blueprint
|
||||
|
||||
from ..helpers.http import require_api_key
|
||||
from ..request_handler import before_request
|
||||
from ..server_error import handle_server_error
|
||||
from ... import opts
|
||||
|
||||
bp = Blueprint('v1', __name__)
|
||||
|
||||
|
||||
@bp.before_request
|
||||
def before_request():
|
||||
if not opts.http_host:
|
||||
opts.http_host = request.headers.get("Host")
|
||||
if not opts.base_client_api:
|
||||
opts.base_client_api = f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}'
|
||||
if request.endpoint != 'v1.get_stats':
|
||||
response = require_api_key()
|
||||
if response is not None:
|
||||
return response
|
||||
def before_bp_request():
|
||||
return before_request()
|
||||
|
||||
|
||||
@bp.errorhandler(500)
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from flask import request
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database import get_distinct_ips_24h, sum_column
|
||||
from llm_server.helpers import deep_sort, round_up_base
|
||||
from llm_server.llm.info import get_running_model
|
||||
from llm_server.netdata import get_power_states
|
||||
from llm_server.routes.cache import redis
|
||||
from llm_server.routes.cache import cache, redis
|
||||
from llm_server.routes.queue import priority_queue
|
||||
from llm_server.routes.stats import SemaphoreCheckerThread, calculate_avg_gen_time, get_active_gen_workers, get_total_proompts, server_start_time
|
||||
|
||||
|
@ -34,6 +36,7 @@ def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, act
|
|||
|
||||
# TODO: have routes/__init__.py point to the latest API version generate_stats()
|
||||
|
||||
@cache.memoize(timeout=10)
|
||||
def generate_stats():
|
||||
model_name, error = get_running_model() # will return False when the fetch fails
|
||||
if isinstance(model_name, bool):
|
||||
|
@ -83,6 +86,10 @@ def generate_stats():
|
|||
else:
|
||||
netdata_stats = {}
|
||||
|
||||
x = redis.get('base_client_api')
|
||||
base_client_api = x.decode() if x else None
|
||||
del x
|
||||
|
||||
output = {
|
||||
'stats': {
|
||||
'proompters': {
|
||||
|
@ -98,8 +105,8 @@ def generate_stats():
|
|||
},
|
||||
'online': online,
|
||||
'endpoints': {
|
||||
'blocking': f'https://{opts.base_client_api}',
|
||||
'streaming': f'wss://{opts.base_client_api}/v1/stream' if opts.enable_streaming else None,
|
||||
'blocking': f'https://{base_client_api}',
|
||||
'streaming': f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else None,
|
||||
},
|
||||
'queue': {
|
||||
'processing': active_gen_workers,
|
||||
|
|
|
@ -5,6 +5,7 @@ from llm_server import opts
|
|||
from llm_server.database import weighted_average_column_for_model
|
||||
from llm_server.llm.info import get_running_model
|
||||
from llm_server.routes.cache import redis
|
||||
from llm_server.routes.v1.generate_stats import generate_stats
|
||||
|
||||
|
||||
class MainBackgroundThread(Thread):
|
||||
|
@ -60,3 +61,12 @@ class MainBackgroundThread(Thread):
|
|||
estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero
|
||||
redis.set('estimated_avg_tps', estimated_avg_tps)
|
||||
time.sleep(60)
|
||||
|
||||
|
||||
def cache_stats():
|
||||
while True:
|
||||
# If opts.base_client_api is null that means no one has visited the site yet
|
||||
# and the base_client_api hasn't been set. Do nothing until then.
|
||||
if redis.get('base_client_api'):
|
||||
x = generate_stats()
|
||||
time.sleep(5)
|
||||
|
|
29
server.py
29
server.py
|
@ -28,7 +28,7 @@ from llm_server.routes.stats import SemaphoreCheckerThread, process_avg_gen_time
|
|||
from llm_server.routes.v1 import bp
|
||||
from llm_server.routes.v1.generate_stats import generate_stats
|
||||
from llm_server.stream import init_socketio
|
||||
from llm_server.threads import MainBackgroundThread
|
||||
from llm_server.threads import MainBackgroundThread, cache_stats
|
||||
|
||||
script_path = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
|
@ -93,12 +93,8 @@ if config['average_generation_time_mode'] not in ['database', 'minute']:
|
|||
sys.exit(1)
|
||||
opts.average_generation_time_mode = config['average_generation_time_mode']
|
||||
|
||||
# Start background processes
|
||||
start_workers(opts.concurrent_gens)
|
||||
|
||||
# cleanup_thread = Thread(target=elapsed_times_cleanup)
|
||||
# cleanup_thread.daemon = True
|
||||
# cleanup_thread.start()
|
||||
# Start the background thread
|
||||
process_avg_gen_time_background_thread = Thread(target=process_avg_gen_time)
|
||||
process_avg_gen_time_background_thread.daemon = True
|
||||
process_avg_gen_time_background_thread.start()
|
||||
|
@ -112,6 +108,11 @@ init_socketio(app)
|
|||
app.register_blueprint(bp, url_prefix='/api/v1/')
|
||||
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
|
||||
|
||||
# This needs to be started after Flask is initalized
|
||||
stats_updater_thread = Thread(target=cache_stats)
|
||||
stats_updater_thread.daemon = True
|
||||
stats_updater_thread.start()
|
||||
|
||||
|
||||
# print(app.url_map)
|
||||
|
||||
|
@ -121,8 +122,6 @@ app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
|
|||
@app.route('/api/openai')
|
||||
@cache.cached(timeout=10)
|
||||
def home():
|
||||
if not opts.base_client_api:
|
||||
opts.base_client_api = f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}'
|
||||
stats = generate_stats()
|
||||
|
||||
if not bool(redis.get('backend_online')) or not stats['online']:
|
||||
|
@ -151,6 +150,10 @@ def home():
|
|||
if opts.mode == 'vllm':
|
||||
mode_info = vllm_info
|
||||
|
||||
x = redis.get('base_client_api')
|
||||
base_client_api = x.decode() if x else None
|
||||
del x
|
||||
|
||||
return render_template('home.html',
|
||||
llm_middleware_name=opts.llm_middleware_name,
|
||||
analytics_tracking_code=analytics_tracking_code,
|
||||
|
@ -165,7 +168,7 @@ def home():
|
|||
context_size=opts.context_size,
|
||||
stats_json=json.dumps(stats, indent=4, ensure_ascii=False),
|
||||
extra_info=mode_info,
|
||||
openai_client_api=f'https://{opts.base_client_api}/openai/v1' if opts.enable_openi_compatible_backend else 'disabled',
|
||||
openai_client_api=f'https://{base_client_api}/openai/v1' if opts.enable_openi_compatible_backend else 'disabled',
|
||||
expose_openai_system_prompt=opts.expose_openai_system_prompt,
|
||||
enable_streaming=opts.enable_streaming,
|
||||
)
|
||||
|
@ -185,5 +188,13 @@ def server_error(e):
|
|||
return handle_server_error(e)
|
||||
|
||||
|
||||
@app.before_request
|
||||
def before_app_request():
|
||||
if not opts.http_host:
|
||||
opts.http_host = request.headers.get("Host")
|
||||
if not redis.get('base_client_api'):
|
||||
redis.set('base_client_api', f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(host='0.0.0.0', threaded=False, processes=15)
|
||||
|
|
Reference in New Issue