add estimated wait time and other time tracking stats

This commit is contained in:
Cyberes 2023-08-23 21:33:52 -06:00
parent 0aa52863bc
commit de19af900f
5 changed files with 168 additions and 91 deletions

View File

@ -1,7 +1,9 @@
import heapq import heapq
import threading import threading
import time
from llm_server.llm.generator import generator from llm_server.llm.generator import generator
from llm_server.routes.stats import generation_elapsed, generation_elapsed_lock
class PriorityQueue: class PriorityQueue:
@ -40,7 +42,14 @@ class DataEvent(threading.Event):
def worker(): def worker():
while True: while True:
priority, index, (request_json_body, client_ip, token, parameters), event = priority_queue.get() priority, index, (request_json_body, client_ip, token, parameters), event = priority_queue.get()
start_time = time.time()
success, response, error_msg = generator(request_json_body) success, response, error_msg = generator(request_json_body)
end_time = time.time()
elapsed_time = end_time - start_time
with generation_elapsed_lock:
generation_elapsed.append((end_time, elapsed_time))
event.data = (success, response, error_msg) event.data = (success, response, error_msg)
event.set() event.set()

View File

@ -1,13 +1,55 @@
import time import time
from datetime import datetime from datetime import datetime
from threading import Semaphore, Thread from threading import Lock, Thread
from llm_server.opts import concurrent_gens
from llm_server.routes.cache import redis from llm_server.routes.cache import redis
# proompters_1_min = 0 # proompters_1_min = 0
concurrent_semaphore = Semaphore(concurrent_gens) # concurrent_semaphore = Semaphore(concurrent_gens)
start_time = datetime.now() server_start_time = datetime.now()
# TODO: have a background thread put the averages in a variable so we don't end up with massive arrays
wait_in_queue_elapsed = []
wait_in_queue_elapsed_lock = Lock()
generation_elapsed = []
generation_elapsed_lock = Lock()
def elapsed_times_cleanup():
global wait_in_queue_elapsed
while True:
current_time = time.time()
with wait_in_queue_elapsed_lock:
global wait_in_queue_elapsed
wait_in_queue_elapsed = [(end_time, elapsed_time) for end_time, elapsed_time in wait_in_queue_elapsed if current_time - end_time <= 60]
time.sleep(1)
def calculate_avg_gen_time():
# Get the average generation time from Redis
average_generation_time = redis.get('average_generation_time')
if average_generation_time is None:
return 0
else:
return float(average_generation_time)
def process_avg_gen_time():
while True:
with generation_elapsed_lock:
# Get the data from the last minute
one_minute_ago = time.time() - 60
recent_data = [elapsed for end, elapsed in generation_elapsed if end >= one_minute_ago]
# Calculate the average
if len(recent_data) == 0:
average_generation_time = 0
else:
average_generation_time = sum(recent_data) / len(recent_data)
redis.set('average_generation_time', average_generation_time)
time.sleep(5)
def get_count(): def get_count():

View File

@ -3,7 +3,7 @@ import time
from flask import jsonify, request from flask import jsonify, request
from llm_server.routes.stats import SemaphoreCheckerThread, concurrent_semaphore from llm_server.routes.stats import SemaphoreCheckerThread, wait_in_queue_elapsed, wait_in_queue_elapsed_lock
from . import bp from . import bp
from ..cache import redis from ..cache import redis
from ..helpers.client import format_sillytavern_err from ..helpers.client import format_sillytavern_err
@ -18,94 +18,102 @@ DEFAULT_PRIORITY = 9999
@bp.route('/generate', methods=['POST']) @bp.route('/generate', methods=['POST'])
def generate(): def generate():
start_time = time.time()
request_valid_json, request_json_body = validate_json(request.data) request_valid_json, request_json_body = validate_json(request.data)
if not request_valid_json: if not request_valid_json:
return jsonify({'code': 400, 'error': 'Invalid JSON'}), 400 return jsonify({'code': 400, 'error': 'Invalid JSON'}), 400
with concurrent_semaphore: if request.headers.get('cf-connecting-ip'):
if request.headers.get('cf-connecting-ip'): client_ip = request.headers.get('cf-connecting-ip')
client_ip = request.headers.get('cf-connecting-ip') elif request.headers.get('x-forwarded-for'):
elif request.headers.get('x-forwarded-for'): client_ip = request.headers.get('x-forwarded-for')
client_ip = request.headers.get('x-forwarded-for') else:
client_ip = request.remote_addr
SemaphoreCheckerThread.recent_prompters[client_ip] = time.time()
parameters = request_json_body.copy()
del parameters['prompt']
token = request.headers.get('X-Api-Key')
priority = None
if token:
conn = sqlite3.connect(opts.database_path)
cursor = conn.cursor()
cursor.execute("SELECT priority FROM token_auth WHERE token = ?", (token,))
result = cursor.fetchone()
if result:
priority = result[0]
conn.close()
if priority is None:
priority = DEFAULT_PRIORITY
else:
print(f'Token {token} was given priority {priority}.')
# success, response, error_msg = generator(request_json_body)
event = priority_queue.put((request_json_body, client_ip, token, parameters), priority)
event.wait()
success, response, error_msg = event.data
# Add the elapsed time to a global list
end_time = time.time()
elapsed_time = end_time - start_time
# print('elapsed:', elapsed_time)
with wait_in_queue_elapsed_lock:
wait_in_queue_elapsed.append((end_time, elapsed_time))
if not success:
if opts.mode == 'oobabooga':
backend_response = format_sillytavern_err(f'Failed to reach the backend ({opts.mode}): {error_msg}', 'error')
response_json_body = {
'results': [
{
'text': backend_response,
}
],
}
else: else:
client_ip = request.remote_addr raise Exception
SemaphoreCheckerThread.recent_prompters[client_ip] = time.time() log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
return jsonify({
parameters = request_json_body.copy() 'code': 500,
del parameters['prompt'] 'error': 'failed to reach backend',
**response_json_body
token = request.headers.get('X-Api-Key') }), 200
priority = None response_valid_json, response_json_body = validate_json(response)
if token: if response_valid_json:
conn = sqlite3.connect(opts.database_path) redis.incr('proompts')
cursor = conn.cursor() backend_response = safe_list_get(response_json_body.get('results', []), 0, {}).get('text')
cursor.execute("SELECT priority FROM token_auth WHERE token = ?", (token,)) if not backend_response:
result = cursor.fetchone()
if result:
priority = result[0]
conn.close()
if priority is None:
priority = DEFAULT_PRIORITY
else:
print(f'Token {token} was given priority {priority}.')
# success, response, error_msg = generator(request_json_body)
event = priority_queue.put((request_json_body, client_ip, token, parameters), priority)
event.wait()
success, response, error_msg = event.data
if not success:
if opts.mode == 'oobabooga': if opts.mode == 'oobabooga':
backend_response = format_sillytavern_err(f'Failed to reach the backend ({opts.mode}): {error_msg}', 'error') backend_response = format_sillytavern_err(
response_json_body = { f'Backend (oobabooga) returned an empty string. This can happen when your parameters are incorrect. Make sure your context size is no greater than {opts.context_size}. Furthermore, oobabooga does not support concurrent requests so all users have to wait in line and the backend server may have glitched for a moment. Please try again.',
'results': [ 'error')
{ response_json_body['results'][0]['text'] = backend_response
'text': backend_response,
}
],
}
else: else:
raise Exception raise Exception
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
return jsonify({
'code': 500,
'error': 'failed to reach backend',
**response_json_body
}), 200
response_valid_json, response_json_body = validate_json(response)
if response_valid_json:
redis.incr('proompts')
backend_response = safe_list_get(response_json_body.get('results', []), 0, {}).get('text')
if not backend_response:
if opts.mode == 'oobabooga':
backend_response = format_sillytavern_err(
f'Backend (oobabooga) returned an empty string. This can happen when your parameters are incorrect. Make sure your context size is no greater than {opts.context_size}. Furthermore, oobabooga does not support concurrent requests so all users have to wait in line and the backend server may have glitched for a moment. Please try again.',
'error')
response_json_body['results'][0]['text'] = backend_response
else:
raise Exception
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code) log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
return jsonify({ return jsonify({
**response_json_body **response_json_body
}), 200 }), 200
else:
if opts.mode == 'oobabooga':
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error')
response_json_body = {
'results': [
{
'text': backend_response,
}
],
}
else: else:
if opts.mode == 'oobabooga': raise Exception
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error') log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
response_json_body = { return jsonify({
'results': [ 'code': 500,
{ 'error': 'the backend did not return valid JSON',
'text': backend_response, **response_json_body
} }), 200
],
}
else:
raise Exception
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
return jsonify({
'code': 500,
'error': 'the backend did not return valid JSON',
**response_json_body
}), 200

View File

@ -4,11 +4,10 @@ from datetime import datetime
from flask import jsonify, request from flask import jsonify, request
from llm_server import opts from llm_server import opts
from llm_server.routes.v1.generate import concurrent_semaphore
from . import bp from . import bp
from .. import stats from .. import stats
from ..cache import cache from ..queue import priority_queue
from ..stats import SemaphoreCheckerThread from ..stats import SemaphoreCheckerThread, calculate_avg_gen_time
from ...llm.info import get_running_model from ...llm.info import get_running_model
@ -21,12 +20,22 @@ def get_stats():
else: else:
online = True online = True
# t = elapsed_times.copy() # copy since we do multiple operations and don't want it to change
# if len(t) == 0:
# estimated_wait = 0
# else:
# waits = [elapsed for end, elapsed in t]
# estimated_wait = int(sum(waits) / len(waits))
average_generation_time = int(calculate_avg_gen_time())
return jsonify({ return jsonify({
'stats': { 'stats': {
'proompters_now': opts.concurrent_gens - concurrent_semaphore._value, 'prompts_in_queue': len(priority_queue),
'proompters_1_min': SemaphoreCheckerThread.proompters_1_min, 'proompters_1_min': SemaphoreCheckerThread.proompters_1_min,
'total_proompts': stats.get_count(), 'total_proompts': stats.get_count(),
'uptime': int((datetime.now() - stats.start_time).total_seconds()), 'uptime': int((datetime.now() - stats.server_start_time).total_seconds()),
'average_generation_elapsed_sec': average_generation_time,
}, },
'online': online, 'online': online,
'mode': opts.mode, 'mode': opts.mode,
@ -34,6 +43,7 @@ def get_stats():
'endpoints': { 'endpoints': {
'blocking': f'https://{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}', 'blocking': f'https://{request.headers.get("Host")}/{opts.frontend_api_client.strip("/")}',
}, },
'estimated_wait_sec': int(average_generation_time * len(priority_queue)),
'timestamp': int(time.time()), 'timestamp': int(time.time()),
'openaiKeys': '', 'openaiKeys': '',
'anthropicKeys': '', 'anthropicKeys': '',

View File

@ -1,6 +1,7 @@
import os import os
import sys import sys
from pathlib import Path from pathlib import Path
from threading import Thread
from flask import Flask, jsonify from flask import Flask, jsonify
@ -11,7 +12,7 @@ from llm_server.helpers import resolve_path
from llm_server.routes.cache import cache from llm_server.routes.cache import cache
from llm_server.routes.helpers.http import cache_control from llm_server.routes.helpers.http import cache_control
from llm_server.routes.queue import start_workers from llm_server.routes.queue import start_workers
from llm_server.routes.stats import SemaphoreCheckerThread, concurrent_semaphore from llm_server.routes.stats import SemaphoreCheckerThread, elapsed_times_cleanup, process_avg_gen_time
from llm_server.routes.v1 import bp from llm_server.routes.v1 import bp
script_path = os.path.dirname(os.path.realpath(__file__)) script_path = os.path.dirname(os.path.realpath(__file__))
@ -56,6 +57,13 @@ if not opts.verify_ssl:
start_workers(opts.concurrent_gens) start_workers(opts.concurrent_gens)
# cleanup_thread = Thread(target=elapsed_times_cleanup)
# cleanup_thread.daemon = True
# cleanup_thread.start()
# Start the background thread
process_avg_gen_time_background_thread = Thread(target=process_avg_gen_time)
process_avg_gen_time_background_thread.daemon = True
process_avg_gen_time_background_thread.start()
SemaphoreCheckerThread().start() SemaphoreCheckerThread().start()
app = Flask(__name__) app = Flask(__name__)