import json import sqlite3 import time from pathlib import Path import tiktoken from llm_server import opts tokenizer = tiktoken.get_encoding("cl100k_base") def init_db(): if not Path(opts.database_path).exists(): conn = sqlite3.connect(opts.database_path) c = conn.cursor() c.execute(''' CREATE TABLE prompts ( ip TEXT, token TEXT DEFAULT NULL, model TEXT, backend_mode TEXT, backend_url TEXT, request_url TEXT, generation_time FLOAT, prompt TEXT, prompt_tokens INTEGER, response TEXT, response_tokens INTEGER, response_status INTEGER, parameters TEXT CHECK (parameters IS NULL OR json_valid(parameters)), headers TEXT CHECK (headers IS NULL OR json_valid(headers)), timestamp INTEGER ) ''') c.execute(''' CREATE TABLE token_auth ( token TEXT UNIQUE, type TEXT NOT NULL, priority INTEGER default 9999, uses INTEGER default 0, max_uses INTEGER, expire INTEGER, disabled BOOLEAN default 0 ) ''') conn.commit() conn.close() def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, response_tokens: int = None, is_error: bool = False): prompt_tokens = len(tokenizer.encode(prompt)) if not is_error: if not response_tokens: response_tokens = len(tokenizer.encode(response, disallowed_special=())) else: response_tokens = None # Sometimes we may want to insert null into the DB, but # usually we want to insert a float. if gen_time: gen_time = round(gen_time, 3) if is_error: gen_time = None if not opts.log_prompts: prompt = None if not opts.log_prompts and not is_error: # TODO: test and verify this works as expected response = None timestamp = int(time.time()) conn = sqlite3.connect(opts.database_path) c = conn.cursor() c.execute("INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", (ip, token, opts.running_model, opts.mode, opts.backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp)) conn.commit() conn.close() def is_valid_api_key(api_key): conn = sqlite3.connect(opts.database_path) cursor = conn.cursor() cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = ?", (api_key,)) row = cursor.fetchone() if row is not None: token, uses, max_uses, expire, disabled = row disabled = bool(disabled) if (uses is None or uses < max_uses) and (expire is None or expire > time.time()) and not disabled: return True return False def increment_uses(api_key): conn = sqlite3.connect(opts.database_path) cursor = conn.cursor() cursor.execute("SELECT token FROM token_auth WHERE token = ?", (api_key,)) row = cursor.fetchone() if row is not None: cursor.execute("UPDATE token_auth SET uses = COALESCE(uses, 0) + 1 WHERE token = ?", (api_key,)) conn.commit() return True return False def get_number_of_rows(table_name): conn = sqlite3.connect(opts.database_path) cur = conn.cursor() cur.execute(f'SELECT COUNT(*) FROM {table_name}') result = cur.fetchone() conn.close() return result[0] def average_column(table_name, column_name): conn = sqlite3.connect(opts.database_path) cursor = conn.cursor() cursor.execute(f"SELECT AVG({column_name}) FROM {table_name}") result = cursor.fetchone() conn.close() return result[0] def average_column_for_model(table_name, column_name, model_name): conn = sqlite3.connect(opts.database_path) cursor = conn.cursor() cursor.execute(f"SELECT AVG({column_name}) FROM {table_name} WHERE model = ?", (model_name,)) result = cursor.fetchone() conn.close() return result[0] def weighted_average_column_for_model(table_name, column_name, model_name, backend_name, backend_url, exclude_zeros: bool = False): conn = sqlite3.connect(opts.database_path) cursor = conn.cursor() # cursor.execute(f"SELECT DISTINCT model, backend_mode FROM {table_name}") # models_backends = [(row[0], row[1]) for row in cursor.fetchall()] # # model_averages = {} # for model, backend in models_backends: # if backend != backend_name: # continue cursor.execute(f"SELECT {column_name}, ROWID FROM {table_name} WHERE model = ? AND backend_mode = ? AND backend_url = ? ORDER BY ROWID DESC", (model_name, backend_name, backend_url)) results = cursor.fetchall() # if not results: # continue total_weight = 0 weighted_sum = 0 for i, (value, rowid) in enumerate(results): if value is None or (exclude_zeros and value == 0): continue weight = i + 1 total_weight += weight weighted_sum += weight * value # if total_weight == 0: # continue if total_weight > 0: # Avoid division by zero calculated_avg = weighted_sum / total_weight else: calculated_avg = 0 conn.close() return calculated_avg def sum_column(table_name, column_name): conn = sqlite3.connect(opts.database_path) cursor = conn.cursor() cursor.execute(f"SELECT SUM({column_name}) FROM {table_name}") result = cursor.fetchone() conn.close() return result[0] if result[0] else 0 def get_distinct_ips_24h(): # Get the current time and subtract 24 hours (in seconds) past_24_hours = int(time.time()) - 24 * 60 * 60 conn = sqlite3.connect(opts.database_path) cur = conn.cursor() cur.execute("SELECT COUNT(DISTINCT ip) FROM prompts WHERE timestamp >= ?", (past_24_hours,)) result = cur.fetchone() conn.close() return result[0] if result else 0