This repository has been archived on 2024-10-27. You can view files and clone it, but cannot push or open issues or pull requests.
local-llm-server/llm_server/database/database.py

214 lines
7.5 KiB
Python
Raw Normal View History

import json
import time
2023-09-23 20:55:49 -06:00
import traceback
2023-09-29 00:09:44 -06:00
from threading import Thread
from typing import Union
from llm_server import opts
2023-09-30 19:41:50 -06:00
from llm_server.cluster.cluster_config import cluster_config
from llm_server.database.conn import database
from llm_server.llm import get_token_count
def log_prompt(ip: str, token: str, prompt: str, response: Union[str, None], gen_time: Union[int, float, None], parameters: dict, headers: dict, backend_response_code: int, request_url: str, backend_url: str, response_tokens: int = None, is_error: bool = False):
2023-10-02 21:43:36 -06:00
assert isinstance(prompt, str)
assert isinstance(backend_url, str)
2023-09-29 00:09:44 -06:00
def background_task():
2023-09-30 19:41:50 -06:00
nonlocal ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, backend_url, response_tokens, is_error
2023-09-29 00:09:44 -06:00
# Try not to shove JSON into the database.
if isinstance(response, dict) and response.get('results'):
response = response['results'][0]['text']
try:
j = json.loads(response)
if j.get('results'):
response = j['results'][0]['text']
except:
pass
prompt_tokens = get_token_count(prompt, backend_url)
2023-09-29 00:09:44 -06:00
if not is_error:
if not response_tokens:
response_tokens = get_token_count(response, backend_url)
2023-09-29 00:09:44 -06:00
else:
response_tokens = None
# Sometimes we may want to insert null into the DB, but
# usually we want to insert a float.
if gen_time:
gen_time = round(gen_time, 3)
if is_error:
gen_time = None
if not opts.log_prompts:
prompt = None
if not opts.log_prompts and not is_error:
# TODO: test and verify this works as expected
response = None
if token:
increment_token_uses(token)
2023-10-03 13:47:18 -06:00
backend_info = cluster_config.get_backend(backend_url)
running_model = backend_info.get('model')
2023-10-03 13:49:00 -06:00
backend_mode = backend_info['mode']
2023-09-29 00:09:44 -06:00
timestamp = int(time.time())
cursor = database.cursor()
try:
cursor.execute("""
INSERT INTO prompts
(ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""",
2023-10-03 13:49:00 -06:00
(ip, token, running_model, backend_mode, backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
2023-09-29 00:09:44 -06:00
finally:
cursor.close()
# TODO: use async/await instead of threads
thread = Thread(target=background_task)
thread.start()
thread.join()
def is_valid_api_key(api_key):
cursor = database.cursor()
try:
cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = %s", (api_key,))
row = cursor.fetchone()
if row is not None:
token, uses, max_uses, expire, disabled = row
disabled = bool(disabled)
2023-09-26 22:09:11 -06:00
if ((uses is None or max_uses is None) or uses < max_uses) and (expire is None or expire > time.time()) and not disabled:
return True
return False
finally:
2023-09-20 21:19:26 -06:00
cursor.close()
2023-09-26 22:09:11 -06:00
def is_api_key_moderated(api_key):
if not api_key:
return opts.openai_moderation_enabled
cursor = database.cursor()
2023-09-26 22:09:11 -06:00
try:
cursor.execute("SELECT openai_moderation_enabled FROM token_auth WHERE token = %s", (api_key,))
row = cursor.fetchone()
if row is not None:
return bool(row[0])
return opts.openai_moderation_enabled
finally:
cursor.close()
def get_number_of_rows(table_name):
cursor = database.cursor()
2023-09-20 21:19:26 -06:00
try:
2023-09-25 23:22:16 -06:00
cursor.execute(f"SELECT COUNT(*) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%' OR token IS NULL")
2023-09-20 21:19:26 -06:00
result = cursor.fetchone()
return result[0]
finally:
cursor.close()
def average_column(table_name, column_name):
cursor = database.cursor()
2023-09-20 21:19:26 -06:00
try:
2023-09-25 23:22:16 -06:00
cursor.execute(f"SELECT AVG({column_name}) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%' OR token IS NULL")
2023-09-20 21:19:26 -06:00
result = cursor.fetchone()
return result[0]
finally:
cursor.close()
def average_column_for_model(table_name, column_name, model_name):
cursor = database.cursor()
2023-09-20 21:19:26 -06:00
try:
2023-09-25 23:22:16 -06:00
cursor.execute(f"SELECT AVG({column_name}) FROM {table_name} WHERE model = %s AND token NOT LIKE 'SYSTEM__%%' OR token IS NULL", (model_name,))
2023-09-20 21:19:26 -06:00
result = cursor.fetchone()
return result[0]
finally:
cursor.close()
2023-09-25 23:39:50 -06:00
def weighted_average_column_for_model(table_name, column_name, model_name, backend_name, backend_url, exclude_zeros: bool = False, include_system_tokens: bool = True):
if include_system_tokens:
sql = f"SELECT {column_name}, id FROM {table_name} WHERE model = %s AND backend_mode = %s AND backend_url = %s ORDER BY id DESC"
else:
sql = f"SELECT {column_name}, id FROM {table_name} WHERE model = %s AND backend_mode = %s AND backend_url = %s AND (token NOT LIKE 'SYSTEM__%%' OR token IS NULL) ORDER BY id DESC"
cursor = database.cursor()
2023-09-20 21:19:26 -06:00
try:
2023-09-23 20:55:49 -06:00
try:
2023-09-25 23:39:50 -06:00
cursor.execute(sql, (model_name, backend_name, backend_url,))
2023-09-23 20:55:49 -06:00
results = cursor.fetchall()
except Exception:
traceback.print_exc()
2023-09-23 22:30:59 -06:00
return None
2023-09-20 21:19:26 -06:00
total_weight = 0
weighted_sum = 0
for i, (value, rowid) in enumerate(results):
if value is None or (exclude_zeros and value == 0):
continue
weight = i + 1
total_weight += weight
weighted_sum += weight * value
if total_weight > 0:
# Avoid division by zero
calculated_avg = weighted_sum / total_weight
else:
calculated_avg = 0
return calculated_avg
finally:
cursor.close()
def sum_column(table_name, column_name):
cursor = database.cursor()
2023-09-20 21:19:26 -06:00
try:
2023-09-25 23:22:16 -06:00
cursor.execute(f"SELECT SUM({column_name}) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%' OR token IS NULL")
2023-09-20 21:19:26 -06:00
result = cursor.fetchone()
2023-09-23 18:55:52 -06:00
return result[0] if result else 0
2023-09-20 21:19:26 -06:00
finally:
cursor.close()
def get_distinct_ips_24h():
# Get the current time and subtract 24 hours (in seconds)
past_24_hours = int(time.time()) - 24 * 60 * 60
cursor = database.cursor()
2023-09-20 21:19:26 -06:00
try:
cursor.execute("SELECT COUNT(DISTINCT ip) FROM prompts WHERE timestamp >= %s AND (token NOT LIKE 'SYSTEM__%%' OR token IS NULL)", (past_24_hours,))
2023-09-20 21:19:26 -06:00
result = cursor.fetchone()
return result[0] if result else 0
finally:
cursor.close()
def increment_token_uses(token):
cursor = database.cursor()
try:
cursor.execute('UPDATE token_auth SET uses = uses + 1 WHERE token = %s', (token,))
finally:
cursor.close()
2023-10-02 02:05:15 -06:00
def get_token_ratelimit(token):
priority = 9990
simultaneous_ip = opts.simultaneous_requests_per_ip
if token:
cursor = database.cursor()
try:
cursor.execute("SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s", (token,))
result = cursor.fetchone()
if result:
priority, simultaneous_ip = result
if simultaneous_ip is None:
# No ratelimit for this token if null
simultaneous_ip = 999999999
finally:
cursor.close()
return priority, simultaneous_ip