From ec3fe2c2ac4dd987160e57d2860cc1bace21cf5d Mon Sep 17 00:00:00 2001 From: Cyberes Date: Thu, 24 Aug 2023 20:43:11 -0600 Subject: [PATCH] show total output tokens on stats --- llm_server/config.py | 1 + llm_server/database.py | 9 +++++++++ llm_server/opts.py | 3 +++ llm_server/routes/__init__.py | 1 + llm_server/routes/v1/generate_stats.py | 6 +++++- llm_server/threads.py | 5 ++++- server.py | 5 +++-- 7 files changed, 26 insertions(+), 4 deletions(-) diff --git a/llm_server/config.py b/llm_server/config.py index db87c9e..eaaf00d 100644 --- a/llm_server/config.py +++ b/llm_server/config.py @@ -12,6 +12,7 @@ config_default_vars = { 'analytics_tracking_code': '', 'average_generation_time_mode': 'database', 'info_html': None, + 'show_total_output_tokens': True, } config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name'] diff --git a/llm_server/database.py b/llm_server/database.py index e3bdafd..9273a28 100644 --- a/llm_server/database.py +++ b/llm_server/database.py @@ -101,3 +101,12 @@ def average_column(table_name, column_name): result = cursor.fetchone() conn.close() return result[0] + + +def sum_column(table_name, column_name): + conn = sqlite3.connect(opts.database_path) + cursor = conn.cursor() + cursor.execute(f"SELECT SUM({column_name}) FROM {table_name}") + result = cursor.fetchone() + conn.close() + return result[0] if result[0] else 0 diff --git a/llm_server/opts.py b/llm_server/opts.py index 9a3cb28..e6038d5 100644 --- a/llm_server/opts.py +++ b/llm_server/opts.py @@ -1,5 +1,7 @@ # Read-only global variables +# TODO: rewrite the config system so I don't have to add every single config default here + running_model = 'none' concurrent_gens = 3 mode = 'oobabooga' @@ -15,3 +17,4 @@ verify_ssl = True show_num_prompts = True show_uptime = True average_generation_time_mode = 'database' +show_total_output_tokens = True diff --git a/llm_server/routes/__init__.py b/llm_server/routes/__init__.py index e69de29..c13144e 100644 --- a/llm_server/routes/__init__.py +++ b/llm_server/routes/__init__.py @@ -0,0 +1 @@ +# TODO: move the inference API to /api/infer/ and the stats api to /api/v1/stats diff --git a/llm_server/routes/v1/generate_stats.py b/llm_server/routes/v1/generate_stats.py index 33d8da0..ad32b54 100644 --- a/llm_server/routes/v1/generate_stats.py +++ b/llm_server/routes/v1/generate_stats.py @@ -2,6 +2,7 @@ import time from datetime import datetime from llm_server import opts +from llm_server.database import sum_column from llm_server.helpers import deep_sort from llm_server.llm.info import get_running_model from llm_server.routes.cache import redis @@ -9,6 +10,8 @@ from llm_server.routes.queue import priority_queue from llm_server.routes.stats import SemaphoreCheckerThread, calculate_avg_gen_time, get_active_gen_workers, get_total_proompts, server_start_time +# TODO: have routes/__init__.py point to the latest API version generate_stats() + def generate_stats(): model_list, error = get_running_model() # will return False when the fetch fails if isinstance(model_list, bool): @@ -40,10 +43,11 @@ def generate_stats(): 'stats': { 'proompts_in_queue': proompters_in_queue, 'proompters_1_min': SemaphoreCheckerThread.proompters_1_min, - 'total_proompts': get_total_proompts() if opts.show_num_prompts else None, + 'proompts': get_total_proompts() if opts.show_num_prompts else None, 'uptime': int((datetime.now() - server_start_time).total_seconds()) if opts.show_uptime else None, 'average_generation_elapsed_sec': average_generation_time, 'average_tps': average_tps, + 'tokens_generated': sum_column('prompts', 'response_tokens') if opts.show_total_output_tokens else None, }, 'online': online, 'endpoints': { diff --git a/llm_server/threads.py b/llm_server/threads.py index 1a89238..437472d 100644 --- a/llm_server/threads.py +++ b/llm_server/threads.py @@ -8,9 +8,12 @@ from llm_server.database import average_column from llm_server.routes.cache import redis -class BackendHealthCheck(Thread): +class MainBackgroundThread(Thread): backend_online = False + # TODO: do I really need to put everything in Redis? + # TODO: call generate_stats() every minute, cache the results, put results in a DB table, then have other parts of code call this cache + def __init__(self): Thread.__init__(self) self.daemon = True diff --git a/server.py b/server.py index 092e5ae..2fe59d8 100644 --- a/server.py +++ b/server.py @@ -17,7 +17,7 @@ from llm_server.routes.queue import start_workers from llm_server.routes.stats import SemaphoreCheckerThread, process_avg_gen_time from llm_server.routes.v1 import bp from llm_server.routes.v1.generate_stats import generate_stats -from llm_server.threads import BackendHealthCheck +from llm_server.threads import MainBackgroundThread script_path = os.path.dirname(os.path.realpath(__file__)) @@ -51,6 +51,7 @@ opts.context_size = config['token_limit'] opts.show_num_prompts = config['show_num_prompts'] opts.show_uptime = config['show_uptime'] opts.backend_url = config['backend_url'].strip('/') +opts.show_total_output_tokens = config['show_total_output_tokens'] opts.verify_ssl = config['verify_ssl'] if not opts.verify_ssl: @@ -78,7 +79,7 @@ start_workers(opts.concurrent_gens) process_avg_gen_time_background_thread = Thread(target=process_avg_gen_time) process_avg_gen_time_background_thread.daemon = True process_avg_gen_time_background_thread.start() -BackendHealthCheck().start() +MainBackgroundThread().start() SemaphoreCheckerThread().start() app = Flask(__name__)