add estimated wait time and other time tracking stats
This commit is contained in:
parent
0aa52863bc
commit
de19af900f
|
@ -1,7 +1,9 @@
|
|||
import heapq
|
||||
import threading
|
||||
import time
|
||||
|
||||
from llm_server.llm.generator import generator
|
||||
from llm_server.routes.stats import generation_elapsed, generation_elapsed_lock
|
||||
|
||||
|
||||
class PriorityQueue:
|
||||
|
@ -40,7 +42,14 @@ class DataEvent(threading.Event):
|
|||
def worker():
|
||||
while True:
|
||||
priority, index, (request_json_body, client_ip, token, parameters), event = priority_queue.get()
|
||||
start_time = time.time()
|
||||
success, response, error_msg = generator(request_json_body)
|
||||
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
with generation_elapsed_lock:
|
||||
generation_elapsed.append((end_time, elapsed_time))
|
||||
|
||||
event.data = (success, response, error_msg)
|
||||
event.set()
|
||||
|
||||
|
|
|
@ -1,13 +1,55 @@
|
|||
import time
|
||||
from datetime import datetime
|
||||
from threading import Semaphore, Thread
|
||||
from threading import Lock, Thread
|
||||
|
||||
from llm_server.opts import concurrent_gens
|
||||
from llm_server.routes.cache import redis
|
||||
|
||||
# proompters_1_min = 0
|
||||
concurrent_semaphore = Semaphore(concurrent_gens)
|
||||
start_time = datetime.now()
|
||||
# concurrent_semaphore = Semaphore(concurrent_gens)
|
||||
server_start_time = datetime.now()
|
||||
|
||||
# TODO: have a background thread put the averages in a variable so we don't end up with massive arrays
|
||||
|
||||
wait_in_queue_elapsed = []
|
||||
wait_in_queue_elapsed_lock = Lock()
|
||||
|
||||
generation_elapsed = []
|
||||
generation_elapsed_lock = Lock()
|
||||
|
||||
|
||||
def elapsed_times_cleanup():
|
||||
global wait_in_queue_elapsed
|
||||
while True:
|
||||
current_time = time.time()
|
||||
with wait_in_queue_elapsed_lock:
|
||||
global wait_in_queue_elapsed
|
||||
wait_in_queue_elapsed = [(end_time, elapsed_time) for end_time, elapsed_time in wait_in_queue_elapsed if current_time - end_time <= 60]
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def calculate_avg_gen_time():
|
||||
# Get the average generation time from Redis
|
||||
average_generation_time = redis.get('average_generation_time')
|
||||
if average_generation_time is None:
|
||||
return 0
|
||||
else:
|
||||
return float(average_generation_time)
|
||||
|
||||
|
||||
def process_avg_gen_time():
|
||||
while True:
|
||||
with generation_elapsed_lock:
|
||||
# Get the data from the last minute
|
||||
one_minute_ago = time.time() - 60
|
||||
recent_data = [elapsed for end, elapsed in generation_elapsed if end >= one_minute_ago]
|
||||
|
||||
# Calculate the average
|
||||
if len(recent_data) == 0:
|
||||
average_generation_time = 0
|
||||
else:
|
||||
average_generation_time = sum(recent_data) / len(recent_data)
|
||||
redis.set('average_generation_time', average_generation_time)
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
def get_count():
|
||||
|
|
|
@ -3,7 +3,7 @@ import time
|
|||
|
||||
from flask import jsonify, request
|
||||
|
||||
from llm_server.routes.stats import SemaphoreCheckerThread, concurrent_semaphore
|
||||
from llm_server.routes.stats import SemaphoreCheckerThread, wait_in_queue_elapsed, wait_in_queue_elapsed_lock
|
||||
from . import bp
|
||||
from ..cache import redis
|
||||
from ..helpers.client import format_sillytavern_err
|
||||
|
@ -18,94 +18,102 @@ DEFAULT_PRIORITY = 9999
|
|||
|
||||
@bp.route('/generate', methods=['POST'])
|
||||
def generate():
|
||||
start_time = time.time()
|
||||
request_valid_json, request_json_body = validate_json(request.data)
|
||||
if not request_valid_json:
|
||||
return jsonify({'code': 400, 'error': 'Invalid JSON'}), 400
|
||||
|
||||
with concurrent_semaphore:
|
||||
if request.headers.get('cf-connecting-ip'):
|
||||
client_ip = request.headers.get('cf-connecting-ip')
|
||||
elif request.headers.get('x-forwarded-for'):
|
||||
client_ip = request.headers.get('x-forwarded-for')
|
||||
if request.headers.get('cf-connecting-ip'):
|
||||
client_ip = request.headers.get('cf-connecting-ip')
|
||||
elif request.headers.get('x-forwarded-for'):
|
||||
client_ip = request.headers.get('x-forwarded-for')
|
||||
else:
|
||||
client_ip = request.remote_addr
|
||||
|
||||
SemaphoreCheckerThread.recent_prompters[client_ip] = time.time()
|
||||
|
||||
parameters = request_json_body.copy()
|
||||
del parameters['prompt']
|
||||
|
||||
token = request.headers.get('X-Api-Key')
|
||||
priority = None
|
||||
if token:
|
||||
conn = sqlite3.connect(opts.database_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT priority FROM token_auth WHERE token = ?", (token,))
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
priority = result[0]
|
||||
conn.close()
|
||||
|
||||
if priority is None:
|
||||
priority = DEFAULT_PRIORITY
|
||||
else:
|
||||
print(f'Token {token} was given priority {priority}.')
|
||||
|
||||
# success, response, error_msg = generator(request_json_body)
|
||||
event = priority_queue.put((request_json_body, client_ip, token, parameters), priority)
|
||||
event.wait()
|
||||
success, response, error_msg = event.data
|
||||
|
||||
# Add the elapsed time to a global list
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
# print('elapsed:', elapsed_time)
|
||||
with wait_in_queue_elapsed_lock:
|
||||
wait_in_queue_elapsed.append((end_time, elapsed_time))
|
||||
|
||||
if not success:
|
||||
if opts.mode == 'oobabooga':
|
||||
backend_response = format_sillytavern_err(f'Failed to reach the backend ({opts.mode}): {error_msg}', 'error')
|
||||
response_json_body = {
|
||||
'results': [
|
||||
{
|
||||
'text': backend_response,
|
||||
}
|
||||
],
|
||||
}
|
||||
else:
|
||||
client_ip = request.remote_addr
|
||||
raise Exception
|
||||
|
||||
SemaphoreCheckerThread.recent_prompters[client_ip] = time.time()
|
||||
|
||||
parameters = request_json_body.copy()
|
||||
del parameters['prompt']
|
||||
|
||||
token = request.headers.get('X-Api-Key')
|
||||
priority = None
|
||||
if token:
|
||||
conn = sqlite3.connect(opts.database_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT priority FROM token_auth WHERE token = ?", (token,))
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
priority = result[0]
|
||||
conn.close()
|
||||
|
||||
if priority is None:
|
||||
priority = DEFAULT_PRIORITY
|
||||
else:
|
||||
print(f'Token {token} was given priority {priority}.')
|
||||
|
||||
# success, response, error_msg = generator(request_json_body)
|
||||
event = priority_queue.put((request_json_body, client_ip, token, parameters), priority)
|
||||
event.wait()
|
||||
success, response, error_msg = event.data
|
||||
|
||||
if not success:
|
||||
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
|
||||
return jsonify({
|
||||
'code': 500,
|
||||
'error': 'failed to reach backend',
|
||||
**response_json_body
|
||||
}), 200
|
||||
response_valid_json, response_json_body = validate_json(response)
|
||||
if response_valid_json:
|
||||
redis.incr('proompts')
|
||||
backend_response = safe_list_get(response_json_body.get('results', []), 0, {}).get('text')
|
||||
if not backend_response:
|
||||
if opts.mode == 'oobabooga':
|
||||
backend_response = format_sillytavern_err(f'Failed to reach the backend ({opts.mode}): {error_msg}', 'error')
|
||||
response_json_body = {
|
||||
'results': [
|
||||
{
|
||||
'text': backend_response,
|
||||
}
|
||||
],
|
||||
}
|
||||
backend_response = format_sillytavern_err(
|
||||
f'Backend (oobabooga) returned an empty string. This can happen when your parameters are incorrect. Make sure your context size is no greater than {opts.context_size}. Furthermore, oobabooga does not support concurrent requests so all users have to wait in line and the backend server may have glitched for a moment. Please try again.',
|
||||
'error')
|
||||
response_json_body['results'][0]['text'] = backend_response
|
||||
else:
|
||||
raise Exception
|
||||
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
|
||||
return jsonify({
|
||||
'code': 500,
|
||||
'error': 'failed to reach backend',
|
||||
**response_json_body
|
||||
}), 200
|
||||
response_valid_json, response_json_body = validate_json(response)
|
||||
if response_valid_json:
|
||||
redis.incr('proompts')
|
||||
backend_response = safe_list_get(response_json_body.get('results', []), 0, {}).get('text')
|
||||
if not backend_response:
|
||||
if opts.mode == 'oobabooga':
|
||||
backend_response = format_sillytavern_err(
|
||||
f'Backend (oobabooga) returned an empty string. This can happen when your parameters are incorrect. Make sure your context size is no greater than {opts.context_size}. Furthermore, oobabooga does not support concurrent requests so all users have to wait in line and the backend server may have glitched for a moment. Please try again.',
|
||||
'error')
|
||||
response_json_body['results'][0]['text'] = backend_response
|
||||
else:
|
||||
raise Exception
|
||||
|
||||
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
|
||||
return jsonify({
|
||||
**response_json_body
|
||||
}), 200
|
||||
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
|
||||
return jsonify({
|
||||
**response_json_body
|
||||
}), 200
|
||||
else:
|
||||
if opts.mode == 'oobabooga':
|
||||
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error')
|
||||
response_json_body = {
|
||||
'results': [
|
||||
{
|
||||
'text': backend_response,
|
||||
}
|
||||
],
|
||||
}
|
||||
else:
|
||||
if opts.mode == 'oobabooga':
|
||||
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error')
|
||||
response_json_body = {
|
||||
'results': [
|
||||
{
|
||||
'text': backend_response,
|
||||
}
|
||||
],
|
||||
}
|
||||
else:
|
||||
raise Exception
|
||||
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
|
||||
return jsonify({
|
||||
'code': 500,
|
||||
'error': 'the backend did not return valid JSON',
|
||||
**response_json_body
|
||||
}), 200
|
||||
raise Exception
|
||||
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
|
||||
return jsonify({
|
||||
'code': 500,
|
||||
'error': 'the backend did not return valid JSON',
|
||||
**response_json_body
|
||||
}), 200
|
||||
|
|
|
@ -4,11 +4,10 @@ from datetime import datetime
|
|||
from flask import jsonify, request
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.routes.v1.generate import concurrent_semaphore
|
||||
from . import bp
|
||||
from .. import stats
|
||||
from ..cache import cache
|
||||
from ..stats import SemaphoreCheckerThread
|
||||
from ..queue import priority_queue
|
||||
from ..stats import SemaphoreCheckerThread, calculate_avg_gen_time
|
||||
from ...llm.info import get_running_model
|
||||
|
||||
|
||||
|
@ -21,12 +20,22 @@ def get_stats():
|
|||
else:
|
||||
online = True
|
||||
|
||||
# t = elapsed_times.copy() # copy since we do multiple operations and don't want it to change
|
||||
# if len(t) == 0:
|
||||
# estimated_wait = 0
|
||||
# else:
|
||||
# waits = [elapsed for end, elapsed in t]
|
||||
# estimated_wait = int(sum(waits) / len(waits))
|
||||
|
||||
average_generation_time = int(calculate_avg_gen_time())
|
||||
|
||||
return jsonify({
|
||||
'stats': {
|
||||
'proompters_now': opts.concurrent_gens - concurrent_semaphore._value,
|
||||
'prompts_in_queue': len(priority_queue),
|
||||
'proompters_1_min': SemaphoreCheckerThread.proompters_1_min,
|
||||
'total_proompts': stats.get_count(),
|
||||
'uptime': int((datetime.now() - stats.start_time).total_seconds()),
|
||||
'uptime': int((datetime.now() - stats.server_start_time).total_seconds()),
|
||||
'average_generation_elapsed_sec': average_generation_time,
|
||||
},
|
||||
'online': online,
|
||||
'mode': opts.mode,
|
||||
|
@ -34,6 +43,7 @@ def get_stats():
|
|||
'endpoints': {
|
||||
'blocking': f'https://{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}',
|
||||
},
|
||||
'estimated_wait_sec': int(average_generation_time * len(priority_queue)),
|
||||
'timestamp': int(time.time()),
|
||||
'openaiKeys': '∞',
|
||||
'anthropicKeys': '∞',
|
||||
|
|
10
server.py
10
server.py
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
|
||||
from flask import Flask, jsonify
|
||||
|
||||
|
@ -11,7 +12,7 @@ from llm_server.helpers import resolve_path
|
|||
from llm_server.routes.cache import cache
|
||||
from llm_server.routes.helpers.http import cache_control
|
||||
from llm_server.routes.queue import start_workers
|
||||
from llm_server.routes.stats import SemaphoreCheckerThread, concurrent_semaphore
|
||||
from llm_server.routes.stats import SemaphoreCheckerThread, elapsed_times_cleanup, process_avg_gen_time
|
||||
from llm_server.routes.v1 import bp
|
||||
|
||||
script_path = os.path.dirname(os.path.realpath(__file__))
|
||||
|
@ -56,6 +57,13 @@ if not opts.verify_ssl:
|
|||
|
||||
start_workers(opts.concurrent_gens)
|
||||
|
||||
# cleanup_thread = Thread(target=elapsed_times_cleanup)
|
||||
# cleanup_thread.daemon = True
|
||||
# cleanup_thread.start()
|
||||
# Start the background thread
|
||||
process_avg_gen_time_background_thread = Thread(target=process_avg_gen_time)
|
||||
process_avg_gen_time_background_thread.daemon = True
|
||||
process_avg_gen_time_background_thread.start()
|
||||
SemaphoreCheckerThread().start()
|
||||
|
||||
app = Flask(__name__)
|
||||
|
|
Reference in New Issue