2023-08-21 21:28:52 -06:00
import json
import sqlite3
import time
from pathlib import Path
import tiktoken
from llm_server import opts
tokenizer = tiktoken . get_encoding ( " cl100k_base " )
2023-08-23 22:01:06 -06:00
def init_db ( ) :
if not Path ( opts . database_path ) . exists ( ) :
conn = sqlite3 . connect ( opts . database_path )
2023-08-21 21:28:52 -06:00
c = conn . cursor ( )
c . execute ( '''
CREATE TABLE prompts (
ip TEXT ,
token TEXT DEFAULT NULL ,
2023-09-13 11:22:33 -06:00
model TEXT ,
backend_mode TEXT ,
backend_url TEXT ,
request_url TEXT ,
generation_time FLOAT ,
2023-08-21 21:28:52 -06:00
prompt TEXT ,
prompt_tokens INTEGER ,
response TEXT ,
response_tokens INTEGER ,
2023-08-22 19:58:31 -06:00
response_status INTEGER ,
2023-08-21 21:28:52 -06:00
parameters TEXT CHECK ( parameters IS NULL OR json_valid ( parameters ) ) ,
headers TEXT CHECK ( headers IS NULL OR json_valid ( headers ) ) ,
timestamp INTEGER
)
''' )
c . execute ( '''
2023-08-23 20:12:38 -06:00
CREATE TABLE token_auth (
token TEXT UNIQUE ,
type TEXT NOT NULL ,
priority INTEGER default 9999 ,
uses INTEGER default 0 ,
max_uses INTEGER ,
expire INTEGER ,
disabled BOOLEAN default 0
)
2023-08-21 21:28:52 -06:00
''' )
conn . commit ( )
conn . close ( )
2023-09-13 11:22:33 -06:00
def log_prompt ( ip , token , prompt , response , gen_time , parameters , headers , backend_response_code , request_url , response_tokens : int = None , is_error : bool = False ) :
2023-08-21 21:28:52 -06:00
prompt_tokens = len ( tokenizer . encode ( prompt ) )
2023-08-29 14:48:33 -06:00
2023-08-29 15:46:56 -06:00
if not is_error :
if not response_tokens :
2023-08-29 17:56:12 -06:00
response_tokens = len ( tokenizer . encode ( response , disallowed_special = ( ) ) )
2023-08-29 15:46:56 -06:00
else :
response_tokens = None
2023-08-21 21:28:52 -06:00
2023-08-29 13:46:41 -06:00
# Sometimes we may want to insert null into the DB, but
2023-08-25 12:25:30 -06:00
# usually we want to insert a float.
if gen_time :
gen_time = round ( gen_time , 3 )
2023-08-29 14:48:33 -06:00
if is_error :
gen_time = None
2023-08-25 12:25:30 -06:00
2023-08-21 21:28:52 -06:00
if not opts . log_prompts :
2023-08-29 14:48:33 -06:00
prompt = None
if not opts . log_prompts and not is_error :
# TODO: test and verify this works as expected
response = None
2023-08-21 21:28:52 -06:00
timestamp = int ( time . time ( ) )
2023-08-23 22:01:06 -06:00
conn = sqlite3 . connect ( opts . database_path )
2023-08-21 21:28:52 -06:00
c = conn . cursor ( )
2023-09-13 11:51:46 -06:00
c . execute ( " INSERT INTO prompts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) " ,
2023-09-13 11:22:33 -06:00
( ip , token , opts . running_model , opts . mode , opts . backend_url , request_url , gen_time , prompt , prompt_tokens , response , response_tokens , backend_response_code , json . dumps ( parameters ) , json . dumps ( headers ) , timestamp ) )
2023-08-21 21:28:52 -06:00
conn . commit ( )
conn . close ( )
def is_valid_api_key ( api_key ) :
conn = sqlite3 . connect ( opts . database_path )
cursor = conn . cursor ( )
2023-08-21 22:49:44 -06:00
cursor . execute ( " SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = ? " , ( api_key , ) )
2023-08-21 21:28:52 -06:00
row = cursor . fetchone ( )
if row is not None :
2023-08-21 22:49:44 -06:00
token , uses , max_uses , expire , disabled = row
disabled = bool ( disabled )
if ( uses is None or uses < max_uses ) and ( expire is None or expire > time . time ( ) ) and not disabled :
2023-08-21 21:28:52 -06:00
return True
return False
def increment_uses ( api_key ) :
conn = sqlite3 . connect ( opts . database_path )
cursor = conn . cursor ( )
cursor . execute ( " SELECT token FROM token_auth WHERE token = ? " , ( api_key , ) )
row = cursor . fetchone ( )
if row is not None :
cursor . execute ( " UPDATE token_auth SET uses = COALESCE(uses, 0) + 1 WHERE token = ? " , ( api_key , ) )
conn . commit ( )
return True
return False
2023-08-23 22:01:06 -06:00
def get_number_of_rows ( table_name ) :
conn = sqlite3 . connect ( opts . database_path )
cur = conn . cursor ( )
cur . execute ( f ' SELECT COUNT(*) FROM { table_name } ' )
result = cur . fetchone ( )
conn . close ( )
return result [ 0 ]
2023-08-24 16:47:14 -06:00
def average_column ( table_name , column_name ) :
conn = sqlite3 . connect ( opts . database_path )
cursor = conn . cursor ( )
cursor . execute ( f " SELECT AVG( { column_name } ) FROM { table_name } " )
result = cursor . fetchone ( )
conn . close ( )
return result [ 0 ]
2023-08-24 20:43:11 -06:00
2023-08-26 00:30:59 -06:00
def average_column_for_model ( table_name , column_name , model_name ) :
conn = sqlite3 . connect ( opts . database_path )
cursor = conn . cursor ( )
cursor . execute ( f " SELECT AVG( { column_name } ) FROM { table_name } WHERE model = ? " , ( model_name , ) )
result = cursor . fetchone ( )
conn . close ( )
return result [ 0 ]
2023-09-13 12:34:14 -06:00
def weighted_average_column_for_model ( table_name , column_name , model_name , backend_name , backend_url , exclude_zeros : bool = False ) :
2023-08-27 19:58:04 -06:00
conn = sqlite3 . connect ( opts . database_path )
cursor = conn . cursor ( )
2023-09-13 12:34:14 -06:00
# cursor.execute(f"SELECT DISTINCT model, backend_mode FROM {table_name}")
# models_backends = [(row[0], row[1]) for row in cursor.fetchall()]
#
# model_averages = {}
# for model, backend in models_backends:
# if backend != backend_name:
# continue
cursor . execute ( f " SELECT { column_name } , ROWID FROM { table_name } WHERE model = ? AND backend_mode = ? AND backend_url = ? ORDER BY ROWID DESC " , ( model_name , backend_name , backend_url ) )
results = cursor . fetchall ( )
# if not results:
# continue
total_weight = 0
weighted_sum = 0
for i , ( value , rowid ) in enumerate ( results ) :
if value is None or ( exclude_zeros and value == 0 ) :
2023-08-27 19:58:04 -06:00
continue
2023-09-13 12:34:14 -06:00
weight = i + 1
total_weight + = weight
weighted_sum + = weight * value
2023-08-27 19:58:04 -06:00
2023-09-13 12:34:14 -06:00
# if total_weight == 0:
# continue
2023-08-27 19:58:04 -06:00
2023-09-13 12:34:14 -06:00
calculated_avg = weighted_sum / total_weight
2023-08-27 19:58:04 -06:00
conn . close ( )
2023-09-13 12:34:14 -06:00
return calculated_avg
2023-08-27 19:58:04 -06:00
2023-08-24 20:43:11 -06:00
def sum_column ( table_name , column_name ) :
conn = sqlite3 . connect ( opts . database_path )
cursor = conn . cursor ( )
cursor . execute ( f " SELECT SUM( { column_name } ) FROM { table_name } " )
result = cursor . fetchone ( )
conn . close ( )
return result [ 0 ] if result [ 0 ] else 0
2023-08-25 12:20:16 -06:00
def get_distinct_ips_24h ( ) :
# Get the current time and subtract 24 hours (in seconds)
past_24_hours = int ( time . time ( ) ) - 24 * 60 * 60
conn = sqlite3 . connect ( opts . database_path )
cur = conn . cursor ( )
cur . execute ( " SELECT COUNT(DISTINCT ip) FROM prompts WHERE timestamp >= ? " , ( past_24_hours , ) )
result = cur . fetchone ( )
conn . close ( )
return result [ 0 ] if result else 0