add estimated wait time and other time tracking stats

This commit is contained in:
Cyberes 2023-08-23 21:33:52 -06:00
parent 0aa52863bc
commit de19af900f
5 changed files with 168 additions and 91 deletions

View File

@ -1,7 +1,9 @@
import heapq
import threading
import time
from llm_server.llm.generator import generator
from llm_server.routes.stats import generation_elapsed, generation_elapsed_lock
class PriorityQueue:
@ -40,7 +42,14 @@ class DataEvent(threading.Event):
def worker():
while True:
priority, index, (request_json_body, client_ip, token, parameters), event = priority_queue.get()
start_time = time.time()
success, response, error_msg = generator(request_json_body)
end_time = time.time()
elapsed_time = end_time - start_time
with generation_elapsed_lock:
generation_elapsed.append((end_time, elapsed_time))
event.data = (success, response, error_msg)
event.set()

View File

@ -1,13 +1,55 @@
import time
from datetime import datetime
from threading import Semaphore, Thread
from threading import Lock, Thread
from llm_server.opts import concurrent_gens
from llm_server.routes.cache import redis
# proompters_1_min = 0
concurrent_semaphore = Semaphore(concurrent_gens)
start_time = datetime.now()
# concurrent_semaphore = Semaphore(concurrent_gens)
server_start_time = datetime.now()
# TODO: have a background thread put the averages in a variable so we don't end up with massive arrays
wait_in_queue_elapsed = []
wait_in_queue_elapsed_lock = Lock()
generation_elapsed = []
generation_elapsed_lock = Lock()
def elapsed_times_cleanup():
global wait_in_queue_elapsed
while True:
current_time = time.time()
with wait_in_queue_elapsed_lock:
global wait_in_queue_elapsed
wait_in_queue_elapsed = [(end_time, elapsed_time) for end_time, elapsed_time in wait_in_queue_elapsed if current_time - end_time <= 60]
time.sleep(1)
def calculate_avg_gen_time():
# Get the average generation time from Redis
average_generation_time = redis.get('average_generation_time')
if average_generation_time is None:
return 0
else:
return float(average_generation_time)
def process_avg_gen_time():
while True:
with generation_elapsed_lock:
# Get the data from the last minute
one_minute_ago = time.time() - 60
recent_data = [elapsed for end, elapsed in generation_elapsed if end >= one_minute_ago]
# Calculate the average
if len(recent_data) == 0:
average_generation_time = 0
else:
average_generation_time = sum(recent_data) / len(recent_data)
redis.set('average_generation_time', average_generation_time)
time.sleep(5)
def get_count():

View File

@ -3,7 +3,7 @@ import time
from flask import jsonify, request
from llm_server.routes.stats import SemaphoreCheckerThread, concurrent_semaphore
from llm_server.routes.stats import SemaphoreCheckerThread, wait_in_queue_elapsed, wait_in_queue_elapsed_lock
from . import bp
from ..cache import redis
from ..helpers.client import format_sillytavern_err
@ -18,94 +18,102 @@ DEFAULT_PRIORITY = 9999
@bp.route('/generate', methods=['POST'])
def generate():
start_time = time.time()
request_valid_json, request_json_body = validate_json(request.data)
if not request_valid_json:
return jsonify({'code': 400, 'error': 'Invalid JSON'}), 400
with concurrent_semaphore:
if request.headers.get('cf-connecting-ip'):
client_ip = request.headers.get('cf-connecting-ip')
elif request.headers.get('x-forwarded-for'):
client_ip = request.headers.get('x-forwarded-for')
if request.headers.get('cf-connecting-ip'):
client_ip = request.headers.get('cf-connecting-ip')
elif request.headers.get('x-forwarded-for'):
client_ip = request.headers.get('x-forwarded-for')
else:
client_ip = request.remote_addr
SemaphoreCheckerThread.recent_prompters[client_ip] = time.time()
parameters = request_json_body.copy()
del parameters['prompt']
token = request.headers.get('X-Api-Key')
priority = None
if token:
conn = sqlite3.connect(opts.database_path)
cursor = conn.cursor()
cursor.execute("SELECT priority FROM token_auth WHERE token = ?", (token,))
result = cursor.fetchone()
if result:
priority = result[0]
conn.close()
if priority is None:
priority = DEFAULT_PRIORITY
else:
print(f'Token {token} was given priority {priority}.')
# success, response, error_msg = generator(request_json_body)
event = priority_queue.put((request_json_body, client_ip, token, parameters), priority)
event.wait()
success, response, error_msg = event.data
# Add the elapsed time to a global list
end_time = time.time()
elapsed_time = end_time - start_time
# print('elapsed:', elapsed_time)
with wait_in_queue_elapsed_lock:
wait_in_queue_elapsed.append((end_time, elapsed_time))
if not success:
if opts.mode == 'oobabooga':
backend_response = format_sillytavern_err(f'Failed to reach the backend ({opts.mode}): {error_msg}', 'error')
response_json_body = {
'results': [
{
'text': backend_response,
}
],
}
else:
client_ip = request.remote_addr
raise Exception
SemaphoreCheckerThread.recent_prompters[client_ip] = time.time()
parameters = request_json_body.copy()
del parameters['prompt']
token = request.headers.get('X-Api-Key')
priority = None
if token:
conn = sqlite3.connect(opts.database_path)
cursor = conn.cursor()
cursor.execute("SELECT priority FROM token_auth WHERE token = ?", (token,))
result = cursor.fetchone()
if result:
priority = result[0]
conn.close()
if priority is None:
priority = DEFAULT_PRIORITY
else:
print(f'Token {token} was given priority {priority}.')
# success, response, error_msg = generator(request_json_body)
event = priority_queue.put((request_json_body, client_ip, token, parameters), priority)
event.wait()
success, response, error_msg = event.data
if not success:
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
return jsonify({
'code': 500,
'error': 'failed to reach backend',
**response_json_body
}), 200
response_valid_json, response_json_body = validate_json(response)
if response_valid_json:
redis.incr('proompts')
backend_response = safe_list_get(response_json_body.get('results', []), 0, {}).get('text')
if not backend_response:
if opts.mode == 'oobabooga':
backend_response = format_sillytavern_err(f'Failed to reach the backend ({opts.mode}): {error_msg}', 'error')
response_json_body = {
'results': [
{
'text': backend_response,
}
],
}
backend_response = format_sillytavern_err(
f'Backend (oobabooga) returned an empty string. This can happen when your parameters are incorrect. Make sure your context size is no greater than {opts.context_size}. Furthermore, oobabooga does not support concurrent requests so all users have to wait in line and the backend server may have glitched for a moment. Please try again.',
'error')
response_json_body['results'][0]['text'] = backend_response
else:
raise Exception
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
return jsonify({
'code': 500,
'error': 'failed to reach backend',
**response_json_body
}), 200
response_valid_json, response_json_body = validate_json(response)
if response_valid_json:
redis.incr('proompts')
backend_response = safe_list_get(response_json_body.get('results', []), 0, {}).get('text')
if not backend_response:
if opts.mode == 'oobabooga':
backend_response = format_sillytavern_err(
f'Backend (oobabooga) returned an empty string. This can happen when your parameters are incorrect. Make sure your context size is no greater than {opts.context_size}. Furthermore, oobabooga does not support concurrent requests so all users have to wait in line and the backend server may have glitched for a moment. Please try again.',
'error')
response_json_body['results'][0]['text'] = backend_response
else:
raise Exception
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
return jsonify({
**response_json_body
}), 200
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
return jsonify({
**response_json_body
}), 200
else:
if opts.mode == 'oobabooga':
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error')
response_json_body = {
'results': [
{
'text': backend_response,
}
],
}
else:
if opts.mode == 'oobabooga':
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error')
response_json_body = {
'results': [
{
'text': backend_response,
}
],
}
else:
raise Exception
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
return jsonify({
'code': 500,
'error': 'the backend did not return valid JSON',
**response_json_body
}), 200
raise Exception
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
return jsonify({
'code': 500,
'error': 'the backend did not return valid JSON',
**response_json_body
}), 200

View File

@ -4,11 +4,10 @@ from datetime import datetime
from flask import jsonify, request
from llm_server import opts
from llm_server.routes.v1.generate import concurrent_semaphore
from . import bp
from .. import stats
from ..cache import cache
from ..stats import SemaphoreCheckerThread
from ..queue import priority_queue
from ..stats import SemaphoreCheckerThread, calculate_avg_gen_time
from ...llm.info import get_running_model
@ -21,12 +20,22 @@ def get_stats():
else:
online = True
# t = elapsed_times.copy() # copy since we do multiple operations and don't want it to change
# if len(t) == 0:
# estimated_wait = 0
# else:
# waits = [elapsed for end, elapsed in t]
# estimated_wait = int(sum(waits) / len(waits))
average_generation_time = int(calculate_avg_gen_time())
return jsonify({
'stats': {
'proompters_now': opts.concurrent_gens - concurrent_semaphore._value,
'prompts_in_queue': len(priority_queue),
'proompters_1_min': SemaphoreCheckerThread.proompters_1_min,
'total_proompts': stats.get_count(),
'uptime': int((datetime.now() - stats.start_time).total_seconds()),
'uptime': int((datetime.now() - stats.server_start_time).total_seconds()),
'average_generation_elapsed_sec': average_generation_time,
},
'online': online,
'mode': opts.mode,
@ -34,6 +43,7 @@ def get_stats():
'endpoints': {
'blocking': f'https://{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}',
},
'estimated_wait_sec': int(average_generation_time * len(priority_queue)),
'timestamp': int(time.time()),
'openaiKeys': '',
'anthropicKeys': '',

View File

@ -1,6 +1,7 @@
import os
import sys
from pathlib import Path
from threading import Thread
from flask import Flask, jsonify
@ -11,7 +12,7 @@ from llm_server.helpers import resolve_path
from llm_server.routes.cache import cache
from llm_server.routes.helpers.http import cache_control
from llm_server.routes.queue import start_workers
from llm_server.routes.stats import SemaphoreCheckerThread, concurrent_semaphore
from llm_server.routes.stats import SemaphoreCheckerThread, elapsed_times_cleanup, process_avg_gen_time
from llm_server.routes.v1 import bp
script_path = os.path.dirname(os.path.realpath(__file__))
@ -56,6 +57,13 @@ if not opts.verify_ssl:
start_workers(opts.concurrent_gens)
# cleanup_thread = Thread(target=elapsed_times_cleanup)
# cleanup_thread.daemon = True
# cleanup_thread.start()
# Start the background thread
process_avg_gen_time_background_thread = Thread(target=process_avg_gen_time)
process_avg_gen_time_background_thread.daemon = True
process_avg_gen_time_background_thread.start()
SemaphoreCheckerThread().start()
app = Flask(__name__)