From 441a870e855a572e1464091adc9bb58984097745 Mon Sep 17 00:00:00 2001 From: Cyberes Date: Sun, 27 Aug 2023 19:58:04 -0600 Subject: [PATCH] calculate weighted average for stat tracking --- README.md | 1 + llm_server/database.py | 33 +++++++++++++++++++++++++++++++++ llm_server/threads.py | 12 +++++++++--- 3 files changed, 43 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index f8ca1e4..e71c625 100644 --- a/README.md +++ b/README.md @@ -51,3 +51,4 @@ should probably clear the `generation_time` time column in the `prompts` table. - Add `huggingface/text-generation-inference` - Convince Oobabooga to implement concurrent generation - Make sure stats work when starting from an empty database +- Make sure we're correctly canceling requests when the client cancels diff --git a/llm_server/database.py b/llm_server/database.py index 269b7ba..db347da 100644 --- a/llm_server/database.py +++ b/llm_server/database.py @@ -118,6 +118,39 @@ def average_column_for_model(table_name, column_name, model_name): return result[0] +def weighted_average_column_for_model(table_name, column_name, model_name): + conn = sqlite3.connect(opts.database_path) + cursor = conn.cursor() + cursor.execute(f"SELECT DISTINCT model FROM {table_name}") + models = [row[0] for row in cursor.fetchall()] + + model_averages = {} + for model in models: + cursor.execute(f"SELECT {column_name}, ROWID FROM {table_name} WHERE model = ? ORDER BY ROWID DESC", (model,)) + results = cursor.fetchall() + + if not results: + continue + + total_weight = 0 + weighted_sum = 0 + for i, (value, rowid) in enumerate(results): + if value is None: + continue + weight = i + 1 + total_weight += weight + weighted_sum += weight * value + + if total_weight == 0: + continue + + model_averages[model] = weighted_sum / total_weight + + conn.close() + + return model_averages.get(model_name) + + def sum_column(table_name, column_name): conn = sqlite3.connect(opts.database_path) cursor = conn.cursor() diff --git a/llm_server/threads.py b/llm_server/threads.py index 9d1ca15..2a1df2e 100644 --- a/llm_server/threads.py +++ b/llm_server/threads.py @@ -4,7 +4,7 @@ from threading import Thread import requests from llm_server import opts -from llm_server.database import average_column_for_model +from llm_server.database import average_column_for_model, weighted_average_column_for_model from llm_server.routes.cache import redis @@ -38,12 +38,18 @@ class MainBackgroundThread(Thread): else: raise Exception - average_generation_elapsed_sec = average_column_for_model('prompts', 'generation_time', opts.running_model) or 0 + average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', opts.running_model) or 0 redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec) - average_output_tokens = average_column_for_model('prompts', 'response_tokens', opts.running_model) or 0 + # overall = average_column_for_model('prompts', 'generation_time', opts.running_model) + # print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}') + + average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', opts.running_model) or 0 redis.set('average_output_tokens', average_output_tokens) + # overall = average_column_for_model('prompts', 'response_tokens', opts.running_model) + # print(f'Weighted: {average_output_tokens}, overall: {overall}') + # Avoid division by zero average_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 redis.set('average_tps', average_tps)