cache stats in background
This commit is contained in:
parent
edf13db324
commit
3c1254d3bf
|
@ -12,7 +12,6 @@ database_path = './proxy-server.db'
|
||||||
auth_required = False
|
auth_required = False
|
||||||
log_prompts = False
|
log_prompts = False
|
||||||
frontend_api_client = ''
|
frontend_api_client = ''
|
||||||
base_client_api = None
|
|
||||||
http_host = None
|
http_host = None
|
||||||
verify_ssl = True
|
verify_ssl = True
|
||||||
show_num_prompts = True
|
show_num_prompts = True
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from flask import Blueprint, request
|
from flask import Blueprint, request
|
||||||
|
|
||||||
|
from ..cache import redis
|
||||||
from ..helpers.client import format_sillytavern_err
|
from ..helpers.client import format_sillytavern_err
|
||||||
from ..helpers.http import require_api_key
|
from ..helpers.http import require_api_key
|
||||||
from ..openai_request_handler import build_openai_response
|
from ..openai_request_handler import build_openai_response
|
||||||
|
@ -10,13 +11,14 @@ openai_bp = Blueprint('openai/v1/', __name__)
|
||||||
|
|
||||||
|
|
||||||
@openai_bp.before_request
|
@openai_bp.before_request
|
||||||
def before_request():
|
def before_oai_request():
|
||||||
|
# TODO: unify with normal before_request()
|
||||||
if not opts.http_host:
|
if not opts.http_host:
|
||||||
opts.http_host = request.headers.get("Host")
|
opts.http_host = request.headers.get("Host")
|
||||||
if not opts.enable_openi_compatible_backend:
|
if not opts.enable_openi_compatible_backend:
|
||||||
return build_openai_response('', format_sillytavern_err('The OpenAI-compatible backend is disabled.', 'Access Denied')), 401
|
return build_openai_response('', format_sillytavern_err('The OpenAI-compatible backend is disabled.', 'Access Denied')), 401
|
||||||
if not opts.base_client_api:
|
if not redis.get('base_client_api'):
|
||||||
opts.base_client_api = f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}'
|
redis.set('base_client_api', f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}')
|
||||||
if request.endpoint != 'v1.get_stats':
|
if request.endpoint != 'v1.get_stats':
|
||||||
response = require_api_key()
|
response = require_api_key()
|
||||||
if response is not None:
|
if response is not None:
|
||||||
|
|
|
@ -3,7 +3,7 @@ import time
|
||||||
from typing import Tuple, Union
|
from typing import Tuple, Union
|
||||||
|
|
||||||
import flask
|
import flask
|
||||||
from flask import Response
|
from flask import Response, request
|
||||||
|
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
from llm_server.database import log_prompt
|
from llm_server.database import log_prompt
|
||||||
|
@ -11,7 +11,7 @@ from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
|
||||||
from llm_server.llm.vllm.vllm_backend import VLLMBackend
|
from llm_server.llm.vllm.vllm_backend import VLLMBackend
|
||||||
from llm_server.routes.cache import redis
|
from llm_server.routes.cache import redis
|
||||||
from llm_server.routes.helpers.client import format_sillytavern_err
|
from llm_server.routes.helpers.client import format_sillytavern_err
|
||||||
from llm_server.routes.helpers.http import validate_json
|
from llm_server.routes.helpers.http import require_api_key, validate_json
|
||||||
from llm_server.routes.queue import priority_queue
|
from llm_server.routes.queue import priority_queue
|
||||||
from llm_server.routes.stats import SemaphoreCheckerThread
|
from llm_server.routes.stats import SemaphoreCheckerThread
|
||||||
|
|
||||||
|
@ -178,3 +178,14 @@ def delete_dict_key(d: dict, k: Union[str, list]):
|
||||||
else:
|
else:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
def before_request():
|
||||||
|
if not opts.http_host:
|
||||||
|
opts.http_host = request.headers.get("Host")
|
||||||
|
if not redis.get('base_client_api'):
|
||||||
|
redis.set('base_client_api', f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}')
|
||||||
|
if request.endpoint != 'v1.get_stats':
|
||||||
|
response = require_api_key()
|
||||||
|
if response is not None:
|
||||||
|
return response
|
||||||
|
|
|
@ -1,22 +1,14 @@
|
||||||
from flask import Blueprint, request
|
from flask import Blueprint
|
||||||
|
|
||||||
from ..helpers.http import require_api_key
|
from ..request_handler import before_request
|
||||||
from ..server_error import handle_server_error
|
from ..server_error import handle_server_error
|
||||||
from ... import opts
|
|
||||||
|
|
||||||
bp = Blueprint('v1', __name__)
|
bp = Blueprint('v1', __name__)
|
||||||
|
|
||||||
|
|
||||||
@bp.before_request
|
@bp.before_request
|
||||||
def before_request():
|
def before_bp_request():
|
||||||
if not opts.http_host:
|
return before_request()
|
||||||
opts.http_host = request.headers.get("Host")
|
|
||||||
if not opts.base_client_api:
|
|
||||||
opts.base_client_api = f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}'
|
|
||||||
if request.endpoint != 'v1.get_stats':
|
|
||||||
response = require_api_key()
|
|
||||||
if response is not None:
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
@bp.errorhandler(500)
|
@bp.errorhandler(500)
|
||||||
|
|
|
@ -1,12 +1,14 @@
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
from flask import request
|
||||||
|
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
from llm_server.database import get_distinct_ips_24h, sum_column
|
from llm_server.database import get_distinct_ips_24h, sum_column
|
||||||
from llm_server.helpers import deep_sort, round_up_base
|
from llm_server.helpers import deep_sort, round_up_base
|
||||||
from llm_server.llm.info import get_running_model
|
from llm_server.llm.info import get_running_model
|
||||||
from llm_server.netdata import get_power_states
|
from llm_server.netdata import get_power_states
|
||||||
from llm_server.routes.cache import redis
|
from llm_server.routes.cache import cache, redis
|
||||||
from llm_server.routes.queue import priority_queue
|
from llm_server.routes.queue import priority_queue
|
||||||
from llm_server.routes.stats import SemaphoreCheckerThread, calculate_avg_gen_time, get_active_gen_workers, get_total_proompts, server_start_time
|
from llm_server.routes.stats import SemaphoreCheckerThread, calculate_avg_gen_time, get_active_gen_workers, get_total_proompts, server_start_time
|
||||||
|
|
||||||
|
@ -34,6 +36,7 @@ def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, act
|
||||||
|
|
||||||
# TODO: have routes/__init__.py point to the latest API version generate_stats()
|
# TODO: have routes/__init__.py point to the latest API version generate_stats()
|
||||||
|
|
||||||
|
@cache.memoize(timeout=10)
|
||||||
def generate_stats():
|
def generate_stats():
|
||||||
model_name, error = get_running_model() # will return False when the fetch fails
|
model_name, error = get_running_model() # will return False when the fetch fails
|
||||||
if isinstance(model_name, bool):
|
if isinstance(model_name, bool):
|
||||||
|
@ -83,6 +86,10 @@ def generate_stats():
|
||||||
else:
|
else:
|
||||||
netdata_stats = {}
|
netdata_stats = {}
|
||||||
|
|
||||||
|
x = redis.get('base_client_api')
|
||||||
|
base_client_api = x.decode() if x else None
|
||||||
|
del x
|
||||||
|
|
||||||
output = {
|
output = {
|
||||||
'stats': {
|
'stats': {
|
||||||
'proompters': {
|
'proompters': {
|
||||||
|
@ -98,8 +105,8 @@ def generate_stats():
|
||||||
},
|
},
|
||||||
'online': online,
|
'online': online,
|
||||||
'endpoints': {
|
'endpoints': {
|
||||||
'blocking': f'https://{opts.base_client_api}',
|
'blocking': f'https://{base_client_api}',
|
||||||
'streaming': f'wss://{opts.base_client_api}/v1/stream' if opts.enable_streaming else None,
|
'streaming': f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else None,
|
||||||
},
|
},
|
||||||
'queue': {
|
'queue': {
|
||||||
'processing': active_gen_workers,
|
'processing': active_gen_workers,
|
||||||
|
|
|
@ -5,6 +5,7 @@ from llm_server import opts
|
||||||
from llm_server.database import weighted_average_column_for_model
|
from llm_server.database import weighted_average_column_for_model
|
||||||
from llm_server.llm.info import get_running_model
|
from llm_server.llm.info import get_running_model
|
||||||
from llm_server.routes.cache import redis
|
from llm_server.routes.cache import redis
|
||||||
|
from llm_server.routes.v1.generate_stats import generate_stats
|
||||||
|
|
||||||
|
|
||||||
class MainBackgroundThread(Thread):
|
class MainBackgroundThread(Thread):
|
||||||
|
@ -60,3 +61,12 @@ class MainBackgroundThread(Thread):
|
||||||
estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero
|
estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero
|
||||||
redis.set('estimated_avg_tps', estimated_avg_tps)
|
redis.set('estimated_avg_tps', estimated_avg_tps)
|
||||||
time.sleep(60)
|
time.sleep(60)
|
||||||
|
|
||||||
|
|
||||||
|
def cache_stats():
|
||||||
|
while True:
|
||||||
|
# If opts.base_client_api is null that means no one has visited the site yet
|
||||||
|
# and the base_client_api hasn't been set. Do nothing until then.
|
||||||
|
if redis.get('base_client_api'):
|
||||||
|
x = generate_stats()
|
||||||
|
time.sleep(5)
|
||||||
|
|
29
server.py
29
server.py
|
@ -28,7 +28,7 @@ from llm_server.routes.stats import SemaphoreCheckerThread, process_avg_gen_time
|
||||||
from llm_server.routes.v1 import bp
|
from llm_server.routes.v1 import bp
|
||||||
from llm_server.routes.v1.generate_stats import generate_stats
|
from llm_server.routes.v1.generate_stats import generate_stats
|
||||||
from llm_server.stream import init_socketio
|
from llm_server.stream import init_socketio
|
||||||
from llm_server.threads import MainBackgroundThread
|
from llm_server.threads import MainBackgroundThread, cache_stats
|
||||||
|
|
||||||
script_path = os.path.dirname(os.path.realpath(__file__))
|
script_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
|
||||||
|
@ -93,12 +93,8 @@ if config['average_generation_time_mode'] not in ['database', 'minute']:
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
opts.average_generation_time_mode = config['average_generation_time_mode']
|
opts.average_generation_time_mode = config['average_generation_time_mode']
|
||||||
|
|
||||||
|
# Start background processes
|
||||||
start_workers(opts.concurrent_gens)
|
start_workers(opts.concurrent_gens)
|
||||||
|
|
||||||
# cleanup_thread = Thread(target=elapsed_times_cleanup)
|
|
||||||
# cleanup_thread.daemon = True
|
|
||||||
# cleanup_thread.start()
|
|
||||||
# Start the background thread
|
|
||||||
process_avg_gen_time_background_thread = Thread(target=process_avg_gen_time)
|
process_avg_gen_time_background_thread = Thread(target=process_avg_gen_time)
|
||||||
process_avg_gen_time_background_thread.daemon = True
|
process_avg_gen_time_background_thread.daemon = True
|
||||||
process_avg_gen_time_background_thread.start()
|
process_avg_gen_time_background_thread.start()
|
||||||
|
@ -112,6 +108,11 @@ init_socketio(app)
|
||||||
app.register_blueprint(bp, url_prefix='/api/v1/')
|
app.register_blueprint(bp, url_prefix='/api/v1/')
|
||||||
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
|
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
|
||||||
|
|
||||||
|
# This needs to be started after Flask is initalized
|
||||||
|
stats_updater_thread = Thread(target=cache_stats)
|
||||||
|
stats_updater_thread.daemon = True
|
||||||
|
stats_updater_thread.start()
|
||||||
|
|
||||||
|
|
||||||
# print(app.url_map)
|
# print(app.url_map)
|
||||||
|
|
||||||
|
@ -121,8 +122,6 @@ app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
|
||||||
@app.route('/api/openai')
|
@app.route('/api/openai')
|
||||||
@cache.cached(timeout=10)
|
@cache.cached(timeout=10)
|
||||||
def home():
|
def home():
|
||||||
if not opts.base_client_api:
|
|
||||||
opts.base_client_api = f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}'
|
|
||||||
stats = generate_stats()
|
stats = generate_stats()
|
||||||
|
|
||||||
if not bool(redis.get('backend_online')) or not stats['online']:
|
if not bool(redis.get('backend_online')) or not stats['online']:
|
||||||
|
@ -151,6 +150,10 @@ def home():
|
||||||
if opts.mode == 'vllm':
|
if opts.mode == 'vllm':
|
||||||
mode_info = vllm_info
|
mode_info = vllm_info
|
||||||
|
|
||||||
|
x = redis.get('base_client_api')
|
||||||
|
base_client_api = x.decode() if x else None
|
||||||
|
del x
|
||||||
|
|
||||||
return render_template('home.html',
|
return render_template('home.html',
|
||||||
llm_middleware_name=opts.llm_middleware_name,
|
llm_middleware_name=opts.llm_middleware_name,
|
||||||
analytics_tracking_code=analytics_tracking_code,
|
analytics_tracking_code=analytics_tracking_code,
|
||||||
|
@ -165,7 +168,7 @@ def home():
|
||||||
context_size=opts.context_size,
|
context_size=opts.context_size,
|
||||||
stats_json=json.dumps(stats, indent=4, ensure_ascii=False),
|
stats_json=json.dumps(stats, indent=4, ensure_ascii=False),
|
||||||
extra_info=mode_info,
|
extra_info=mode_info,
|
||||||
openai_client_api=f'https://{opts.base_client_api}/openai/v1' if opts.enable_openi_compatible_backend else 'disabled',
|
openai_client_api=f'https://{base_client_api}/openai/v1' if opts.enable_openi_compatible_backend else 'disabled',
|
||||||
expose_openai_system_prompt=opts.expose_openai_system_prompt,
|
expose_openai_system_prompt=opts.expose_openai_system_prompt,
|
||||||
enable_streaming=opts.enable_streaming,
|
enable_streaming=opts.enable_streaming,
|
||||||
)
|
)
|
||||||
|
@ -185,5 +188,13 @@ def server_error(e):
|
||||||
return handle_server_error(e)
|
return handle_server_error(e)
|
||||||
|
|
||||||
|
|
||||||
|
@app.before_request
|
||||||
|
def before_app_request():
|
||||||
|
if not opts.http_host:
|
||||||
|
opts.http_host = request.headers.get("Host")
|
||||||
|
if not redis.get('base_client_api'):
|
||||||
|
redis.set('base_client_api', f'{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
app.run(host='0.0.0.0', threaded=False, processes=15)
|
app.run(host='0.0.0.0', threaded=False, processes=15)
|
||||||
|
|
Reference in New Issue