show total output tokens on stats

This commit is contained in:
Cyberes 2023-08-24 20:43:11 -06:00
parent 9b7bf490a1
commit ec3fe2c2ac
7 changed files with 26 additions and 4 deletions

View File

@ -12,6 +12,7 @@ config_default_vars = {
'analytics_tracking_code': '',
'average_generation_time_mode': 'database',
'info_html': None,
'show_total_output_tokens': True,
}
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']

View File

@ -101,3 +101,12 @@ def average_column(table_name, column_name):
result = cursor.fetchone()
conn.close()
return result[0]
def sum_column(table_name, column_name):
conn = sqlite3.connect(opts.database_path)
cursor = conn.cursor()
cursor.execute(f"SELECT SUM({column_name}) FROM {table_name}")
result = cursor.fetchone()
conn.close()
return result[0] if result[0] else 0

View File

@ -1,5 +1,7 @@
# Read-only global variables
# TODO: rewrite the config system so I don't have to add every single config default here
running_model = 'none'
concurrent_gens = 3
mode = 'oobabooga'
@ -15,3 +17,4 @@ verify_ssl = True
show_num_prompts = True
show_uptime = True
average_generation_time_mode = 'database'
show_total_output_tokens = True

View File

@ -0,0 +1 @@
# TODO: move the inference API to /api/infer/ and the stats api to /api/v1/stats

View File

@ -2,6 +2,7 @@ import time
from datetime import datetime
from llm_server import opts
from llm_server.database import sum_column
from llm_server.helpers import deep_sort
from llm_server.llm.info import get_running_model
from llm_server.routes.cache import redis
@ -9,6 +10,8 @@ from llm_server.routes.queue import priority_queue
from llm_server.routes.stats import SemaphoreCheckerThread, calculate_avg_gen_time, get_active_gen_workers, get_total_proompts, server_start_time
# TODO: have routes/__init__.py point to the latest API version generate_stats()
def generate_stats():
model_list, error = get_running_model() # will return False when the fetch fails
if isinstance(model_list, bool):
@ -40,10 +43,11 @@ def generate_stats():
'stats': {
'proompts_in_queue': proompters_in_queue,
'proompters_1_min': SemaphoreCheckerThread.proompters_1_min,
'total_proompts': get_total_proompts() if opts.show_num_prompts else None,
'proompts': get_total_proompts() if opts.show_num_prompts else None,
'uptime': int((datetime.now() - server_start_time).total_seconds()) if opts.show_uptime else None,
'average_generation_elapsed_sec': average_generation_time,
'average_tps': average_tps,
'tokens_generated': sum_column('prompts', 'response_tokens') if opts.show_total_output_tokens else None,
},
'online': online,
'endpoints': {

View File

@ -8,9 +8,12 @@ from llm_server.database import average_column
from llm_server.routes.cache import redis
class BackendHealthCheck(Thread):
class MainBackgroundThread(Thread):
backend_online = False
# TODO: do I really need to put everything in Redis?
# TODO: call generate_stats() every minute, cache the results, put results in a DB table, then have other parts of code call this cache
def __init__(self):
Thread.__init__(self)
self.daemon = True

View File

@ -17,7 +17,7 @@ from llm_server.routes.queue import start_workers
from llm_server.routes.stats import SemaphoreCheckerThread, process_avg_gen_time
from llm_server.routes.v1 import bp
from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.threads import BackendHealthCheck
from llm_server.threads import MainBackgroundThread
script_path = os.path.dirname(os.path.realpath(__file__))
@ -51,6 +51,7 @@ opts.context_size = config['token_limit']
opts.show_num_prompts = config['show_num_prompts']
opts.show_uptime = config['show_uptime']
opts.backend_url = config['backend_url'].strip('/')
opts.show_total_output_tokens = config['show_total_output_tokens']
opts.verify_ssl = config['verify_ssl']
if not opts.verify_ssl:
@ -78,7 +79,7 @@ start_workers(opts.concurrent_gens)
process_avg_gen_time_background_thread = Thread(target=process_avg_gen_time)
process_avg_gen_time_background_thread.daemon = True
process_avg_gen_time_background_thread.start()
BackendHealthCheck().start()
MainBackgroundThread().start()
SemaphoreCheckerThread().start()
app = Flask(__name__)