local-llm-server/llm_server/database.py

104 lines
3.2 KiB
Python
Raw Normal View History

2023-08-21 21:28:52 -06:00
import json
import sqlite3
import time
from pathlib import Path
import tiktoken
from llm_server import opts
tokenizer = tiktoken.get_encoding("cl100k_base")
def init_db():
if not Path(opts.database_path).exists():
conn = sqlite3.connect(opts.database_path)
2023-08-21 21:28:52 -06:00
c = conn.cursor()
c.execute('''
CREATE TABLE prompts (
ip TEXT,
token TEXT DEFAULT NULL,
prompt TEXT,
prompt_tokens INTEGER,
response TEXT,
response_tokens INTEGER,
response_status INTEGER,
2023-08-23 22:28:03 -06:00
generation_time FLOAT,
2023-08-21 21:28:52 -06:00
parameters TEXT CHECK (parameters IS NULL OR json_valid(parameters)),
headers TEXT CHECK (headers IS NULL OR json_valid(headers)),
timestamp INTEGER
)
''')
c.execute('''
2023-08-23 20:12:38 -06:00
CREATE TABLE token_auth (
token TEXT UNIQUE,
type TEXT NOT NULL,
priority INTEGER default 9999,
uses INTEGER default 0,
max_uses INTEGER,
expire INTEGER,
disabled BOOLEAN default 0
)
2023-08-21 21:28:52 -06:00
''')
conn.commit()
conn.close()
def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code):
2023-08-21 21:28:52 -06:00
prompt_tokens = len(tokenizer.encode(prompt))
response_tokens = len(tokenizer.encode(response))
if not opts.log_prompts:
prompt = response = None
timestamp = int(time.time())
conn = sqlite3.connect(opts.database_path)
2023-08-21 21:28:52 -06:00
c = conn.cursor()
c.execute("INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
2023-08-23 22:28:03 -06:00
(ip, token, prompt, prompt_tokens, response, response_tokens, backend_response_code, round(gen_time, 3), json.dumps(parameters), json.dumps(headers), timestamp))
2023-08-21 21:28:52 -06:00
conn.commit()
conn.close()
def is_valid_api_key(api_key):
conn = sqlite3.connect(opts.database_path)
cursor = conn.cursor()
2023-08-21 22:49:44 -06:00
cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = ?", (api_key,))
2023-08-21 21:28:52 -06:00
row = cursor.fetchone()
if row is not None:
2023-08-21 22:49:44 -06:00
token, uses, max_uses, expire, disabled = row
disabled = bool(disabled)
if (uses is None or uses < max_uses) and (expire is None or expire > time.time()) and not disabled:
2023-08-21 21:28:52 -06:00
return True
return False
def increment_uses(api_key):
conn = sqlite3.connect(opts.database_path)
cursor = conn.cursor()
cursor.execute("SELECT token FROM token_auth WHERE token = ?", (api_key,))
row = cursor.fetchone()
if row is not None:
cursor.execute("UPDATE token_auth SET uses = COALESCE(uses, 0) + 1 WHERE token = ?", (api_key,))
conn.commit()
return True
return False
def get_number_of_rows(table_name):
conn = sqlite3.connect(opts.database_path)
cur = conn.cursor()
cur.execute(f'SELECT COUNT(*) FROM {table_name}')
result = cur.fetchone()
conn.close()
return result[0]
def average_column(table_name, column_name):
conn = sqlite3.connect(opts.database_path)
cursor = conn.cursor()
cursor.execute(f"SELECT AVG({column_name}) FROM {table_name}")
result = cursor.fetchone()
conn.close()
return result[0]