2023-08-21 21:28:52 -06:00
|
|
|
import json
|
|
|
|
import sqlite3
|
|
|
|
import time
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
import tiktoken
|
|
|
|
|
|
|
|
from llm_server import opts
|
|
|
|
|
|
|
|
tokenizer = tiktoken.get_encoding("cl100k_base")
|
|
|
|
|
|
|
|
|
2023-08-23 22:01:06 -06:00
|
|
|
def init_db():
|
|
|
|
if not Path(opts.database_path).exists():
|
|
|
|
conn = sqlite3.connect(opts.database_path)
|
2023-08-21 21:28:52 -06:00
|
|
|
c = conn.cursor()
|
|
|
|
c.execute('''
|
|
|
|
CREATE TABLE prompts (
|
|
|
|
ip TEXT,
|
|
|
|
token TEXT DEFAULT NULL,
|
2023-09-12 01:04:11 -06:00
|
|
|
backend TEXT,
|
2023-08-21 21:28:52 -06:00
|
|
|
prompt TEXT,
|
|
|
|
prompt_tokens INTEGER,
|
|
|
|
response TEXT,
|
|
|
|
response_tokens INTEGER,
|
2023-08-22 19:58:31 -06:00
|
|
|
response_status INTEGER,
|
2023-08-23 22:28:03 -06:00
|
|
|
generation_time FLOAT,
|
2023-08-26 00:30:59 -06:00
|
|
|
model TEXT,
|
2023-08-21 21:28:52 -06:00
|
|
|
parameters TEXT CHECK (parameters IS NULL OR json_valid(parameters)),
|
|
|
|
headers TEXT CHECK (headers IS NULL OR json_valid(headers)),
|
|
|
|
timestamp INTEGER
|
|
|
|
)
|
|
|
|
''')
|
|
|
|
c.execute('''
|
2023-08-23 20:12:38 -06:00
|
|
|
CREATE TABLE token_auth (
|
|
|
|
token TEXT UNIQUE,
|
|
|
|
type TEXT NOT NULL,
|
|
|
|
priority INTEGER default 9999,
|
|
|
|
uses INTEGER default 0,
|
|
|
|
max_uses INTEGER,
|
|
|
|
expire INTEGER,
|
|
|
|
disabled BOOLEAN default 0
|
|
|
|
)
|
2023-08-21 21:28:52 -06:00
|
|
|
''')
|
|
|
|
conn.commit()
|
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
2023-08-29 14:48:33 -06:00
|
|
|
def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, response_tokens: int = None, is_error: bool = False):
|
2023-08-21 21:28:52 -06:00
|
|
|
prompt_tokens = len(tokenizer.encode(prompt))
|
2023-08-29 14:48:33 -06:00
|
|
|
|
2023-08-29 15:46:56 -06:00
|
|
|
if not is_error:
|
|
|
|
if not response_tokens:
|
2023-08-29 17:56:12 -06:00
|
|
|
response_tokens = len(tokenizer.encode(response, disallowed_special=()))
|
2023-08-29 15:46:56 -06:00
|
|
|
else:
|
|
|
|
response_tokens = None
|
2023-08-21 21:28:52 -06:00
|
|
|
|
2023-08-29 13:46:41 -06:00
|
|
|
# Sometimes we may want to insert null into the DB, but
|
2023-08-25 12:25:30 -06:00
|
|
|
# usually we want to insert a float.
|
|
|
|
if gen_time:
|
|
|
|
gen_time = round(gen_time, 3)
|
2023-08-29 14:48:33 -06:00
|
|
|
if is_error:
|
|
|
|
gen_time = None
|
2023-08-25 12:25:30 -06:00
|
|
|
|
2023-08-21 21:28:52 -06:00
|
|
|
if not opts.log_prompts:
|
2023-08-29 14:48:33 -06:00
|
|
|
prompt = None
|
|
|
|
|
|
|
|
if not opts.log_prompts and not is_error:
|
|
|
|
# TODO: test and verify this works as expected
|
|
|
|
response = None
|
2023-08-21 21:28:52 -06:00
|
|
|
|
|
|
|
timestamp = int(time.time())
|
2023-08-23 22:01:06 -06:00
|
|
|
conn = sqlite3.connect(opts.database_path)
|
2023-08-21 21:28:52 -06:00
|
|
|
c = conn.cursor()
|
2023-09-12 01:04:11 -06:00
|
|
|
c.execute("INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
|
|
|
(ip, token, opts.mode, prompt, prompt_tokens, response, response_tokens, backend_response_code, gen_time, opts.running_model, json.dumps(parameters), json.dumps(headers), timestamp))
|
2023-08-21 21:28:52 -06:00
|
|
|
conn.commit()
|
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
|
|
|
def is_valid_api_key(api_key):
|
|
|
|
conn = sqlite3.connect(opts.database_path)
|
|
|
|
cursor = conn.cursor()
|
2023-08-21 22:49:44 -06:00
|
|
|
cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = ?", (api_key,))
|
2023-08-21 21:28:52 -06:00
|
|
|
row = cursor.fetchone()
|
|
|
|
if row is not None:
|
2023-08-21 22:49:44 -06:00
|
|
|
token, uses, max_uses, expire, disabled = row
|
|
|
|
disabled = bool(disabled)
|
|
|
|
if (uses is None or uses < max_uses) and (expire is None or expire > time.time()) and not disabled:
|
2023-08-21 21:28:52 -06:00
|
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
def increment_uses(api_key):
|
|
|
|
conn = sqlite3.connect(opts.database_path)
|
|
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("SELECT token FROM token_auth WHERE token = ?", (api_key,))
|
|
|
|
row = cursor.fetchone()
|
|
|
|
if row is not None:
|
|
|
|
cursor.execute("UPDATE token_auth SET uses = COALESCE(uses, 0) + 1 WHERE token = ?", (api_key,))
|
|
|
|
conn.commit()
|
|
|
|
return True
|
|
|
|
return False
|
2023-08-23 22:01:06 -06:00
|
|
|
|
|
|
|
|
|
|
|
def get_number_of_rows(table_name):
|
|
|
|
conn = sqlite3.connect(opts.database_path)
|
|
|
|
cur = conn.cursor()
|
|
|
|
cur.execute(f'SELECT COUNT(*) FROM {table_name}')
|
|
|
|
result = cur.fetchone()
|
|
|
|
conn.close()
|
|
|
|
return result[0]
|
2023-08-24 16:47:14 -06:00
|
|
|
|
|
|
|
|
|
|
|
def average_column(table_name, column_name):
|
|
|
|
conn = sqlite3.connect(opts.database_path)
|
|
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute(f"SELECT AVG({column_name}) FROM {table_name}")
|
|
|
|
result = cursor.fetchone()
|
|
|
|
conn.close()
|
|
|
|
return result[0]
|
2023-08-24 20:43:11 -06:00
|
|
|
|
|
|
|
|
2023-08-26 00:30:59 -06:00
|
|
|
def average_column_for_model(table_name, column_name, model_name):
|
|
|
|
conn = sqlite3.connect(opts.database_path)
|
|
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute(f"SELECT AVG({column_name}) FROM {table_name} WHERE model = ?", (model_name,))
|
|
|
|
result = cursor.fetchone()
|
|
|
|
conn.close()
|
|
|
|
return result[0]
|
|
|
|
|
|
|
|
|
2023-09-12 01:04:11 -06:00
|
|
|
def weighted_average_column_for_model(table_name, column_name, model_name, backend_name, exclude_zeros: bool = False):
|
2023-08-27 19:58:04 -06:00
|
|
|
conn = sqlite3.connect(opts.database_path)
|
|
|
|
cursor = conn.cursor()
|
2023-09-12 01:04:11 -06:00
|
|
|
cursor.execute(f"SELECT DISTINCT model, backend FROM {table_name}")
|
|
|
|
models_backends = [(row[0], row[1]) for row in cursor.fetchall()]
|
2023-08-27 19:58:04 -06:00
|
|
|
|
|
|
|
model_averages = {}
|
2023-09-12 01:04:11 -06:00
|
|
|
for model, backend in models_backends:
|
|
|
|
if backend != backend_name:
|
|
|
|
continue
|
|
|
|
cursor.execute(f"SELECT {column_name}, ROWID FROM {table_name} WHERE model = ? AND backend = ? ORDER BY ROWID DESC", (model, backend))
|
2023-08-27 19:58:04 -06:00
|
|
|
results = cursor.fetchall()
|
|
|
|
|
|
|
|
if not results:
|
|
|
|
continue
|
|
|
|
|
|
|
|
total_weight = 0
|
|
|
|
weighted_sum = 0
|
|
|
|
for i, (value, rowid) in enumerate(results):
|
2023-08-29 15:46:56 -06:00
|
|
|
if value is None or (exclude_zeros and value == 0):
|
2023-08-27 19:58:04 -06:00
|
|
|
continue
|
|
|
|
weight = i + 1
|
|
|
|
total_weight += weight
|
|
|
|
weighted_sum += weight * value
|
|
|
|
|
|
|
|
if total_weight == 0:
|
|
|
|
continue
|
|
|
|
|
2023-09-12 01:04:11 -06:00
|
|
|
model_averages[(model, backend)] = weighted_sum / total_weight
|
2023-08-27 19:58:04 -06:00
|
|
|
|
|
|
|
conn.close()
|
|
|
|
|
2023-09-12 01:04:11 -06:00
|
|
|
return model_averages.get((model_name, backend_name))
|
2023-08-27 19:58:04 -06:00
|
|
|
|
|
|
|
|
2023-08-24 20:43:11 -06:00
|
|
|
def sum_column(table_name, column_name):
|
|
|
|
conn = sqlite3.connect(opts.database_path)
|
|
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute(f"SELECT SUM({column_name}) FROM {table_name}")
|
|
|
|
result = cursor.fetchone()
|
|
|
|
conn.close()
|
|
|
|
return result[0] if result[0] else 0
|
2023-08-25 12:20:16 -06:00
|
|
|
|
|
|
|
|
|
|
|
def get_distinct_ips_24h():
|
|
|
|
# Get the current time and subtract 24 hours (in seconds)
|
|
|
|
past_24_hours = int(time.time()) - 24 * 60 * 60
|
|
|
|
conn = sqlite3.connect(opts.database_path)
|
|
|
|
cur = conn.cursor()
|
|
|
|
cur.execute("SELECT COUNT(DISTINCT ip) FROM prompts WHERE timestamp >= ?", (past_24_hours,))
|
|
|
|
result = cur.fetchone()
|
|
|
|
conn.close()
|
|
|
|
return result[0] if result else 0
|