redo database connection, add pooling, minor logging changes, other clean up

This commit is contained in:
Cyberes 2024-05-07 09:48:51 -06:00
parent ab408c6c5b
commit ff82add09e
25 changed files with 219 additions and 192 deletions

View File

@ -28,7 +28,6 @@ if __name__ == "__main__":
parser.add_argument('-d', '--debug', action='store_true', help='Enable debug logging.') parser.add_argument('-d', '--debug', action='store_true', help='Enable debug logging.')
args = parser.parse_args() args = parser.parse_args()
# TODO: have this be set by either the arg or a config value
if args.debug: if args.debug:
logging_info.level = logging.DEBUG logging_info.level = logging.DEBUG

View File

@ -6,11 +6,14 @@ from llm_server import opts
from llm_server.cluster.redis_cycle import add_backend_cycler, redis_cycle from llm_server.cluster.redis_cycle import add_backend_cycler, redis_cycle
from llm_server.cluster.stores import redis_running_models from llm_server.cluster.stores import redis_running_models
from llm_server.custom_redis import RedisCustom from llm_server.custom_redis import RedisCustom
from llm_server.logging import create_logger
from llm_server.routes.helpers.model import estimate_model_size from llm_server.routes.helpers.model import estimate_model_size
# Don't try to reorganize this file or else you'll run into circular imports. # Don't try to reorganize this file or else you'll run into circular imports.
_logger = create_logger('redis')
class RedisClusterStore: class RedisClusterStore:
""" """
A class used to store the cluster state in Redis. A class used to store the cluster state in Redis.
@ -67,7 +70,7 @@ class RedisClusterStore:
if not backend_info['online']: if not backend_info['online']:
old = backend_url old = backend_url
backend_url = get_a_cluster_backend() backend_url = get_a_cluster_backend()
print(f'Backend {old} offline. Request was redirected to {backend_url}') _logger.debug(f'Backend {old} offline. Request was redirected to {backend_url}')
return backend_url return backend_url
@ -108,8 +111,7 @@ def get_backends():
) )
return [url for url, info in online_backends], [url for url, info in offline_backends] return [url for url, info in online_backends], [url for url, info in offline_backends]
except KeyError: except KeyError:
traceback.print_exc() _logger.err(f'Failed to get a backend from the cluster config: {traceback.format_exc()}\nCurrent backends: {backends}')
print(backends)
def get_a_cluster_backend(model=None): def get_a_cluster_backend(model=None):

View File

@ -7,10 +7,13 @@ import llm_server
from llm_server import opts from llm_server import opts
from llm_server.config.config import ConfigLoader, config_default_vars, config_required_vars from llm_server.config.config import ConfigLoader, config_default_vars, config_required_vars
from llm_server.custom_redis import redis from llm_server.custom_redis import redis
from llm_server.database.conn import database from llm_server.database.conn import Database
from llm_server.database.database import get_number_of_rows from llm_server.database.database import get_number_of_rows
from llm_server.logging import create_logger
from llm_server.routes.queue import PriorityQueue from llm_server.routes.queue import PriorityQueue
_logger = create_logger('config')
def load_config(config_path): def load_config(config_path):
config_loader = ConfigLoader(config_path, config_default_vars, config_required_vars) config_loader = ConfigLoader(config_path, config_default_vars, config_required_vars)
@ -58,7 +61,7 @@ def load_config(config_path):
llm_server.routes.queue.priority_queue = PriorityQueue([x['backend_url'] for x in config['cluster']]) llm_server.routes.queue.priority_queue = PriorityQueue([x['backend_url'] for x in config['cluster']])
if opts.openai_expose_our_model and not opts.openai_api_key: if opts.openai_expose_our_model and not opts.openai_api_key:
print('If you set openai_expose_our_model to false, you must set your OpenAI key in openai_api_key.') _logger.error('If you set openai_expose_our_model to false, you must set your OpenAI key in openai_api_key.')
sys.exit(1) sys.exit(1)
opts.verify_ssl = config['verify_ssl'] opts.verify_ssl = config['verify_ssl']
@ -67,11 +70,11 @@ def load_config(config_path):
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
if config['http_host']: if config['http_host']:
http_host = re.sub(r'http(?:s)?://', '', config["http_host"]) http_host = re.sub(r'https?://', '', config["http_host"])
redis.set('http_host', http_host) redis.set('http_host', http_host)
redis.set('base_client_api', f'{http_host}/{opts.frontend_api_client.strip("/")}') redis.set('base_client_api', f'{http_host}/{opts.frontend_api_client.strip("/")}')
database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database']) Database.initialise(maxconn=config['mysql']['maxconn'], host=config['mysql']['host'], user=config['mysql']['username'], password=config['mysql']['password'], database=config['mysql']['database'])
if config['load_num_prompts']: if config['load_num_prompts']:
redis.set('proompts', get_number_of_rows('prompts')) redis.set('proompts', get_number_of_rows('prompts'))

View File

@ -1,3 +1,4 @@
import logging
import pickle import pickle
import sys import sys
import traceback import traceback
@ -29,8 +30,9 @@ class RedisCustom(Redis):
try: try:
self.set('____', 1) self.set('____', 1)
except redis_pkg.exceptions.ConnectionError as e: except redis_pkg.exceptions.ConnectionError as e:
print('Failed to connect to the Redis server:', e) logger = logging.getLogger('redis')
print('Did you install and start the Redis server?') logger.setLevel(logging.INFO)
logger.error(f'Failed to connect to the Redis server: {e}\nDid you install and start the Redis server?')
sys.exit(1) sys.exit(1)
def _key(self, key): def _key(self, key):

View File

@ -1,28 +1,40 @@
import pymysql from mysql.connector import pooling
class DatabaseConnection: class Database:
host: str = None __connection_pool = None
username: str = None
password: str = None
database_name: str = None
def init_db(self, host, username, password, database_name): @classmethod
self.host = host def initialise(cls, maxconn, **kwargs):
self.username = username if cls.__connection_pool is not None:
self.password = password raise Exception('Database connection pool is already initialised')
self.database_name = database_name cls.__connection_pool = pooling.MySQLConnectionPool(pool_size=maxconn,
pool_reset_session=True,
**kwargs)
def cursor(self): @classmethod
db = pymysql.connect( def get_connection(cls):
host=self.host, return cls.__connection_pool.get_connection()
user=self.username,
password=self.password, @classmethod
database=self.database_name, def return_connection(cls, connection):
charset='utf8mb4', connection.close()
autocommit=True,
)
return db.cursor()
database = DatabaseConnection() class CursorFromConnectionFromPool:
def __init__(self):
self.conn = None
self.cursor = None
def __enter__(self):
self.conn = Database.get_connection()
self.cursor = self.conn.cursor()
return self.cursor
def __exit__(self, exception_type, exception_value, exception_traceback):
if exception_value is not None: # This is equivalent of saying if there is an exception
self.conn.rollback()
else:
self.cursor.close()
self.conn.commit()
Database.return_connection(self.conn)

View File

@ -1,8 +1,8 @@
from llm_server.database.conn import database from llm_server.database.conn import CursorFromConnectionFromPool
def create_db(): def create_db():
cursor = database.cursor() with CursorFromConnectionFromPool() as cursor:
cursor.execute(''' cursor.execute('''
CREATE TABLE IF NOT EXISTS prompts ( CREATE TABLE IF NOT EXISTS prompts (
ip TEXT, ip TEXT,
@ -37,4 +37,3 @@ def create_db():
disabled BOOLEAN DEFAULT 0 disabled BOOLEAN DEFAULT 0
) )
''') ''')
cursor.close()

View File

@ -5,7 +5,7 @@ from typing import Union
from llm_server import opts from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config from llm_server.cluster.cluster_config import cluster_config
from llm_server.database.conn import database from llm_server.database.conn import CursorFromConnectionFromPool
from llm_server.llm import get_token_count from llm_server.llm import get_token_count
@ -52,21 +52,17 @@ def do_db_log(ip: str, token: str, prompt: str, response: Union[str, None], gen_
running_model = backend_info.get('model') running_model = backend_info.get('model')
backend_mode = backend_info['mode'] backend_mode = backend_info['mode']
timestamp = int(time.time()) timestamp = int(time.time())
cursor = database.cursor() with CursorFromConnectionFromPool() as cursor:
try:
cursor.execute(""" cursor.execute("""
INSERT INTO prompts INSERT INTO prompts
(ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp) (ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""", """,
(ip, token, running_model, backend_mode, backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp)) (ip, token, running_model, backend_mode, backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
finally:
cursor.close()
def is_valid_api_key(api_key): def is_valid_api_key(api_key):
cursor = database.cursor() with CursorFromConnectionFromPool() as cursor:
try:
cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = %s", (api_key,)) cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = %s", (api_key,))
row = cursor.fetchone() row = cursor.fetchone()
if row is not None: if row is not None:
@ -75,52 +71,38 @@ def is_valid_api_key(api_key):
if ((uses is None or max_uses is None) or uses < max_uses) and (expire is None or expire > time.time()) and not disabled: if ((uses is None or max_uses is None) or uses < max_uses) and (expire is None or expire > time.time()) and not disabled:
return True return True
return False return False
finally:
cursor.close()
def is_api_key_moderated(api_key): def is_api_key_moderated(api_key):
if not api_key: if not api_key:
return opts.openai_moderation_enabled return opts.openai_moderation_enabled
cursor = database.cursor() with CursorFromConnectionFromPool() as cursor:
try:
cursor.execute("SELECT openai_moderation_enabled FROM token_auth WHERE token = %s", (api_key,)) cursor.execute("SELECT openai_moderation_enabled FROM token_auth WHERE token = %s", (api_key,))
row = cursor.fetchone() row = cursor.fetchone()
if row is not None: if row is not None:
return bool(row[0]) return bool(row[0])
return opts.openai_moderation_enabled return opts.openai_moderation_enabled
finally:
cursor.close()
def get_number_of_rows(table_name): def get_number_of_rows(table_name):
cursor = database.cursor() with CursorFromConnectionFromPool() as cursor:
try:
cursor.execute(f"SELECT COUNT(*) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%' OR token IS NULL") cursor.execute(f"SELECT COUNT(*) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%' OR token IS NULL")
result = cursor.fetchone() result = cursor.fetchone()
return result[0] return result[0]
finally:
cursor.close()
def average_column(table_name, column_name): def average_column(table_name, column_name):
cursor = database.cursor() with CursorFromConnectionFromPool() as cursor:
try:
cursor.execute(f"SELECT AVG({column_name}) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%' OR token IS NULL") cursor.execute(f"SELECT AVG({column_name}) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%' OR token IS NULL")
result = cursor.fetchone() result = cursor.fetchone()
return result[0] return result[0]
finally:
cursor.close()
def average_column_for_model(table_name, column_name, model_name): def average_column_for_model(table_name, column_name, model_name):
cursor = database.cursor() with CursorFromConnectionFromPool() as cursor:
try:
cursor.execute(f"SELECT AVG({column_name}) FROM {table_name} WHERE model = %s AND token NOT LIKE 'SYSTEM__%%' OR token IS NULL", (model_name,)) cursor.execute(f"SELECT AVG({column_name}) FROM {table_name} WHERE model = %s AND token NOT LIKE 'SYSTEM__%%' OR token IS NULL", (model_name,))
result = cursor.fetchone() result = cursor.fetchone()
return result[0] return result[0]
finally:
cursor.close()
def weighted_average_column_for_model(table_name, column_name, model_name, backend_name, backend_url, exclude_zeros: bool = False, include_system_tokens: bool = True): def weighted_average_column_for_model(table_name, column_name, model_name, backend_name, backend_url, exclude_zeros: bool = False, include_system_tokens: bool = True):
@ -129,8 +111,7 @@ def weighted_average_column_for_model(table_name, column_name, model_name, backe
else: else:
sql = f"SELECT {column_name}, id FROM {table_name} WHERE model = %s AND backend_mode = %s AND backend_url = %s AND (token NOT LIKE 'SYSTEM__%%' OR token IS NULL) ORDER BY id DESC" sql = f"SELECT {column_name}, id FROM {table_name} WHERE model = %s AND backend_mode = %s AND backend_url = %s AND (token NOT LIKE 'SYSTEM__%%' OR token IS NULL) ORDER BY id DESC"
cursor = database.cursor() with CursorFromConnectionFromPool() as cursor:
try:
try: try:
cursor.execute(sql, (model_name, backend_name, backend_url,)) cursor.execute(sql, (model_name, backend_name, backend_url,))
results = cursor.fetchall() results = cursor.fetchall()
@ -154,46 +135,34 @@ def weighted_average_column_for_model(table_name, column_name, model_name, backe
calculated_avg = 0 calculated_avg = 0
return calculated_avg return calculated_avg
finally:
cursor.close()
def sum_column(table_name, column_name): def sum_column(table_name, column_name):
cursor = database.cursor() with CursorFromConnectionFromPool() as cursor:
try:
cursor.execute(f"SELECT SUM({column_name}) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%' OR token IS NULL") cursor.execute(f"SELECT SUM({column_name}) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%' OR token IS NULL")
result = cursor.fetchone() result = cursor.fetchone()
return result[0] if result else 0 return result[0] if result else 0
finally:
cursor.close()
def get_distinct_ips_24h(): def get_distinct_ips_24h():
# Get the current time and subtract 24 hours (in seconds) # Get the current time and subtract 24 hours (in seconds)
past_24_hours = int(time.time()) - 24 * 60 * 60 past_24_hours = int(time.time()) - 24 * 60 * 60
cursor = database.cursor() with CursorFromConnectionFromPool() as cursor:
try:
cursor.execute("SELECT COUNT(DISTINCT ip) FROM prompts WHERE timestamp >= %s AND (token NOT LIKE 'SYSTEM__%%' OR token IS NULL)", (past_24_hours,)) cursor.execute("SELECT COUNT(DISTINCT ip) FROM prompts WHERE timestamp >= %s AND (token NOT LIKE 'SYSTEM__%%' OR token IS NULL)", (past_24_hours,))
result = cursor.fetchone() result = cursor.fetchone()
return result[0] if result else 0 return result[0] if result else 0
finally:
cursor.close()
def increment_token_uses(token): def increment_token_uses(token):
cursor = database.cursor() with CursorFromConnectionFromPool() as cursor:
try:
cursor.execute('UPDATE token_auth SET uses = uses + 1 WHERE token = %s', (token,)) cursor.execute('UPDATE token_auth SET uses = uses + 1 WHERE token = %s', (token,))
finally:
cursor.close()
def get_token_ratelimit(token): def get_token_ratelimit(token):
priority = 9990 priority = 9990
simultaneous_ip = opts.simultaneous_requests_per_ip simultaneous_ip = opts.simultaneous_requests_per_ip
if token: if token:
cursor = database.cursor() with CursorFromConnectionFromPool() as cursor:
try:
cursor.execute("SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s", (token,)) cursor.execute("SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s", (token,))
result = cursor.fetchone() result = cursor.fetchone()
if result: if result:
@ -201,6 +170,4 @@ def get_token_ratelimit(token):
if simultaneous_ip is None: if simultaneous_ip is None:
# No ratelimit for this token if null # No ratelimit for this token if null
simultaneous_ip = 999999999 simultaneous_ip = 999999999
finally:
cursor.close()
return priority, simultaneous_ip return priority, simultaneous_ip

View File

@ -1,6 +1,9 @@
import requests import requests
from llm_server import opts from llm_server import opts
from llm_server.logging import create_logger
_logger = create_logger('moderation')
def check_moderation_endpoint(prompt: str): def check_moderation_endpoint(prompt: str):
@ -10,7 +13,7 @@ def check_moderation_endpoint(prompt: str):
} }
response = requests.post('https://api.openai.com/v1/moderations', headers=headers, json={"input": prompt}, timeout=10) response = requests.post('https://api.openai.com/v1/moderations', headers=headers, json={"input": prompt}, timeout=10)
if response.status_code != 200: if response.status_code != 200:
print('moderation failed:', response) _logger.error(f'moderation failed: {response}')
response.raise_for_status() response.raise_for_status()
response = response.json() response = response.json()

View File

@ -1,6 +1,9 @@
from flask import jsonify from flask import jsonify
from llm_server import opts from llm_server import opts
from llm_server.logging import create_logger
_logger = create_logger('oai_to_vllm')
def oai_to_vllm(request_json_body, stop_hashes: bool, mode): def oai_to_vllm(request_json_body, stop_hashes: bool, mode):
@ -36,7 +39,7 @@ def oai_to_vllm(request_json_body, stop_hashes: bool, mode):
def format_oai_err(err_msg): def format_oai_err(err_msg):
print('OAI ERROR MESSAGE:', err_msg) _logger.error(f'Got an OAI error message: {err_msg}')
return jsonify({ return jsonify({
"error": { "error": {
"message": err_msg, "message": err_msg,

View File

@ -39,6 +39,7 @@ def init_logging(filepath: Path = None):
This is only called by `server.py` since there is wierdness with Gunicorn. The deamon doesn't need this. This is only called by `server.py` since there is wierdness with Gunicorn. The deamon doesn't need this.
:return: :return:
""" """
global LOG_DIRECTORY
logger = logging.getLogger('llm_server') logger = logging.getLogger('llm_server')
logger.setLevel(logging_info.level) logger.setLevel(logging_info.level)

View File

@ -5,9 +5,12 @@ from flask import jsonify, request
from llm_server import messages, opts from llm_server import messages, opts
from llm_server.database.log_to_db import log_to_db from llm_server.database.log_to_db import log_to_db
from llm_server.logging import create_logger
from llm_server.routes.helpers.client import format_sillytavern_err from llm_server.routes.helpers.client import format_sillytavern_err
from llm_server.routes.request_handler import RequestHandler from llm_server.routes.request_handler import RequestHandler
_logger = create_logger('OobaRequestHandler')
class OobaRequestHandler(RequestHandler): class OobaRequestHandler(RequestHandler):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -16,7 +19,7 @@ class OobaRequestHandler(RequestHandler):
def handle_request(self, return_ok: bool = True): def handle_request(self, return_ok: bool = True):
assert not self.used assert not self.used
if self.offline: if self.offline:
print('This backend is offline:', messages.BACKEND_OFFLINE) # _logger.debug(f'This backend is offline.')
return self.handle_error(messages.BACKEND_OFFLINE) return self.handle_error(messages.BACKEND_OFFLINE)
request_valid, invalid_response = self.validate_request() request_valid, invalid_response = self.validate_request()

View File

@ -1,8 +1,10 @@
from flask import Blueprint from flask import Blueprint
from ..request_handler import before_request from ..request_handler import before_request
from ..server_error import handle_server_error
from ... import opts from ... import opts
from ...logging import create_logger
_logger = create_logger('OpenAI')
openai_bp = Blueprint('openai/v1/', __name__) openai_bp = Blueprint('openai/v1/', __name__)
openai_model_bp = Blueprint('openai/', __name__) openai_model_bp = Blueprint('openai/', __name__)
@ -24,7 +26,7 @@ def handle_error(e):
"auth_subrequest_error" "auth_subrequest_error"
""" """
print('OAI returning error:', e) _logger(f'OAI returning error: {e}')
return jsonify({ return jsonify({
"error": { "error": {
"message": "Internal server error", "message": "Internal server error",

View File

@ -15,6 +15,9 @@ from ... import opts
from ...database.log_to_db import log_to_db from ...database.log_to_db import log_to_db
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
from ...logging import create_logger
_logger = create_logger('OpenAIChatCompletions')
# TODO: add rate-limit headers? # TODO: add rate-limit headers?
@ -99,7 +102,7 @@ def openai_chat_completions(model_name=None):
# return a 408 if necessary. # return a 408 if necessary.
_, stream_name, error_msg = event.wait() _, stream_name, error_msg = event.wait()
if error_msg: if error_msg:
print('OAI failed to start streaming:', error_msg) _logger.error(f'OAI failed to start streaming: {error_msg}')
stream_name = None # set to null so that the Finally ignores it. stream_name = None # set to null so that the Finally ignores it.
return 'Request Timeout', 408 return 'Request Timeout', 408
@ -111,7 +114,7 @@ def openai_chat_completions(model_name=None):
while True: while True:
stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout) stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout)
if not stream_data: if not stream_data:
print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.") _logger.debug(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.")
yield 'data: [DONE]\n\n' yield 'data: [DONE]\n\n'
else: else:
for stream_index, item in stream_data[0][1]: for stream_index, item in stream_data[0][1]:
@ -120,7 +123,7 @@ def openai_chat_completions(model_name=None):
data = ujson.loads(item[b'data']) data = ujson.loads(item[b'data'])
if data['error']: if data['error']:
# Not printing error since we can just check the daemon log. # Not printing error since we can just check the daemon log.
print('OAI streaming encountered error') _logger.warn(f'OAI streaming encountered error: {data["error"]}')
yield 'data: [DONE]\n\n' yield 'data: [DONE]\n\n'
return return
elif data['new']: elif data['new']:

View File

@ -16,10 +16,13 @@ from ...database.log_to_db import log_to_db
from ...llm import get_token_count from ...llm import get_token_count
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
from ...llm.openai.transform import generate_oai_string, trim_string_to_fit from ...llm.openai.transform import generate_oai_string, trim_string_to_fit
from ...logging import create_logger
# TODO: add rate-limit headers? # TODO: add rate-limit headers?
_logger = create_logger('OpenAICompletions')
@openai_bp.route('/completions', methods=['POST']) @openai_bp.route('/completions', methods=['POST'])
@openai_model_bp.route('/<model_name>/v1/completions', methods=['POST']) @openai_model_bp.route('/<model_name>/v1/completions', methods=['POST'])
def openai_completions(model_name=None): def openai_completions(model_name=None):
@ -144,7 +147,7 @@ def openai_completions(model_name=None):
_, stream_name, error_msg = event.wait() _, stream_name, error_msg = event.wait()
if error_msg: if error_msg:
print('OAI failed to start streaming:', error_msg) _logger.error(f'OAI failed to start streaming: {error_msg}')
stream_name = None stream_name = None
return 'Request Timeout', 408 return 'Request Timeout', 408
@ -156,7 +159,7 @@ def openai_completions(model_name=None):
while True: while True:
stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout) stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout)
if not stream_data: if not stream_data:
print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.") _logger.debug(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.")
yield 'data: [DONE]\n\n' yield 'data: [DONE]\n\n'
else: else:
for stream_index, item in stream_data[0][1]: for stream_index, item in stream_data[0][1]:
@ -164,7 +167,7 @@ def openai_completions(model_name=None):
timestamp = int(stream_index.decode('utf-8').split('-')[0]) timestamp = int(stream_index.decode('utf-8').split('-')[0])
data = ujson.loads(item[b'data']) data = ujson.loads(item[b'data'])
if data['error']: if data['error']:
print('OAI streaming encountered error') _logger.error(f'OAI streaming encountered error: {data["error"]}')
yield 'data: [DONE]\n\n' yield 'data: [DONE]\n\n'
return return
elif data['new']: elif data['new']:

View File

@ -16,9 +16,12 @@ from llm_server.database.log_to_db import log_to_db
from llm_server.llm import get_token_count from llm_server.llm import get_token_count
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_invalid_model_err from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_invalid_model_err
from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
from llm_server.logging import create_logger
from llm_server.routes.request_handler import RequestHandler from llm_server.routes.request_handler import RequestHandler
from llm_server.workers.moderator import add_moderation_task, get_results from llm_server.workers.moderator import add_moderation_task, get_results
_logger = create_logger('OpenAIRequestHandler')
class OpenAIRequestHandler(RequestHandler): class OpenAIRequestHandler(RequestHandler):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -29,7 +32,7 @@ class OpenAIRequestHandler(RequestHandler):
assert not self.used assert not self.used
if self.offline: if self.offline:
msg = return_invalid_model_err(self.selected_model) msg = return_invalid_model_err(self.selected_model)
print('OAI Offline:', msg) _logger.error(f'OAI is offline: {msg}')
return self.handle_error(msg) return self.handle_error(msg)
if opts.openai_silent_trim: if opts.openai_silent_trim:
@ -72,7 +75,7 @@ class OpenAIRequestHandler(RequestHandler):
self.request.json['messages'].insert((len(self.request.json['messages'])), {'role': 'system', 'content': mod_msg}) self.request.json['messages'].insert((len(self.request.json['messages'])), {'role': 'system', 'content': mod_msg})
self.prompt = transform_messages_to_prompt(self.request.json['messages']) self.prompt = transform_messages_to_prompt(self.request.json['messages'])
except Exception as e: except Exception as e:
print(f'OpenAI moderation endpoint failed:', f'{e.__class__.__name__}: {e}') _logger.error(f'OpenAI moderation endpoint failed: {e.__class__.__name__}: {e}')
traceback.print_exc() traceback.print_exc()
llm_request = {**self.parameters, 'prompt': self.prompt} llm_request = {**self.parameters, 'prompt': self.prompt}
@ -106,7 +109,7 @@ class OpenAIRequestHandler(RequestHandler):
return response, 429 return response, 429
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]: def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
print('OAI Error:', error_msg) _logger.error('OAI Error: {error_msg}')
return jsonify({ return jsonify({
"error": { "error": {
"message": "Invalid request, check your parameters and try again.", "message": "Invalid request, check your parameters and try again.",
@ -155,7 +158,7 @@ class OpenAIRequestHandler(RequestHandler):
def validate_request(self, prompt: str = None, do_log: bool = False) -> Tuple[bool, Tuple[Response | None, int]]: def validate_request(self, prompt: str = None, do_log: bool = False) -> Tuple[bool, Tuple[Response | None, int]]:
self.parameters, parameters_invalid_msg = self.get_parameters() self.parameters, parameters_invalid_msg = self.get_parameters()
if not self.parameters: if not self.parameters:
print('OAI BACKEND VALIDATION ERROR:', parameters_invalid_msg) _logger.error(f'OAI BACKEND VALIDATION ERROR: {parameters_invalid_msg}')
return False, (Response('Invalid request, check your parameters and try again.'), 400) return False, (Response('Invalid request, check your parameters and try again.'), 400)
invalid_oai_err_msg = validate_oai(self.parameters) invalid_oai_err_msg = validate_oai(self.parameters)
if invalid_oai_err_msg: if invalid_oai_err_msg:

View File

@ -10,6 +10,7 @@ from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import RedisCustom, redis from llm_server.custom_redis import RedisCustom, redis
from llm_server.database.database import get_token_ratelimit from llm_server.database.database import get_token_ratelimit
from llm_server.logging import create_logger
def increment_ip_count(client_ip: str, redis_key): def increment_ip_count(client_ip: str, redis_key):
@ -30,6 +31,7 @@ class RedisPriorityQueue:
def __init__(self, name, db: int = 12): def __init__(self, name, db: int = 12):
self.name = name self.name = name
self.redis = RedisCustom(name, db=db) self.redis = RedisCustom(name, db=db)
self._logger = create_logger('RedisPriorityQueue')
def put(self, item, priority: int, selected_model: str, do_stream: bool = False): def put(self, item, priority: int, selected_model: str, do_stream: bool = False):
# TODO: remove this when we're sure nothing strange is happening # TODO: remove this when we're sure nothing strange is happening
@ -41,7 +43,7 @@ class RedisPriorityQueue:
ip_count = self.get_ip_request_count(item[1]) ip_count = self.get_ip_request_count(item[1])
_, simultaneous_ip = get_token_ratelimit(item[2]) _, simultaneous_ip = get_token_ratelimit(item[2])
if ip_count and int(ip_count) >= simultaneous_ip and priority != 0: if ip_count and int(ip_count) >= simultaneous_ip and priority != 0:
print(f'Rejecting request from {item[1]} - {ip_count} request queued.') self._logger.debug(f'Rejecting request from {item[1]} - {ip_count} request queued.')
return None # reject the request return None # reject the request
timestamp = time.time() timestamp = time.time()
@ -98,7 +100,7 @@ class RedisPriorityQueue:
event_id = item_data[1] event_id = item_data[1]
event = DataEvent(event_id) event = DataEvent(event_id)
event.set((False, None, 'closed')) event.set((False, None, 'closed'))
print('Removed timed-out item from queue:', event_id) self._logger.debug('Removed timed-out item from queue: {event_id}')
class DataEvent: class DataEvent:

View File

@ -12,10 +12,13 @@ from llm_server.database.log_to_db import log_to_db
from llm_server.helpers import auto_set_base_client_api from llm_server.helpers import auto_set_base_client_api
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
from llm_server.llm.vllm.vllm_backend import VLLMBackend from llm_server.llm.vllm.vllm_backend import VLLMBackend
from llm_server.logging import create_logger
from llm_server.routes.auth import parse_token from llm_server.routes.auth import parse_token
from llm_server.routes.helpers.http import require_api_key, validate_json from llm_server.routes.helpers.http import require_api_key, validate_json
from llm_server.routes.queue import priority_queue from llm_server.routes.queue import priority_queue
_logger = create_logger('RequestHandler')
class RequestHandler: class RequestHandler:
def __init__(self, incoming_request: flask.Request, selected_model: str = None, incoming_json: Union[dict, str] = None): def __init__(self, incoming_request: flask.Request, selected_model: str = None, incoming_json: Union[dict, str] = None):
@ -223,7 +226,7 @@ class RequestHandler:
processing_ip = 0 processing_ip = 0
if queued_ip_count + processing_ip >= self.token_simultaneous_ip: if queued_ip_count + processing_ip >= self.token_simultaneous_ip:
print(f'Rejecting request from {self.client_ip} - {processing_ip} processing, {queued_ip_count} queued') _logger.debug(f'Rejecting request from {self.client_ip} - {processing_ip} processing, {queued_ip_count} queued')
return True return True
else: else:
return False return False

View File

@ -1,3 +1,8 @@
from llm_server.logging import create_logger
_logger = create_logger('handle_server_error')
def handle_server_error(e): def handle_server_error(e):
print('Internal Error:', e) _logger.error(f'Internal Error: {e}')
return {'error': True, 'code': 500, 'message': 'Internal Server Error :('}, 500 return {'error': True, 'code': 500, 'message': 'Internal Server Error :('}, 500

View File

@ -13,12 +13,15 @@ from ..queue import priority_queue
from ... import opts from ... import opts
from ...custom_redis import redis from ...custom_redis import redis
from ...database.log_to_db import log_to_db from ...database.log_to_db import log_to_db
from ...logging import create_logger
from ...sock import sock from ...sock import sock
# Stacking the @sock.route() creates a TypeError error on the /v1/stream endpoint. # Stacking the @sock.route() creates a TypeError error on the /v1/stream endpoint.
# We solve this by splitting the routes # We solve this by splitting the routes
_logger = create_logger('GenerateStream')
@bp.route('/v1/stream') @bp.route('/v1/stream')
@bp.route('/<model_name>/v1/stream') @bp.route('/<model_name>/v1/stream')
def stream(model_name=None): def stream(model_name=None):
@ -85,7 +88,7 @@ def do_stream(ws, model_name):
handler = OobaRequestHandler(incoming_request=request, selected_model=model_name, incoming_json=request_json_body) handler = OobaRequestHandler(incoming_request=request, selected_model=model_name, incoming_json=request_json_body)
if handler.offline: if handler.offline:
msg = f'{handler.selected_model} is not a valid model choice.' msg = f'{handler.selected_model} is not a valid model choice.'
print(msg) _logger.debug(msg)
ws.send(json.dumps({ ws.send(json.dumps({
'event': 'text_stream', 'event': 'text_stream',
'message_num': 0, 'message_num': 0,
@ -131,7 +134,7 @@ def do_stream(ws, model_name):
_, stream_name, error_msg = event.wait() _, stream_name, error_msg = event.wait()
if error_msg: if error_msg:
print('Stream failed to start streaming:', error_msg) _logger.error(f'Stream failed to start streaming: {error_msg}')
ws.close(reason=1014, message='Request Timeout') ws.close(reason=1014, message='Request Timeout')
return return
@ -143,14 +146,14 @@ def do_stream(ws, model_name):
while True: while True:
stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout) stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout)
if not stream_data: if not stream_data:
print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.") _logger.error(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.")
return return
else: else:
for stream_index, item in stream_data[0][1]: for stream_index, item in stream_data[0][1]:
last_id = stream_index last_id = stream_index
data = ujson.loads(item[b'data']) data = ujson.loads(item[b'data'])
if data['error']: if data['error']:
print(data['error']) _logger.error(f'Encountered error while streaming: {data["error"]}')
send_err_and_quit('Encountered exception while streaming.') send_err_and_quit('Encountered exception while streaming.')
return return
elif data['new']: elif data['new']:

View File

@ -2,6 +2,7 @@ import time
from redis import Redis from redis import Redis
from llm_server.logging import create_logger
from llm_server.workers.inferencer import STREAM_NAME_PREFIX from llm_server.workers.inferencer import STREAM_NAME_PREFIX
@ -10,6 +11,7 @@ from llm_server.workers.inferencer import STREAM_NAME_PREFIX
def cleaner(): def cleaner():
r = Redis(db=8) r = Redis(db=8)
stream_info = {} stream_info = {}
logger = create_logger('cleaner')
while True: while True:
all_streams = r.keys(f'{STREAM_NAME_PREFIX}:*') all_streams = r.keys(f'{STREAM_NAME_PREFIX}:*')
@ -26,7 +28,7 @@ def cleaner():
# If the size hasn't changed for 5 minutes, delete the stream # If the size hasn't changed for 5 minutes, delete the stream
if time.time() - stream_info[stream]['time'] >= 300: if time.time() - stream_info[stream]['time'] >= 300:
r.delete(stream) r.delete(stream)
print(f"Stream '{stream}' deleted due to inactivity.") logger.debug(f"Stream '{stream}' deleted due to inactivity.")
del stream_info[stream] del stream_info[stream]
time.sleep(60) time.sleep(60)

View File

@ -4,6 +4,7 @@ import traceback
import redis import redis
from llm_server.database.database import do_db_log from llm_server.database.database import do_db_log
from llm_server.logging import create_logger
def db_logger(): def db_logger():
@ -16,6 +17,7 @@ def db_logger():
r = redis.Redis(host='localhost', port=6379, db=3) r = redis.Redis(host='localhost', port=6379, db=3)
p = r.pubsub() p = r.pubsub()
p.subscribe('database-logger') p.subscribe('database-logger')
logger = create_logger('main_bg')
for message in p.listen(): for message in p.listen():
try: try:
@ -28,4 +30,4 @@ def db_logger():
if function_name == 'log_prompt': if function_name == 'log_prompt':
do_db_log(*args, **kwargs) do_db_log(*args, **kwargs)
except: except:
traceback.print_exc() logger.error(traceback.format_exc())

View File

@ -7,10 +7,12 @@ from llm_server.cluster.cluster_config import get_backends, cluster_config
from llm_server.custom_redis import redis from llm_server.custom_redis import redis
from llm_server.database.database import weighted_average_column_for_model from llm_server.database.database import weighted_average_column_for_model
from llm_server.llm.info import get_info from llm_server.llm.info import get_info
from llm_server.logging import create_logger
from llm_server.routes.queue import RedisPriorityQueue, priority_queue from llm_server.routes.queue import RedisPriorityQueue, priority_queue
def main_background_thread(): def main_background_thread():
logger = create_logger('main_bg')
while True: while True:
online, offline = get_backends() online, offline = get_backends()
for backend_url in online: for backend_url in online:
@ -34,7 +36,7 @@ def main_background_thread():
base_client_api = redis.get('base_client_api', dtype=str) base_client_api = redis.get('base_client_api', dtype=str)
r = requests.get('https://' + base_client_api, timeout=5) r = requests.get('https://' + base_client_api, timeout=5)
except Exception as e: except Exception as e:
print(f'Failed fetch the homepage - {e.__class__.__name__}: {e}') logger.error(f'Failed fetch the homepage - {e.__class__.__name__}: {e}')
backends = priority_queue.get_backends() backends = priority_queue.get_backends()
for backend_url in backends: for backend_url in backends:

View File

@ -4,7 +4,7 @@ Flask-Caching==2.0.2
requests~=2.31.0 requests~=2.31.0
tiktoken~=0.5.0 tiktoken~=0.5.0
gevent~=23.9.0.post1 gevent~=23.9.0.post1
PyMySQL~=1.1.0 mysql-connector-python==8.4.0
simplejson~=3.19.1 simplejson~=3.19.1
websockets~=11.0.3 websockets~=11.0.3
basicauth~=1.0.0 basicauth~=1.0.0
@ -14,5 +14,4 @@ gunicorn==21.2.0
redis==5.0.1 redis==5.0.1
ujson==5.8.0 ujson==5.8.0
vllm==0.2.7 vllm==0.2.7
gradio~=3.46.1
coloredlogs~=15.0.1 coloredlogs~=15.0.1

View File

@ -22,11 +22,11 @@ from llm_server.cluster.cluster_config import cluster_config
from llm_server.config.config import mode_ui_names from llm_server.config.config import mode_ui_names
from llm_server.config.load import load_config from llm_server.config.load import load_config
from llm_server.custom_redis import flask_cache, redis from llm_server.custom_redis import flask_cache, redis
from llm_server.database.conn import database from llm_server.database.conn import database, Database
from llm_server.database.create import create_db from llm_server.database.create import create_db
from llm_server.helpers import auto_set_base_client_api from llm_server.helpers import auto_set_base_client_api
from llm_server.llm.vllm.info import vllm_info from llm_server.llm.vllm.info import vllm_info
from llm_server.logging import init_logging from llm_server.logging import init_logging, create_logger
from llm_server.routes.openai import openai_bp, openai_model_bp from llm_server.routes.openai import openai_bp, openai_model_bp
from llm_server.routes.server_error import handle_server_error from llm_server.routes.server_error import handle_server_error
from llm_server.routes.v1 import bp from llm_server.routes.v1 import bp
@ -62,13 +62,6 @@ from llm_server.sock import init_wssocket
# TODO: add more excluding to SYSTEM__ tokens # TODO: add more excluding to SYSTEM__ tokens
# TODO: return 200 when returning formatted sillytavern error # TODO: return 200 when returning formatted sillytavern error
try:
import vllm
except ModuleNotFoundError as e:
print('Could not import vllm-gptq:', e)
print('Please see README.md for install instructions.')
sys.exit(1)
script_path = os.path.dirname(os.path.realpath(__file__)) script_path = os.path.dirname(os.path.realpath(__file__))
config_path_environ = os.getenv("CONFIG_PATH") config_path_environ = os.getenv("CONFIG_PATH")
if config_path_environ: if config_path_environ:
@ -78,11 +71,20 @@ else:
success, config, msg = load_config(config_path) success, config, msg = load_config(config_path)
if not success: if not success:
print('Failed to load config:', msg) logger = logging.getLogger('llm_server')
logger.setLevel(logging.INFO)
logger.error(f'Failed to load config: {msg}')
sys.exit(1) sys.exit(1)
init_logging(Path(config['webserver_log_directory']) / 'server.log') init_logging(Path(config['webserver_log_directory']) / 'server.log')
logger = logging.getLogger('llm_server') logger = create_logger('Server')
logger.debug('Debug logging enabled.')
try:
import vllm
except ModuleNotFoundError as e:
logger.error(f'Could not import vllm-gptq: {e}')
sys.exit(1)
while not redis.get('daemon_started', dtype=bool): while not redis.get('daemon_started', dtype=bool):
logger.warning('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?') logger.warning('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?')
@ -90,7 +92,7 @@ while not redis.get('daemon_started', dtype=bool):
logger.info('Started HTTP worker!') logger.info('Started HTTP worker!')
database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database']) Database.initialise(maxconn=config['mysql']['maxconn'], host=config['mysql']['host'], user=config['mysql']['username'], password=config['mysql']['password'], database=config['mysql']['database'])
create_db() create_db()
app = Flask(__name__) app = Flask(__name__)

View File

@ -1,6 +1,5 @@
<!DOCTYPE html> <!DOCTYPE html>
<html lang="en"> <html lang="en">
<head> <head>
<title>{{ llm_middleware_name }}</title> <title>{{ llm_middleware_name }}</title>
<meta content="width=device-width, initial-scale=1" name="viewport"/> <meta content="width=device-width, initial-scale=1" name="viewport"/>
@ -112,7 +111,8 @@
<li>Set your API type to <kbd>{{ mode_name }}</kbd></li> <li>Set your API type to <kbd>{{ mode_name }}</kbd></li>
<li>Enter <kbd>{{ client_api }}</kbd> in the <kbd>{{ api_input_textbox }}</kbd> textbox.</li> <li>Enter <kbd>{{ client_api }}</kbd> in the <kbd>{{ api_input_textbox }}</kbd> textbox.</li>
{% if enable_streaming %} {% if enable_streaming %}
<li>Enter <kbd>{{ ws_client_api }}</kbd> in the <kbd>{{ streaming_input_textbox }}</kbd> textbox.</li> <li>Enter <kbd>{{ ws_client_api }}</kbd> in the <kbd>{{ streaming_input_textbox }}</kbd> textbox.
</li>
{% endif %} {% endif %}
<li>If you have a token, check the <kbd>Mancer AI</kbd> checkbox and enter your token in the <kbd>Mancer <li>If you have a token, check the <kbd>Mancer AI</kbd> checkbox and enter your token in the <kbd>Mancer
API key</kbd> textbox. API key</kbd> textbox.
@ -127,7 +127,8 @@
<br> <br>
<div id="openai"> <div id="openai">
<strong>OpenAI-Compatible API</strong> <strong>OpenAI-Compatible API</strong>
<p>The OpenAI-compatible API adds a system prompt to set the AI's behavior to a "helpful assistant". You can view this prompt <a href="/api/openai/v1/prompt">here</a>.</p> <p>The OpenAI-compatible API adds a system prompt to set the AI's behavior to a "helpful assistant". You
can view this prompt <a href="/api/openai/v1/prompt">here</a>.</p>
</div> </div>
{% endif %} {% endif %}
<br> <br>
@ -148,7 +149,8 @@
{% for key, value in model_choices.items() %} {% for key, value in model_choices.items() %}
<div class="info-box"> <div class="info-box">
<h3>{{ key }} <span class="header-workers">- {{ value.backend_count }} {% if value.backend_count == 1 %}worker{% else %}workers{% endif %}</span></h3> <h3>{{ key }} <span class="header-workers">- {{ value.backend_count }} {% if value.backend_count == 1 %}
worker{% else %}workers{% endif %}</span></h3>
{% if value.estimated_wait == 0 and value.estimated_wait >= value.concurrent_gens %} {% if value.estimated_wait == 0 and value.estimated_wait >= value.concurrent_gens %}
{# There will be a wait if the queue is empty but prompts are processing, but we don't know how long. #} {# There will be a wait if the queue is empty but prompts are processing, but we don't know how long. #}