2023-09-20 20:30:31 -06:00
import json
import time
2023-09-23 20:55:49 -06:00
import traceback
2023-10-02 20:53:08 -06:00
from typing import Union
2023-09-20 20:30:31 -06:00
from llm_server import opts
2023-09-30 19:41:50 -06:00
from llm_server . cluster . cluster_config import cluster_config
2023-09-26 23:59:22 -06:00
from llm_server . database . conn import database
2023-10-02 20:53:08 -06:00
from llm_server . llm import get_token_count
2023-09-20 20:30:31 -06:00
2023-10-04 19:24:47 -06:00
def do_db_log ( ip : str , token : str , prompt : str , response : Union [ str , None ] , gen_time : Union [ int , float , None ] , parameters : dict , headers : dict , backend_response_code : int , request_url : str , backend_url : str , response_tokens : int = None , is_error : bool = False ) :
2023-10-02 21:43:36 -06:00
assert isinstance ( prompt , str )
assert isinstance ( backend_url , str )
2023-10-04 19:24:47 -06:00
# Try not to shove JSON into the database.
if isinstance ( response , dict ) and response . get ( ' results ' ) :
response = response [ ' results ' ] [ 0 ] [ ' text ' ]
try :
j = json . loads ( response )
if j . get ( ' results ' ) :
response = j [ ' results ' ] [ 0 ] [ ' text ' ]
except :
pass
2023-09-29 00:09:44 -06:00
2023-10-04 19:24:47 -06:00
prompt_tokens = get_token_count ( prompt , backend_url )
2023-09-29 00:09:44 -06:00
2023-10-04 19:24:47 -06:00
if not is_error :
if not response_tokens :
response_tokens = get_token_count ( response , backend_url )
else :
response_tokens = None
2023-09-29 00:09:44 -06:00
2023-10-04 19:24:47 -06:00
# Sometimes we may want to insert null into the DB, but
# usually we want to insert a float.
if gen_time :
gen_time = round ( gen_time , 3 )
if is_error :
gen_time = None
2023-09-29 00:09:44 -06:00
2023-10-04 19:24:47 -06:00
if not opts . log_prompts :
prompt = None
2023-09-29 00:09:44 -06:00
2023-10-04 19:24:47 -06:00
if not opts . log_prompts and not is_error :
# TODO: test and verify this works as expected
response = None
2023-09-29 00:09:44 -06:00
2023-10-04 19:24:47 -06:00
if token :
increment_token_uses ( token )
2023-09-29 00:09:44 -06:00
2023-10-04 19:24:47 -06:00
backend_info = cluster_config . get_backend ( backend_url )
running_model = backend_info . get ( ' model ' )
backend_mode = backend_info [ ' mode ' ]
timestamp = int ( time . time ( ) )
cursor = database . cursor ( )
try :
cursor . execute ( """
INSERT INTO prompts
( ip , token , model , backend_mode , backend_url , request_url , generation_time , prompt , prompt_tokens , response , response_tokens , response_status , parameters , headers , timestamp )
VALUES ( % s , % s , % s , % s , % s , % s , % s , % s , % s , % s , % s , % s , % s , % s , % s )
""" ,
( ip , token , running_model , backend_mode , backend_url , request_url , gen_time , prompt , prompt_tokens , response , response_tokens , backend_response_code , json . dumps ( parameters ) , json . dumps ( headers ) , timestamp ) )
finally :
cursor . close ( )
2023-09-20 20:30:31 -06:00
def is_valid_api_key ( api_key ) :
2023-09-26 23:59:22 -06:00
cursor = database . cursor ( )
2023-09-20 20:30:31 -06:00
try :
cursor . execute ( " SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = %s " , ( api_key , ) )
row = cursor . fetchone ( )
if row is not None :
token , uses , max_uses , expire , disabled = row
disabled = bool ( disabled )
2023-09-26 22:09:11 -06:00
if ( ( uses is None or max_uses is None ) or uses < max_uses ) and ( expire is None or expire > time . time ( ) ) and not disabled :
2023-09-20 20:30:31 -06:00
return True
return False
finally :
2023-09-20 21:19:26 -06:00
cursor . close ( )
2023-09-20 20:30:31 -06:00
2023-09-26 22:09:11 -06:00
def is_api_key_moderated ( api_key ) :
if not api_key :
return opts . openai_moderation_enabled
2023-09-26 23:59:22 -06:00
cursor = database . cursor ( )
2023-09-26 22:09:11 -06:00
try :
cursor . execute ( " SELECT openai_moderation_enabled FROM token_auth WHERE token = %s " , ( api_key , ) )
row = cursor . fetchone ( )
if row is not None :
return bool ( row [ 0 ] )
return opts . openai_moderation_enabled
finally :
cursor . close ( )
2023-09-20 20:30:31 -06:00
def get_number_of_rows ( table_name ) :
2023-09-26 23:59:22 -06:00
cursor = database . cursor ( )
2023-09-20 21:19:26 -06:00
try :
2023-09-25 23:22:16 -06:00
cursor . execute ( f " SELECT COUNT(*) FROM { table_name } WHERE token NOT LIKE ' SYSTEM__%% ' OR token IS NULL " )
2023-09-20 21:19:26 -06:00
result = cursor . fetchone ( )
return result [ 0 ]
finally :
cursor . close ( )
2023-09-20 20:30:31 -06:00
def average_column ( table_name , column_name ) :
2023-09-26 23:59:22 -06:00
cursor = database . cursor ( )
2023-09-20 21:19:26 -06:00
try :
2023-09-25 23:22:16 -06:00
cursor . execute ( f " SELECT AVG( { column_name } ) FROM { table_name } WHERE token NOT LIKE ' SYSTEM__%% ' OR token IS NULL " )
2023-09-20 21:19:26 -06:00
result = cursor . fetchone ( )
return result [ 0 ]
finally :
cursor . close ( )
2023-09-20 20:30:31 -06:00
def average_column_for_model ( table_name , column_name , model_name ) :
2023-09-26 23:59:22 -06:00
cursor = database . cursor ( )
2023-09-20 21:19:26 -06:00
try :
2023-09-25 23:22:16 -06:00
cursor . execute ( f " SELECT AVG( { column_name } ) FROM { table_name } WHERE model = %s AND token NOT LIKE ' SYSTEM__%% ' OR token IS NULL " , ( model_name , ) )
2023-09-20 21:19:26 -06:00
result = cursor . fetchone ( )
return result [ 0 ]
finally :
cursor . close ( )
2023-09-20 20:30:31 -06:00
2023-09-25 23:39:50 -06:00
def weighted_average_column_for_model ( table_name , column_name , model_name , backend_name , backend_url , exclude_zeros : bool = False , include_system_tokens : bool = True ) :
if include_system_tokens :
sql = f " SELECT { column_name } , id FROM { table_name } WHERE model = %s AND backend_mode = %s AND backend_url = %s ORDER BY id DESC "
else :
sql = f " SELECT { column_name } , id FROM { table_name } WHERE model = %s AND backend_mode = %s AND backend_url = %s AND (token NOT LIKE ' SYSTEM__%% ' OR token IS NULL) ORDER BY id DESC "
2023-09-26 23:59:22 -06:00
cursor = database . cursor ( )
2023-09-20 21:19:26 -06:00
try :
2023-09-23 20:55:49 -06:00
try :
2023-09-25 23:39:50 -06:00
cursor . execute ( sql , ( model_name , backend_name , backend_url , ) )
2023-09-23 20:55:49 -06:00
results = cursor . fetchall ( )
except Exception :
traceback . print_exc ( )
2023-09-23 22:30:59 -06:00
return None
2023-09-20 21:19:26 -06:00
total_weight = 0
weighted_sum = 0
for i , ( value , rowid ) in enumerate ( results ) :
if value is None or ( exclude_zeros and value == 0 ) :
continue
weight = i + 1
total_weight + = weight
weighted_sum + = weight * value
if total_weight > 0 :
# Avoid division by zero
calculated_avg = weighted_sum / total_weight
else :
calculated_avg = 0
return calculated_avg
finally :
cursor . close ( )
2023-09-20 20:30:31 -06:00
def sum_column ( table_name , column_name ) :
2023-09-26 23:59:22 -06:00
cursor = database . cursor ( )
2023-09-20 21:19:26 -06:00
try :
2023-09-25 23:22:16 -06:00
cursor . execute ( f " SELECT SUM( { column_name } ) FROM { table_name } WHERE token NOT LIKE ' SYSTEM__%% ' OR token IS NULL " )
2023-09-20 21:19:26 -06:00
result = cursor . fetchone ( )
2023-09-23 18:55:52 -06:00
return result [ 0 ] if result else 0
2023-09-20 21:19:26 -06:00
finally :
cursor . close ( )
2023-09-20 20:30:31 -06:00
def get_distinct_ips_24h ( ) :
# Get the current time and subtract 24 hours (in seconds)
past_24_hours = int ( time . time ( ) ) - 24 * 60 * 60
2023-09-26 23:59:22 -06:00
cursor = database . cursor ( )
2023-09-20 21:19:26 -06:00
try :
2023-09-25 17:20:21 -06:00
cursor . execute ( " SELECT COUNT(DISTINCT ip) FROM prompts WHERE timestamp >= %s AND (token NOT LIKE ' SYSTEM__ %% ' OR token IS NULL) " , ( past_24_hours , ) )
2023-09-20 21:19:26 -06:00
result = cursor . fetchone ( )
return result [ 0 ] if result else 0
finally :
cursor . close ( )
2023-09-25 00:55:20 -06:00
def increment_token_uses ( token ) :
2023-09-26 23:59:22 -06:00
cursor = database . cursor ( )
2023-09-25 00:55:20 -06:00
try :
cursor . execute ( ' UPDATE token_auth SET uses = uses + 1 WHERE token = %s ' , ( token , ) )
finally :
cursor . close ( )
2023-10-02 02:05:15 -06:00
def get_token_ratelimit ( token ) :
priority = 9990
simultaneous_ip = opts . simultaneous_requests_per_ip
if token :
cursor = database . cursor ( )
try :
cursor . execute ( " SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s " , ( token , ) )
result = cursor . fetchone ( )
if result :
priority , simultaneous_ip = result
if simultaneous_ip is None :
# No ratelimit for this token if null
simultaneous_ip = 999999999
finally :
cursor . close ( )
return priority , simultaneous_ip