import json import sqlite3 import time from pathlib import Path import tiktoken from llm_server import opts tokenizer = tiktoken.get_encoding("cl100k_base") def init_db(): if not Path(opts.database_path).exists(): conn = sqlite3.connect(opts.database_path) c = conn.cursor() c.execute(''' CREATE TABLE prompts ( ip TEXT, token TEXT DEFAULT NULL, prompt TEXT, prompt_tokens INTEGER, response TEXT, response_tokens INTEGER, response_status INTEGER, generation_time FLOAT, parameters TEXT CHECK (parameters IS NULL OR json_valid(parameters)), headers TEXT CHECK (headers IS NULL OR json_valid(headers)), timestamp INTEGER ) ''') c.execute(''' CREATE TABLE token_auth ( token TEXT UNIQUE, type TEXT NOT NULL, priority INTEGER default 9999, uses INTEGER default 0, max_uses INTEGER, expire INTEGER, disabled BOOLEAN default 0 ) ''') conn.commit() conn.close() def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code): prompt_tokens = len(tokenizer.encode(prompt)) response_tokens = len(tokenizer.encode(response)) if not opts.log_prompts: prompt = response = None timestamp = int(time.time()) conn = sqlite3.connect(opts.database_path) c = conn.cursor() c.execute("INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", (ip, token, prompt, prompt_tokens, response, response_tokens, backend_response_code, round(gen_time, 3), json.dumps(parameters), json.dumps(headers), timestamp)) conn.commit() conn.close() def is_valid_api_key(api_key): conn = sqlite3.connect(opts.database_path) cursor = conn.cursor() cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = ?", (api_key,)) row = cursor.fetchone() if row is not None: token, uses, max_uses, expire, disabled = row disabled = bool(disabled) if (uses is None or uses < max_uses) and (expire is None or expire > time.time()) and not disabled: return True return False def increment_uses(api_key): conn = sqlite3.connect(opts.database_path) cursor = conn.cursor() cursor.execute("SELECT token FROM token_auth WHERE token = ?", (api_key,)) row = cursor.fetchone() if row is not None: cursor.execute("UPDATE token_auth SET uses = COALESCE(uses, 0) + 1 WHERE token = ?", (api_key,)) conn.commit() return True return False def get_number_of_rows(table_name): conn = sqlite3.connect(opts.database_path) cur = conn.cursor() cur.execute(f'SELECT COUNT(*) FROM {table_name}') result = cur.fetchone() conn.close() return result[0] def average_column(table_name, column_name): conn = sqlite3.connect(opts.database_path) cursor = conn.cursor() cursor.execute(f"SELECT AVG({column_name}) FROM {table_name}") result = cursor.fetchone() conn.close() return result[0]