log model used in request so we can pull the correct averages when we change models

This commit is contained in:
Cyberes 2023-08-26 00:30:59 -06:00
parent 0150bbf8dd
commit 6a09ffc8a4
4 changed files with 27 additions and 13 deletions

View File

@ -42,10 +42,12 @@ To set up token auth, add rows to the `token_auth` table in the SQLite database.
### Use ### Use
**DO NOT** lose your database. It's used for calculating the estimated wait time based on average TPS and response tokens and if you lose those stats your numbers will be inaccurate until the database fills back up again. If you change GPUs, you should probably clear the `generation_time` time column in the `prompts` table. **DO NOT** lose your database. It's used for calculating the estimated wait time based on average TPS and response tokens and if you lose those stats your numbers will be inaccurate until the database fills back up again. If you change GPUs, you
should probably clear the `generation_time` time column in the `prompts` table.
### To Do ### To Do
- Implement streaming - Implement streaming
- Add `huggingface/text-generation-inference` - Add `huggingface/text-generation-inference`
- Convince Oobabooga to implement concurrent generation - Convince Oobabooga to implement concurrent generation
- Make sure stats work when starting from an empty database

View File

@ -24,6 +24,7 @@ def init_db():
response_tokens INTEGER, response_tokens INTEGER,
response_status INTEGER, response_status INTEGER,
generation_time FLOAT, generation_time FLOAT,
model TEXT,
parameters TEXT CHECK (parameters IS NULL OR json_valid(parameters)), parameters TEXT CHECK (parameters IS NULL OR json_valid(parameters)),
headers TEXT CHECK (headers IS NULL OR json_valid(headers)), headers TEXT CHECK (headers IS NULL OR json_valid(headers)),
timestamp INTEGER timestamp INTEGER
@ -59,8 +60,8 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
timestamp = int(time.time()) timestamp = int(time.time())
conn = sqlite3.connect(opts.database_path) conn = sqlite3.connect(opts.database_path)
c = conn.cursor() c = conn.cursor()
c.execute("INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", c.execute("INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(ip, token, prompt, prompt_tokens, response, response_tokens, backend_response_code, gen_time, json.dumps(parameters), json.dumps(headers), timestamp)) (ip, token, prompt, prompt_tokens, response, response_tokens, backend_response_code, gen_time, opts.running_model, json.dumps(parameters), json.dumps(headers), timestamp))
conn.commit() conn.commit()
conn.close() conn.close()
@ -108,6 +109,15 @@ def average_column(table_name, column_name):
return result[0] return result[0]
def average_column_for_model(table_name, column_name, model_name):
conn = sqlite3.connect(opts.database_path)
cursor = conn.cursor()
cursor.execute(f"SELECT AVG({column_name}) FROM {table_name} WHERE model = ?", (model_name,))
result = cursor.fetchone()
conn.close()
return result[0]
def sum_column(table_name, column_name): def sum_column(table_name, column_name):
conn = sqlite3.connect(opts.database_path) conn = sqlite3.connect(opts.database_path)
cursor = conn.cursor() cursor = conn.cursor()

View File

@ -34,7 +34,7 @@ def generate_stats():
if opts.average_generation_time_mode == 'database': if opts.average_generation_time_mode == 'database':
average_generation_time = int(float(redis.get('average_generation_elapsed_sec'))) average_generation_time = int(float(redis.get('average_generation_elapsed_sec')))
average_output_tokens = int(float(redis.get('average_output_tokens'))) average_output_tokens = int(float(redis.get('average_output_tokens')))
estimated_wait_sec = int(((average_output_tokens / average_tps) * proompters_in_queue) / opts.concurrent_gens) estimated_wait_sec = int(((average_output_tokens / average_tps) * proompters_in_queue) / opts.concurrent_gens) if average_tps > 0 else 0
elif opts.average_generation_time_mode == 'minute': elif opts.average_generation_time_mode == 'minute':
average_generation_time = int(calculate_avg_gen_time()) average_generation_time = int(calculate_avg_gen_time())
estimated_wait_sec = int((average_generation_time * proompters_in_queue) / opts.concurrent_gens) estimated_wait_sec = int((average_generation_time * proompters_in_queue) / opts.concurrent_gens)

View File

@ -4,7 +4,7 @@ from threading import Thread
import requests import requests
from llm_server import opts from llm_server import opts
from llm_server.database import average_column from llm_server.database import average_column_for_model
from llm_server.routes.cache import redis from llm_server.routes.cache import redis
@ -24,14 +24,6 @@ class MainBackgroundThread(Thread):
def run(self): def run(self):
while True: while True:
average_generation_elapsed_sec = average_column('prompts', 'generation_time') if not None else 0
redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec)
average_output_tokens = average_column('prompts', 'response_tokens') if not None else 0
redis.set('average_output_tokens', average_output_tokens)
average_tps = round(average_output_tokens / average_generation_elapsed_sec, 2)
redis.set('average_tps', average_tps)
if opts.mode == 'oobabooga': if opts.mode == 'oobabooga':
try: try:
r = requests.get(f'{opts.backend_url}/api/v1/model', timeout=3, verify=opts.verify_ssl) r = requests.get(f'{opts.backend_url}/api/v1/model', timeout=3, verify=opts.verify_ssl)
@ -45,4 +37,14 @@ class MainBackgroundThread(Thread):
pass pass
else: else:
raise Exception raise Exception
average_generation_elapsed_sec = average_column_for_model('prompts', 'generation_time', opts.running_model) or 0
redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec)
average_output_tokens = average_column_for_model('prompts', 'response_tokens', opts.running_model) or 0
redis.set('average_output_tokens', average_output_tokens)
# Avoid division by zero
average_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0
redis.set('average_tps', average_tps)
time.sleep(60) time.sleep(60)