calculate weighted average for stat tracking
This commit is contained in:
parent
6a09ffc8a4
commit
441a870e85
|
@ -51,3 +51,4 @@ should probably clear the `generation_time` time column in the `prompts` table.
|
||||||
- Add `huggingface/text-generation-inference`
|
- Add `huggingface/text-generation-inference`
|
||||||
- Convince Oobabooga to implement concurrent generation
|
- Convince Oobabooga to implement concurrent generation
|
||||||
- Make sure stats work when starting from an empty database
|
- Make sure stats work when starting from an empty database
|
||||||
|
- Make sure we're correctly canceling requests when the client cancels
|
||||||
|
|
|
@ -118,6 +118,39 @@ def average_column_for_model(table_name, column_name, model_name):
|
||||||
return result[0]
|
return result[0]
|
||||||
|
|
||||||
|
|
||||||
|
def weighted_average_column_for_model(table_name, column_name, model_name):
|
||||||
|
conn = sqlite3.connect(opts.database_path)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(f"SELECT DISTINCT model FROM {table_name}")
|
||||||
|
models = [row[0] for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
model_averages = {}
|
||||||
|
for model in models:
|
||||||
|
cursor.execute(f"SELECT {column_name}, ROWID FROM {table_name} WHERE model = ? ORDER BY ROWID DESC", (model,))
|
||||||
|
results = cursor.fetchall()
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
continue
|
||||||
|
|
||||||
|
total_weight = 0
|
||||||
|
weighted_sum = 0
|
||||||
|
for i, (value, rowid) in enumerate(results):
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
weight = i + 1
|
||||||
|
total_weight += weight
|
||||||
|
weighted_sum += weight * value
|
||||||
|
|
||||||
|
if total_weight == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
model_averages[model] = weighted_sum / total_weight
|
||||||
|
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
return model_averages.get(model_name)
|
||||||
|
|
||||||
|
|
||||||
def sum_column(table_name, column_name):
|
def sum_column(table_name, column_name):
|
||||||
conn = sqlite3.connect(opts.database_path)
|
conn = sqlite3.connect(opts.database_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
|
@ -4,7 +4,7 @@ from threading import Thread
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
from llm_server.database import average_column_for_model
|
from llm_server.database import average_column_for_model, weighted_average_column_for_model
|
||||||
from llm_server.routes.cache import redis
|
from llm_server.routes.cache import redis
|
||||||
|
|
||||||
|
|
||||||
|
@ -38,12 +38,18 @@ class MainBackgroundThread(Thread):
|
||||||
else:
|
else:
|
||||||
raise Exception
|
raise Exception
|
||||||
|
|
||||||
average_generation_elapsed_sec = average_column_for_model('prompts', 'generation_time', opts.running_model) or 0
|
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', opts.running_model) or 0
|
||||||
redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec)
|
redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec)
|
||||||
|
|
||||||
average_output_tokens = average_column_for_model('prompts', 'response_tokens', opts.running_model) or 0
|
# overall = average_column_for_model('prompts', 'generation_time', opts.running_model)
|
||||||
|
# print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}')
|
||||||
|
|
||||||
|
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', opts.running_model) or 0
|
||||||
redis.set('average_output_tokens', average_output_tokens)
|
redis.set('average_output_tokens', average_output_tokens)
|
||||||
|
|
||||||
|
# overall = average_column_for_model('prompts', 'response_tokens', opts.running_model)
|
||||||
|
# print(f'Weighted: {average_output_tokens}, overall: {overall}')
|
||||||
|
|
||||||
# Avoid division by zero
|
# Avoid division by zero
|
||||||
average_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0
|
average_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0
|
||||||
redis.set('average_tps', average_tps)
|
redis.set('average_tps', average_tps)
|
||||||
|
|
Reference in New Issue