82 lines
2.6 KiB
Python
82 lines
2.6 KiB
Python
import json
|
|
import sqlite3
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import tiktoken
|
|
|
|
from llm_server import opts
|
|
|
|
tokenizer = tiktoken.get_encoding("cl100k_base")
|
|
|
|
|
|
def init_db(db_path):
|
|
if not Path(db_path).exists():
|
|
conn = sqlite3.connect(db_path)
|
|
c = conn.cursor()
|
|
c.execute('''
|
|
CREATE TABLE prompts (
|
|
ip TEXT,
|
|
token TEXT DEFAULT NULL,
|
|
prompt TEXT,
|
|
prompt_tokens INTEGER,
|
|
response TEXT,
|
|
response_tokens INTEGER,
|
|
response_status INTEGER,
|
|
parameters TEXT CHECK (parameters IS NULL OR json_valid(parameters)),
|
|
headers TEXT CHECK (headers IS NULL OR json_valid(headers)),
|
|
timestamp INTEGER
|
|
)
|
|
''')
|
|
c.execute('''
|
|
CREATE TABLE token_auth
|
|
(token TEXT, type TEXT NOT NULL, uses INTEGER, max_uses INTEGER, expire INTEGER, disabled BOOLEAN default 0)
|
|
''')
|
|
# c.execute('''
|
|
# CREATE TABLE leeches
|
|
# (url TEXT, online TEXT)
|
|
# ''')
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
|
|
def log_prompt(db_path, ip, token, prompt, response, parameters, headers, backend_response_code):
|
|
prompt_tokens = len(tokenizer.encode(prompt))
|
|
response_tokens = len(tokenizer.encode(response))
|
|
|
|
if not opts.log_prompts:
|
|
prompt = response = None
|
|
|
|
timestamp = int(time.time())
|
|
conn = sqlite3.connect(db_path)
|
|
c = conn.cursor()
|
|
c.execute("INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
|
(ip, token, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
|
|
def is_valid_api_key(api_key):
|
|
conn = sqlite3.connect(opts.database_path)
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = ?", (api_key,))
|
|
row = cursor.fetchone()
|
|
if row is not None:
|
|
token, uses, max_uses, expire, disabled = row
|
|
disabled = bool(disabled)
|
|
if (uses is None or uses < max_uses) and (expire is None or expire > time.time()) and not disabled:
|
|
return True
|
|
return False
|
|
|
|
|
|
def increment_uses(api_key):
|
|
conn = sqlite3.connect(opts.database_path)
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT token FROM token_auth WHERE token = ?", (api_key,))
|
|
row = cursor.fetchone()
|
|
if row is not None:
|
|
cursor.execute("UPDATE token_auth SET uses = COALESCE(uses, 0) + 1 WHERE token = ?", (api_key,))
|
|
conn.commit()
|
|
return True
|
|
return False
|