182 lines
6.1 KiB
Python
182 lines
6.1 KiB
Python
import json
|
|
import time
|
|
import traceback
|
|
|
|
import llm_server
|
|
from llm_server import opts
|
|
from llm_server.database.conn import database
|
|
from llm_server.llm.vllm import tokenize
|
|
from llm_server.routes.cache import redis
|
|
|
|
|
|
def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, response_tokens: int = None, is_error: bool = False):
|
|
if isinstance(response, dict) and response.get('results'):
|
|
response = response['results'][0]['text']
|
|
try:
|
|
j = json.loads(response)
|
|
if j.get('results'):
|
|
response = j['results'][0]['text']
|
|
except:
|
|
pass
|
|
|
|
prompt_tokens = llm_server.llm.get_token_count(prompt)
|
|
if not is_error:
|
|
if not response_tokens:
|
|
response_tokens = llm_server.llm.get_token_count(response)
|
|
else:
|
|
response_tokens = None
|
|
|
|
# Sometimes we may want to insert null into the DB, but
|
|
# usually we want to insert a float.
|
|
if gen_time:
|
|
gen_time = round(gen_time, 3)
|
|
if is_error:
|
|
gen_time = None
|
|
|
|
if not opts.log_prompts:
|
|
prompt = None
|
|
|
|
if not opts.log_prompts and not is_error:
|
|
# TODO: test and verify this works as expected
|
|
response = None
|
|
|
|
if token:
|
|
increment_token_uses(token)
|
|
|
|
running_model = redis.get('running_model', str, 'ERROR')
|
|
timestamp = int(time.time())
|
|
cursor = database.cursor()
|
|
try:
|
|
cursor.execute("""
|
|
INSERT INTO prompts
|
|
(ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp)
|
|
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
|
""",
|
|
(ip, token, running_model, opts.mode, opts.backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
|
|
finally:
|
|
cursor.close()
|
|
|
|
|
|
def is_valid_api_key(api_key):
|
|
cursor = database.cursor()
|
|
try:
|
|
cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = %s", (api_key,))
|
|
row = cursor.fetchone()
|
|
if row is not None:
|
|
token, uses, max_uses, expire, disabled = row
|
|
disabled = bool(disabled)
|
|
if ((uses is None or max_uses is None) or uses < max_uses) and (expire is None or expire > time.time()) and not disabled:
|
|
return True
|
|
return False
|
|
finally:
|
|
cursor.close()
|
|
|
|
|
|
def is_api_key_moderated(api_key):
|
|
if not api_key:
|
|
return opts.openai_moderation_enabled
|
|
cursor = database.cursor()
|
|
try:
|
|
cursor.execute("SELECT openai_moderation_enabled FROM token_auth WHERE token = %s", (api_key,))
|
|
row = cursor.fetchone()
|
|
if row is not None:
|
|
return bool(row[0])
|
|
return opts.openai_moderation_enabled
|
|
finally:
|
|
cursor.close()
|
|
|
|
|
|
def get_number_of_rows(table_name):
|
|
cursor = database.cursor()
|
|
try:
|
|
cursor.execute(f"SELECT COUNT(*) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%' OR token IS NULL")
|
|
result = cursor.fetchone()
|
|
return result[0]
|
|
finally:
|
|
cursor.close()
|
|
|
|
|
|
def average_column(table_name, column_name):
|
|
cursor = database.cursor()
|
|
try:
|
|
cursor.execute(f"SELECT AVG({column_name}) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%' OR token IS NULL")
|
|
result = cursor.fetchone()
|
|
return result[0]
|
|
finally:
|
|
cursor.close()
|
|
|
|
|
|
def average_column_for_model(table_name, column_name, model_name):
|
|
cursor = database.cursor()
|
|
try:
|
|
cursor.execute(f"SELECT AVG({column_name}) FROM {table_name} WHERE model = %s AND token NOT LIKE 'SYSTEM__%%' OR token IS NULL", (model_name,))
|
|
result = cursor.fetchone()
|
|
return result[0]
|
|
finally:
|
|
cursor.close()
|
|
|
|
|
|
def weighted_average_column_for_model(table_name, column_name, model_name, backend_name, backend_url, exclude_zeros: bool = False, include_system_tokens: bool = True):
|
|
if include_system_tokens:
|
|
sql = f"SELECT {column_name}, id FROM {table_name} WHERE model = %s AND backend_mode = %s AND backend_url = %s ORDER BY id DESC"
|
|
else:
|
|
sql = f"SELECT {column_name}, id FROM {table_name} WHERE model = %s AND backend_mode = %s AND backend_url = %s AND (token NOT LIKE 'SYSTEM__%%' OR token IS NULL) ORDER BY id DESC"
|
|
|
|
cursor = database.cursor()
|
|
try:
|
|
try:
|
|
cursor.execute(sql, (model_name, backend_name, backend_url,))
|
|
results = cursor.fetchall()
|
|
except Exception:
|
|
traceback.print_exc()
|
|
return None
|
|
|
|
total_weight = 0
|
|
weighted_sum = 0
|
|
for i, (value, rowid) in enumerate(results):
|
|
if value is None or (exclude_zeros and value == 0):
|
|
continue
|
|
weight = i + 1
|
|
total_weight += weight
|
|
weighted_sum += weight * value
|
|
|
|
if total_weight > 0:
|
|
# Avoid division by zero
|
|
calculated_avg = weighted_sum / total_weight
|
|
else:
|
|
calculated_avg = 0
|
|
|
|
return calculated_avg
|
|
finally:
|
|
cursor.close()
|
|
|
|
|
|
def sum_column(table_name, column_name):
|
|
cursor = database.cursor()
|
|
try:
|
|
cursor.execute(f"SELECT SUM({column_name}) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%' OR token IS NULL")
|
|
result = cursor.fetchone()
|
|
return result[0] if result else 0
|
|
finally:
|
|
cursor.close()
|
|
|
|
|
|
def get_distinct_ips_24h():
|
|
# Get the current time and subtract 24 hours (in seconds)
|
|
past_24_hours = int(time.time()) - 24 * 60 * 60
|
|
cursor = database.cursor()
|
|
try:
|
|
cursor.execute("SELECT COUNT(DISTINCT ip) FROM prompts WHERE timestamp >= %s AND (token NOT LIKE 'SYSTEM__%%' OR token IS NULL)", (past_24_hours,))
|
|
result = cursor.fetchone()
|
|
return result[0] if result else 0
|
|
finally:
|
|
cursor.close()
|
|
|
|
|
|
def increment_token_uses(token):
|
|
cursor = database.cursor()
|
|
try:
|
|
cursor.execute('UPDATE token_auth SET uses = uses + 1 WHERE token = %s', (token,))
|
|
finally:
|
|
cursor.close()
|