diff --git a/daemon.py b/daemon.py index 69e8532..3f67a76 100644 --- a/daemon.py +++ b/daemon.py @@ -28,7 +28,6 @@ if __name__ == "__main__": parser.add_argument('-d', '--debug', action='store_true', help='Enable debug logging.') args = parser.parse_args() - # TODO: have this be set by either the arg or a config value if args.debug: logging_info.level = logging.DEBUG diff --git a/llm_server/cluster/cluster_config.py b/llm_server/cluster/cluster_config.py index 277cdb1..8040fa0 100644 --- a/llm_server/cluster/cluster_config.py +++ b/llm_server/cluster/cluster_config.py @@ -6,11 +6,14 @@ from llm_server import opts from llm_server.cluster.redis_cycle import add_backend_cycler, redis_cycle from llm_server.cluster.stores import redis_running_models from llm_server.custom_redis import RedisCustom +from llm_server.logging import create_logger from llm_server.routes.helpers.model import estimate_model_size - # Don't try to reorganize this file or else you'll run into circular imports. +_logger = create_logger('redis') + + class RedisClusterStore: """ A class used to store the cluster state in Redis. @@ -67,7 +70,7 @@ class RedisClusterStore: if not backend_info['online']: old = backend_url backend_url = get_a_cluster_backend() - print(f'Backend {old} offline. Request was redirected to {backend_url}') + _logger.debug(f'Backend {old} offline. Request was redirected to {backend_url}') return backend_url @@ -108,8 +111,7 @@ def get_backends(): ) return [url for url, info in online_backends], [url for url, info in offline_backends] except KeyError: - traceback.print_exc() - print(backends) + _logger.err(f'Failed to get a backend from the cluster config: {traceback.format_exc()}\nCurrent backends: {backends}') def get_a_cluster_backend(model=None): diff --git a/llm_server/config/load.py b/llm_server/config/load.py index 9c95c79..0917f25 100644 --- a/llm_server/config/load.py +++ b/llm_server/config/load.py @@ -7,10 +7,13 @@ import llm_server from llm_server import opts from llm_server.config.config import ConfigLoader, config_default_vars, config_required_vars from llm_server.custom_redis import redis -from llm_server.database.conn import database +from llm_server.database.conn import Database from llm_server.database.database import get_number_of_rows +from llm_server.logging import create_logger from llm_server.routes.queue import PriorityQueue +_logger = create_logger('config') + def load_config(config_path): config_loader = ConfigLoader(config_path, config_default_vars, config_required_vars) @@ -58,7 +61,7 @@ def load_config(config_path): llm_server.routes.queue.priority_queue = PriorityQueue([x['backend_url'] for x in config['cluster']]) if opts.openai_expose_our_model and not opts.openai_api_key: - print('If you set openai_expose_our_model to false, you must set your OpenAI key in openai_api_key.') + _logger.error('If you set openai_expose_our_model to false, you must set your OpenAI key in openai_api_key.') sys.exit(1) opts.verify_ssl = config['verify_ssl'] @@ -67,11 +70,11 @@ def load_config(config_path): urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) if config['http_host']: - http_host = re.sub(r'http(?:s)?://', '', config["http_host"]) + http_host = re.sub(r'https?://', '', config["http_host"]) redis.set('http_host', http_host) redis.set('base_client_api', f'{http_host}/{opts.frontend_api_client.strip("/")}') - database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database']) + Database.initialise(maxconn=config['mysql']['maxconn'], host=config['mysql']['host'], user=config['mysql']['username'], password=config['mysql']['password'], database=config['mysql']['database']) if config['load_num_prompts']: redis.set('proompts', get_number_of_rows('prompts')) diff --git a/llm_server/custom_redis.py b/llm_server/custom_redis.py index 00abb06..34f33e9 100644 --- a/llm_server/custom_redis.py +++ b/llm_server/custom_redis.py @@ -1,3 +1,4 @@ +import logging import pickle import sys import traceback @@ -29,8 +30,9 @@ class RedisCustom(Redis): try: self.set('____', 1) except redis_pkg.exceptions.ConnectionError as e: - print('Failed to connect to the Redis server:', e) - print('Did you install and start the Redis server?') + logger = logging.getLogger('redis') + logger.setLevel(logging.INFO) + logger.error(f'Failed to connect to the Redis server: {e}\nDid you install and start the Redis server?') sys.exit(1) def _key(self, key): diff --git a/llm_server/database/conn.py b/llm_server/database/conn.py index f63f555..261b4c5 100644 --- a/llm_server/database/conn.py +++ b/llm_server/database/conn.py @@ -1,28 +1,40 @@ -import pymysql +from mysql.connector import pooling -class DatabaseConnection: - host: str = None - username: str = None - password: str = None - database_name: str = None +class Database: + __connection_pool = None - def init_db(self, host, username, password, database_name): - self.host = host - self.username = username - self.password = password - self.database_name = database_name + @classmethod + def initialise(cls, maxconn, **kwargs): + if cls.__connection_pool is not None: + raise Exception('Database connection pool is already initialised') + cls.__connection_pool = pooling.MySQLConnectionPool(pool_size=maxconn, + pool_reset_session=True, + **kwargs) - def cursor(self): - db = pymysql.connect( - host=self.host, - user=self.username, - password=self.password, - database=self.database_name, - charset='utf8mb4', - autocommit=True, - ) - return db.cursor() + @classmethod + def get_connection(cls): + return cls.__connection_pool.get_connection() + + @classmethod + def return_connection(cls, connection): + connection.close() -database = DatabaseConnection() +class CursorFromConnectionFromPool: + def __init__(self): + self.conn = None + self.cursor = None + + def __enter__(self): + self.conn = Database.get_connection() + self.cursor = self.conn.cursor() + return self.cursor + + def __exit__(self, exception_type, exception_value, exception_traceback): + if exception_value is not None: # This is equivalent of saying if there is an exception + self.conn.rollback() + else: + self.cursor.close() + self.conn.commit() + Database.return_connection(self.conn) diff --git a/llm_server/database/create.py b/llm_server/database/create.py index c1788ae..302f8a6 100644 --- a/llm_server/database/create.py +++ b/llm_server/database/create.py @@ -1,40 +1,39 @@ -from llm_server.database.conn import database +from llm_server.database.conn import CursorFromConnectionFromPool def create_db(): - cursor = database.cursor() - cursor.execute(''' - CREATE TABLE IF NOT EXISTS prompts ( - ip TEXT, - token TEXT DEFAULT NULL, - model TEXT, - backend_mode TEXT, - backend_url TEXT, - request_url TEXT, - generation_time FLOAT, - prompt LONGTEXT, - prompt_tokens INTEGER, - response LONGTEXT, - response_tokens INTEGER, - response_status INTEGER, - parameters TEXT, - # CHECK (parameters IS NULL OR JSON_VALID(parameters)), - headers TEXT, - # CHECK (headers IS NULL OR JSON_VALID(headers)), - timestamp INTEGER - ) - ''') - cursor.execute(''' - CREATE TABLE IF NOT EXISTS token_auth ( - token TEXT, - UNIQUE (token), - type TEXT NOT NULL, - priority INTEGER DEFAULT 9999, - simultaneous_ip INTEGER DEFAULT NULL, - uses INTEGER DEFAULT 0, - max_uses INTEGER, - expire INTEGER, - disabled BOOLEAN DEFAULT 0 - ) - ''') - cursor.close() + with CursorFromConnectionFromPool() as cursor: + cursor.execute(''' + CREATE TABLE IF NOT EXISTS prompts ( + ip TEXT, + token TEXT DEFAULT NULL, + model TEXT, + backend_mode TEXT, + backend_url TEXT, + request_url TEXT, + generation_time FLOAT, + prompt LONGTEXT, + prompt_tokens INTEGER, + response LONGTEXT, + response_tokens INTEGER, + response_status INTEGER, + parameters TEXT, + # CHECK (parameters IS NULL OR JSON_VALID(parameters)), + headers TEXT, + # CHECK (headers IS NULL OR JSON_VALID(headers)), + timestamp INTEGER + ) + ''') + cursor.execute(''' + CREATE TABLE IF NOT EXISTS token_auth ( + token TEXT, + UNIQUE (token), + type TEXT NOT NULL, + priority INTEGER DEFAULT 9999, + simultaneous_ip INTEGER DEFAULT NULL, + uses INTEGER DEFAULT 0, + max_uses INTEGER, + expire INTEGER, + disabled BOOLEAN DEFAULT 0 + ) + ''') diff --git a/llm_server/database/database.py b/llm_server/database/database.py index d6bd6b2..a1c5d80 100644 --- a/llm_server/database/database.py +++ b/llm_server/database/database.py @@ -5,7 +5,7 @@ from typing import Union from llm_server import opts from llm_server.cluster.cluster_config import cluster_config -from llm_server.database.conn import database +from llm_server.database.conn import CursorFromConnectionFromPool from llm_server.llm import get_token_count @@ -52,21 +52,17 @@ def do_db_log(ip: str, token: str, prompt: str, response: Union[str, None], gen_ running_model = backend_info.get('model') backend_mode = backend_info['mode'] timestamp = int(time.time()) - cursor = database.cursor() - try: + with CursorFromConnectionFromPool() as cursor: cursor.execute(""" - INSERT INTO prompts - (ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp) - VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) - """, + INSERT INTO prompts + (ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + """, (ip, token, running_model, backend_mode, backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp)) - finally: - cursor.close() def is_valid_api_key(api_key): - cursor = database.cursor() - try: + with CursorFromConnectionFromPool() as cursor: cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = %s", (api_key,)) row = cursor.fetchone() if row is not None: @@ -75,52 +71,38 @@ def is_valid_api_key(api_key): if ((uses is None or max_uses is None) or uses < max_uses) and (expire is None or expire > time.time()) and not disabled: return True return False - finally: - cursor.close() def is_api_key_moderated(api_key): if not api_key: return opts.openai_moderation_enabled - cursor = database.cursor() - try: + with CursorFromConnectionFromPool() as cursor: cursor.execute("SELECT openai_moderation_enabled FROM token_auth WHERE token = %s", (api_key,)) row = cursor.fetchone() if row is not None: return bool(row[0]) return opts.openai_moderation_enabled - finally: - cursor.close() def get_number_of_rows(table_name): - cursor = database.cursor() - try: + with CursorFromConnectionFromPool() as cursor: cursor.execute(f"SELECT COUNT(*) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%' OR token IS NULL") result = cursor.fetchone() return result[0] - finally: - cursor.close() def average_column(table_name, column_name): - cursor = database.cursor() - try: + with CursorFromConnectionFromPool() as cursor: cursor.execute(f"SELECT AVG({column_name}) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%' OR token IS NULL") result = cursor.fetchone() return result[0] - finally: - cursor.close() def average_column_for_model(table_name, column_name, model_name): - cursor = database.cursor() - try: + with CursorFromConnectionFromPool() as cursor: cursor.execute(f"SELECT AVG({column_name}) FROM {table_name} WHERE model = %s AND token NOT LIKE 'SYSTEM__%%' OR token IS NULL", (model_name,)) result = cursor.fetchone() return result[0] - finally: - cursor.close() def weighted_average_column_for_model(table_name, column_name, model_name, backend_name, backend_url, exclude_zeros: bool = False, include_system_tokens: bool = True): @@ -129,8 +111,7 @@ def weighted_average_column_for_model(table_name, column_name, model_name, backe else: sql = f"SELECT {column_name}, id FROM {table_name} WHERE model = %s AND backend_mode = %s AND backend_url = %s AND (token NOT LIKE 'SYSTEM__%%' OR token IS NULL) ORDER BY id DESC" - cursor = database.cursor() - try: + with CursorFromConnectionFromPool() as cursor: try: cursor.execute(sql, (model_name, backend_name, backend_url,)) results = cursor.fetchall() @@ -154,46 +135,34 @@ def weighted_average_column_for_model(table_name, column_name, model_name, backe calculated_avg = 0 return calculated_avg - finally: - cursor.close() def sum_column(table_name, column_name): - cursor = database.cursor() - try: + with CursorFromConnectionFromPool() as cursor: cursor.execute(f"SELECT SUM({column_name}) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%' OR token IS NULL") result = cursor.fetchone() return result[0] if result else 0 - finally: - cursor.close() def get_distinct_ips_24h(): # Get the current time and subtract 24 hours (in seconds) past_24_hours = int(time.time()) - 24 * 60 * 60 - cursor = database.cursor() - try: + with CursorFromConnectionFromPool() as cursor: cursor.execute("SELECT COUNT(DISTINCT ip) FROM prompts WHERE timestamp >= %s AND (token NOT LIKE 'SYSTEM__%%' OR token IS NULL)", (past_24_hours,)) result = cursor.fetchone() return result[0] if result else 0 - finally: - cursor.close() def increment_token_uses(token): - cursor = database.cursor() - try: + with CursorFromConnectionFromPool() as cursor: cursor.execute('UPDATE token_auth SET uses = uses + 1 WHERE token = %s', (token,)) - finally: - cursor.close() def get_token_ratelimit(token): priority = 9990 simultaneous_ip = opts.simultaneous_requests_per_ip if token: - cursor = database.cursor() - try: + with CursorFromConnectionFromPool() as cursor: cursor.execute("SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s", (token,)) result = cursor.fetchone() if result: @@ -201,6 +170,4 @@ def get_token_ratelimit(token): if simultaneous_ip is None: # No ratelimit for this token if null simultaneous_ip = 999999999 - finally: - cursor.close() return priority, simultaneous_ip diff --git a/llm_server/llm/openai/moderation.py b/llm_server/llm/openai/moderation.py index f62241d..ee63e5a 100644 --- a/llm_server/llm/openai/moderation.py +++ b/llm_server/llm/openai/moderation.py @@ -1,6 +1,9 @@ import requests from llm_server import opts +from llm_server.logging import create_logger + +_logger = create_logger('moderation') def check_moderation_endpoint(prompt: str): @@ -10,7 +13,7 @@ def check_moderation_endpoint(prompt: str): } response = requests.post('https://api.openai.com/v1/moderations', headers=headers, json={"input": prompt}, timeout=10) if response.status_code != 200: - print('moderation failed:', response) + _logger.error(f'moderation failed: {response}') response.raise_for_status() response = response.json() diff --git a/llm_server/llm/openai/oai_to_vllm.py b/llm_server/llm/openai/oai_to_vllm.py index ef07a08..feb9364 100644 --- a/llm_server/llm/openai/oai_to_vllm.py +++ b/llm_server/llm/openai/oai_to_vllm.py @@ -1,6 +1,9 @@ from flask import jsonify from llm_server import opts +from llm_server.logging import create_logger + +_logger = create_logger('oai_to_vllm') def oai_to_vllm(request_json_body, stop_hashes: bool, mode): @@ -36,7 +39,7 @@ def oai_to_vllm(request_json_body, stop_hashes: bool, mode): def format_oai_err(err_msg): - print('OAI ERROR MESSAGE:', err_msg) + _logger.error(f'Got an OAI error message: {err_msg}') return jsonify({ "error": { "message": err_msg, diff --git a/llm_server/logging.py b/llm_server/logging.py index 53f6d72..1d42fd0 100644 --- a/llm_server/logging.py +++ b/llm_server/logging.py @@ -39,6 +39,7 @@ def init_logging(filepath: Path = None): This is only called by `server.py` since there is wierdness with Gunicorn. The deamon doesn't need this. :return: """ + global LOG_DIRECTORY logger = logging.getLogger('llm_server') logger.setLevel(logging_info.level) diff --git a/llm_server/routes/ooba_request_handler.py b/llm_server/routes/ooba_request_handler.py index aadda78..e834873 100644 --- a/llm_server/routes/ooba_request_handler.py +++ b/llm_server/routes/ooba_request_handler.py @@ -5,9 +5,12 @@ from flask import jsonify, request from llm_server import messages, opts from llm_server.database.log_to_db import log_to_db +from llm_server.logging import create_logger from llm_server.routes.helpers.client import format_sillytavern_err from llm_server.routes.request_handler import RequestHandler +_logger = create_logger('OobaRequestHandler') + class OobaRequestHandler(RequestHandler): def __init__(self, *args, **kwargs): @@ -16,7 +19,7 @@ class OobaRequestHandler(RequestHandler): def handle_request(self, return_ok: bool = True): assert not self.used if self.offline: - print('This backend is offline:', messages.BACKEND_OFFLINE) + # _logger.debug(f'This backend is offline.') return self.handle_error(messages.BACKEND_OFFLINE) request_valid, invalid_response = self.validate_request() diff --git a/llm_server/routes/openai/__init__.py b/llm_server/routes/openai/__init__.py index c3837e4..622cf58 100644 --- a/llm_server/routes/openai/__init__.py +++ b/llm_server/routes/openai/__init__.py @@ -1,8 +1,10 @@ from flask import Blueprint from ..request_handler import before_request -from ..server_error import handle_server_error from ... import opts +from ...logging import create_logger + +_logger = create_logger('OpenAI') openai_bp = Blueprint('openai/v1/', __name__) openai_model_bp = Blueprint('openai/', __name__) @@ -24,7 +26,7 @@ def handle_error(e): "auth_subrequest_error" """ - print('OAI returning error:', e) + _logger(f'OAI returning error: {e}') return jsonify({ "error": { "message": "Internal server error", diff --git a/llm_server/routes/openai/chat_completions.py b/llm_server/routes/openai/chat_completions.py index 9ccc15f..3a0ecef 100644 --- a/llm_server/routes/openai/chat_completions.py +++ b/llm_server/routes/openai/chat_completions.py @@ -15,6 +15,9 @@ from ... import opts from ...database.log_to_db import log_to_db from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit +from ...logging import create_logger + +_logger = create_logger('OpenAIChatCompletions') # TODO: add rate-limit headers? @@ -99,7 +102,7 @@ def openai_chat_completions(model_name=None): # return a 408 if necessary. _, stream_name, error_msg = event.wait() if error_msg: - print('OAI failed to start streaming:', error_msg) + _logger.error(f'OAI failed to start streaming: {error_msg}') stream_name = None # set to null so that the Finally ignores it. return 'Request Timeout', 408 @@ -111,7 +114,7 @@ def openai_chat_completions(model_name=None): while True: stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout) if not stream_data: - print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.") + _logger.debug(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.") yield 'data: [DONE]\n\n' else: for stream_index, item in stream_data[0][1]: @@ -120,7 +123,7 @@ def openai_chat_completions(model_name=None): data = ujson.loads(item[b'data']) if data['error']: # Not printing error since we can just check the daemon log. - print('OAI streaming encountered error') + _logger.warn(f'OAI streaming encountered error: {data["error"]}') yield 'data: [DONE]\n\n' return elif data['new']: diff --git a/llm_server/routes/openai/completions.py b/llm_server/routes/openai/completions.py index 2524b17..908c920 100644 --- a/llm_server/routes/openai/completions.py +++ b/llm_server/routes/openai/completions.py @@ -16,10 +16,13 @@ from ...database.log_to_db import log_to_db from ...llm import get_token_count from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai from ...llm.openai.transform import generate_oai_string, trim_string_to_fit - +from ...logging import create_logger # TODO: add rate-limit headers? +_logger = create_logger('OpenAICompletions') + + @openai_bp.route('/completions', methods=['POST']) @openai_model_bp.route('//v1/completions', methods=['POST']) def openai_completions(model_name=None): @@ -144,7 +147,7 @@ def openai_completions(model_name=None): _, stream_name, error_msg = event.wait() if error_msg: - print('OAI failed to start streaming:', error_msg) + _logger.error(f'OAI failed to start streaming: {error_msg}') stream_name = None return 'Request Timeout', 408 @@ -156,7 +159,7 @@ def openai_completions(model_name=None): while True: stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout) if not stream_data: - print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.") + _logger.debug(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.") yield 'data: [DONE]\n\n' else: for stream_index, item in stream_data[0][1]: @@ -164,7 +167,7 @@ def openai_completions(model_name=None): timestamp = int(stream_index.decode('utf-8').split('-')[0]) data = ujson.loads(item[b'data']) if data['error']: - print('OAI streaming encountered error') + _logger.error(f'OAI streaming encountered error: {data["error"]}') yield 'data: [DONE]\n\n' return elif data['new']: diff --git a/llm_server/routes/openai_request_handler.py b/llm_server/routes/openai_request_handler.py index 170eb77..2bda304 100644 --- a/llm_server/routes/openai_request_handler.py +++ b/llm_server/routes/openai_request_handler.py @@ -16,9 +16,12 @@ from llm_server.database.log_to_db import log_to_db from llm_server.llm import get_token_count from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_invalid_model_err from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit +from llm_server.logging import create_logger from llm_server.routes.request_handler import RequestHandler from llm_server.workers.moderator import add_moderation_task, get_results +_logger = create_logger('OpenAIRequestHandler') + class OpenAIRequestHandler(RequestHandler): def __init__(self, *args, **kwargs): @@ -29,7 +32,7 @@ class OpenAIRequestHandler(RequestHandler): assert not self.used if self.offline: msg = return_invalid_model_err(self.selected_model) - print('OAI Offline:', msg) + _logger.error(f'OAI is offline: {msg}') return self.handle_error(msg) if opts.openai_silent_trim: @@ -72,7 +75,7 @@ class OpenAIRequestHandler(RequestHandler): self.request.json['messages'].insert((len(self.request.json['messages'])), {'role': 'system', 'content': mod_msg}) self.prompt = transform_messages_to_prompt(self.request.json['messages']) except Exception as e: - print(f'OpenAI moderation endpoint failed:', f'{e.__class__.__name__}: {e}') + _logger.error(f'OpenAI moderation endpoint failed: {e.__class__.__name__}: {e}') traceback.print_exc() llm_request = {**self.parameters, 'prompt': self.prompt} @@ -106,7 +109,7 @@ class OpenAIRequestHandler(RequestHandler): return response, 429 def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]: - print('OAI Error:', error_msg) + _logger.error('OAI Error: {error_msg}') return jsonify({ "error": { "message": "Invalid request, check your parameters and try again.", @@ -155,7 +158,7 @@ class OpenAIRequestHandler(RequestHandler): def validate_request(self, prompt: str = None, do_log: bool = False) -> Tuple[bool, Tuple[Response | None, int]]: self.parameters, parameters_invalid_msg = self.get_parameters() if not self.parameters: - print('OAI BACKEND VALIDATION ERROR:', parameters_invalid_msg) + _logger.error(f'OAI BACKEND VALIDATION ERROR: {parameters_invalid_msg}') return False, (Response('Invalid request, check your parameters and try again.'), 400) invalid_oai_err_msg = validate_oai(self.parameters) if invalid_oai_err_msg: diff --git a/llm_server/routes/queue.py b/llm_server/routes/queue.py index ee66580..562bbed 100644 --- a/llm_server/routes/queue.py +++ b/llm_server/routes/queue.py @@ -10,6 +10,7 @@ from llm_server import opts from llm_server.cluster.cluster_config import cluster_config from llm_server.custom_redis import RedisCustom, redis from llm_server.database.database import get_token_ratelimit +from llm_server.logging import create_logger def increment_ip_count(client_ip: str, redis_key): @@ -30,6 +31,7 @@ class RedisPriorityQueue: def __init__(self, name, db: int = 12): self.name = name self.redis = RedisCustom(name, db=db) + self._logger = create_logger('RedisPriorityQueue') def put(self, item, priority: int, selected_model: str, do_stream: bool = False): # TODO: remove this when we're sure nothing strange is happening @@ -41,7 +43,7 @@ class RedisPriorityQueue: ip_count = self.get_ip_request_count(item[1]) _, simultaneous_ip = get_token_ratelimit(item[2]) if ip_count and int(ip_count) >= simultaneous_ip and priority != 0: - print(f'Rejecting request from {item[1]} - {ip_count} request queued.') + self._logger.debug(f'Rejecting request from {item[1]} - {ip_count} request queued.') return None # reject the request timestamp = time.time() @@ -98,7 +100,7 @@ class RedisPriorityQueue: event_id = item_data[1] event = DataEvent(event_id) event.set((False, None, 'closed')) - print('Removed timed-out item from queue:', event_id) + self._logger.debug('Removed timed-out item from queue: {event_id}') class DataEvent: diff --git a/llm_server/routes/request_handler.py b/llm_server/routes/request_handler.py index aadb443..4dee1ff 100644 --- a/llm_server/routes/request_handler.py +++ b/llm_server/routes/request_handler.py @@ -12,10 +12,13 @@ from llm_server.database.log_to_db import log_to_db from llm_server.helpers import auto_set_base_client_api from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend from llm_server.llm.vllm.vllm_backend import VLLMBackend +from llm_server.logging import create_logger from llm_server.routes.auth import parse_token from llm_server.routes.helpers.http import require_api_key, validate_json from llm_server.routes.queue import priority_queue +_logger = create_logger('RequestHandler') + class RequestHandler: def __init__(self, incoming_request: flask.Request, selected_model: str = None, incoming_json: Union[dict, str] = None): @@ -223,7 +226,7 @@ class RequestHandler: processing_ip = 0 if queued_ip_count + processing_ip >= self.token_simultaneous_ip: - print(f'Rejecting request from {self.client_ip} - {processing_ip} processing, {queued_ip_count} queued') + _logger.debug(f'Rejecting request from {self.client_ip} - {processing_ip} processing, {queued_ip_count} queued') return True else: return False diff --git a/llm_server/routes/server_error.py b/llm_server/routes/server_error.py index a6d6f99..62d202f 100644 --- a/llm_server/routes/server_error.py +++ b/llm_server/routes/server_error.py @@ -1,3 +1,8 @@ +from llm_server.logging import create_logger + +_logger = create_logger('handle_server_error') + + def handle_server_error(e): - print('Internal Error:', e) + _logger.error(f'Internal Error: {e}') return {'error': True, 'code': 500, 'message': 'Internal Server Error :('}, 500 diff --git a/llm_server/routes/v1/generate_stream.py b/llm_server/routes/v1/generate_stream.py index 3ed2f58..36b7b39 100644 --- a/llm_server/routes/v1/generate_stream.py +++ b/llm_server/routes/v1/generate_stream.py @@ -13,12 +13,15 @@ from ..queue import priority_queue from ... import opts from ...custom_redis import redis from ...database.log_to_db import log_to_db +from ...logging import create_logger from ...sock import sock - # Stacking the @sock.route() creates a TypeError error on the /v1/stream endpoint. # We solve this by splitting the routes +_logger = create_logger('GenerateStream') + + @bp.route('/v1/stream') @bp.route('//v1/stream') def stream(model_name=None): @@ -85,7 +88,7 @@ def do_stream(ws, model_name): handler = OobaRequestHandler(incoming_request=request, selected_model=model_name, incoming_json=request_json_body) if handler.offline: msg = f'{handler.selected_model} is not a valid model choice.' - print(msg) + _logger.debug(msg) ws.send(json.dumps({ 'event': 'text_stream', 'message_num': 0, @@ -131,7 +134,7 @@ def do_stream(ws, model_name): _, stream_name, error_msg = event.wait() if error_msg: - print('Stream failed to start streaming:', error_msg) + _logger.error(f'Stream failed to start streaming: {error_msg}') ws.close(reason=1014, message='Request Timeout') return @@ -143,14 +146,14 @@ def do_stream(ws, model_name): while True: stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout) if not stream_data: - print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.") + _logger.error(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.") return else: for stream_index, item in stream_data[0][1]: last_id = stream_index data = ujson.loads(item[b'data']) if data['error']: - print(data['error']) + _logger.error(f'Encountered error while streaming: {data["error"]}') send_err_and_quit('Encountered exception while streaming.') return elif data['new']: diff --git a/llm_server/workers/cleaner.py b/llm_server/workers/cleaner.py index 95a6a78..cbbcb46 100644 --- a/llm_server/workers/cleaner.py +++ b/llm_server/workers/cleaner.py @@ -2,6 +2,7 @@ import time from redis import Redis +from llm_server.logging import create_logger from llm_server.workers.inferencer import STREAM_NAME_PREFIX @@ -10,6 +11,7 @@ from llm_server.workers.inferencer import STREAM_NAME_PREFIX def cleaner(): r = Redis(db=8) stream_info = {} + logger = create_logger('cleaner') while True: all_streams = r.keys(f'{STREAM_NAME_PREFIX}:*') @@ -26,7 +28,7 @@ def cleaner(): # If the size hasn't changed for 5 minutes, delete the stream if time.time() - stream_info[stream]['time'] >= 300: r.delete(stream) - print(f"Stream '{stream}' deleted due to inactivity.") + logger.debug(f"Stream '{stream}' deleted due to inactivity.") del stream_info[stream] time.sleep(60) diff --git a/llm_server/workers/logger.py b/llm_server/workers/logger.py index eada969..58dcaec 100644 --- a/llm_server/workers/logger.py +++ b/llm_server/workers/logger.py @@ -4,6 +4,7 @@ import traceback import redis from llm_server.database.database import do_db_log +from llm_server.logging import create_logger def db_logger(): @@ -16,6 +17,7 @@ def db_logger(): r = redis.Redis(host='localhost', port=6379, db=3) p = r.pubsub() p.subscribe('database-logger') + logger = create_logger('main_bg') for message in p.listen(): try: @@ -28,4 +30,4 @@ def db_logger(): if function_name == 'log_prompt': do_db_log(*args, **kwargs) except: - traceback.print_exc() + logger.error(traceback.format_exc()) diff --git a/llm_server/workers/mainer.py b/llm_server/workers/mainer.py index fb1f3b0..efe31fc 100644 --- a/llm_server/workers/mainer.py +++ b/llm_server/workers/mainer.py @@ -7,10 +7,12 @@ from llm_server.cluster.cluster_config import get_backends, cluster_config from llm_server.custom_redis import redis from llm_server.database.database import weighted_average_column_for_model from llm_server.llm.info import get_info +from llm_server.logging import create_logger from llm_server.routes.queue import RedisPriorityQueue, priority_queue def main_background_thread(): + logger = create_logger('main_bg') while True: online, offline = get_backends() for backend_url in online: @@ -34,7 +36,7 @@ def main_background_thread(): base_client_api = redis.get('base_client_api', dtype=str) r = requests.get('https://' + base_client_api, timeout=5) except Exception as e: - print(f'Failed fetch the homepage - {e.__class__.__name__}: {e}') + logger.error(f'Failed fetch the homepage - {e.__class__.__name__}: {e}') backends = priority_queue.get_backends() for backend_url in backends: diff --git a/requirements.txt b/requirements.txt index ce00783..77ac8c1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ Flask-Caching==2.0.2 requests~=2.31.0 tiktoken~=0.5.0 gevent~=23.9.0.post1 -PyMySQL~=1.1.0 +mysql-connector-python==8.4.0 simplejson~=3.19.1 websockets~=11.0.3 basicauth~=1.0.0 @@ -14,5 +14,4 @@ gunicorn==21.2.0 redis==5.0.1 ujson==5.8.0 vllm==0.2.7 -gradio~=3.46.1 coloredlogs~=15.0.1 diff --git a/server.py b/server.py index fdb6f96..f80c474 100644 --- a/server.py +++ b/server.py @@ -22,11 +22,11 @@ from llm_server.cluster.cluster_config import cluster_config from llm_server.config.config import mode_ui_names from llm_server.config.load import load_config from llm_server.custom_redis import flask_cache, redis -from llm_server.database.conn import database +from llm_server.database.conn import database, Database from llm_server.database.create import create_db from llm_server.helpers import auto_set_base_client_api from llm_server.llm.vllm.info import vllm_info -from llm_server.logging import init_logging +from llm_server.logging import init_logging, create_logger from llm_server.routes.openai import openai_bp, openai_model_bp from llm_server.routes.server_error import handle_server_error from llm_server.routes.v1 import bp @@ -62,13 +62,6 @@ from llm_server.sock import init_wssocket # TODO: add more excluding to SYSTEM__ tokens # TODO: return 200 when returning formatted sillytavern error -try: - import vllm -except ModuleNotFoundError as e: - print('Could not import vllm-gptq:', e) - print('Please see README.md for install instructions.') - sys.exit(1) - script_path = os.path.dirname(os.path.realpath(__file__)) config_path_environ = os.getenv("CONFIG_PATH") if config_path_environ: @@ -78,11 +71,20 @@ else: success, config, msg = load_config(config_path) if not success: - print('Failed to load config:', msg) + logger = logging.getLogger('llm_server') + logger.setLevel(logging.INFO) + logger.error(f'Failed to load config: {msg}') sys.exit(1) init_logging(Path(config['webserver_log_directory']) / 'server.log') -logger = logging.getLogger('llm_server') +logger = create_logger('Server') +logger.debug('Debug logging enabled.') + +try: + import vllm +except ModuleNotFoundError as e: + logger.error(f'Could not import vllm-gptq: {e}') + sys.exit(1) while not redis.get('daemon_started', dtype=bool): logger.warning('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?') @@ -90,7 +92,7 @@ while not redis.get('daemon_started', dtype=bool): logger.info('Started HTTP worker!') -database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database']) +Database.initialise(maxconn=config['mysql']['maxconn'], host=config['mysql']['host'], user=config['mysql']['username'], password=config['mysql']['password'], database=config['mysql']['database']) create_db() app = Flask(__name__) diff --git a/templates/home.html b/templates/home.html index d6d4a57..42d8ed7 100644 --- a/templates/home.html +++ b/templates/home.html @@ -1,6 +1,5 @@ - {{ llm_middleware_name }} @@ -97,8 +96,8 @@

Streaming API URL: {{ ws_client_api if enable_streaming else 'Disabled' }}

OpenAI-Compatible API URL: {{ openai_client_api }}

{% if info_html|length > 1 %} -
- {{ info_html|safe }} +
+ {{ info_html|safe }} {% endif %} @@ -112,7 +111,8 @@
  • Set your API type to {{ mode_name }}
  • Enter {{ client_api }} in the {{ api_input_textbox }} textbox.
  • {% if enable_streaming %} -
  • Enter {{ ws_client_api }} in the {{ streaming_input_textbox }} textbox.
  • +
  • Enter {{ ws_client_api }} in the {{ streaming_input_textbox }} textbox. +
  • {% endif %}
  • If you have a token, check the Mancer AI checkbox and enter your token in the Mancer API key textbox. @@ -124,11 +124,12 @@ {% if openai_client_api != 'disabled' and expose_openai_system_prompt %} -
    -
    - OpenAI-Compatible API -

    The OpenAI-compatible API adds a system prompt to set the AI's behavior to a "helpful assistant". You can view this prompt here.

    -
    +
    +
    + OpenAI-Compatible API +

    The OpenAI-compatible API adds a system prompt to set the AI's behavior to a "helpful assistant". You + can view this prompt here.

    +
    {% endif %}
    {{ extra_info|safe }}
    @@ -147,30 +148,31 @@
    {% for key, value in model_choices.items() %} -
    -

    {{ key }} - {{ value.backend_count }} {% if value.backend_count == 1 %}worker{% else %}workers{% endif %}

    +
    +

    {{ key }} - {{ value.backend_count }} {% if value.backend_count == 1 %} + worker{% else %}workers{% endif %}

    - {% if value.estimated_wait == 0 and value.estimated_wait >= value.concurrent_gens %} - {# There will be a wait if the queue is empty but prompts are processing, but we don't know how long. #} - {% set estimated_wait_sec = "less than " + value.estimated_wait|int|string + " seconds" %} - {% else %} - {% set estimated_wait_sec = value.estimated_wait|int|string + " seconds" %} - {% endif %} + {% if value.estimated_wait == 0 and value.estimated_wait >= value.concurrent_gens %} + {# There will be a wait if the queue is empty but prompts are processing, but we don't know how long. #} + {% set estimated_wait_sec = "less than " + value.estimated_wait|int|string + " seconds" %} + {% else %} + {% set estimated_wait_sec = value.estimated_wait|int|string + " seconds" %} + {% endif %} -

    - Estimated Wait Time: {{ estimated_wait_sec }}
    - Processing: {{ value.processing }}
    - Queued: {{ value.queued }}
    -

    -

    - Client API URL: {{ value.client_api }}
    - Streaming API URL: {{ value.ws_client_api }}
    - OpenAI-Compatible API URL: {{ value.openai_client_api }} -

    -

    Context Size: {{ value.context_size }}

    -

    Average Generation Time: {{ value.avg_generation_time | int }} seconds

    -
    -
    +

    + Estimated Wait Time: {{ estimated_wait_sec }}
    + Processing: {{ value.processing }}
    + Queued: {{ value.queued }}
    +

    +

    + Client API URL: {{ value.client_api }}
    + Streaming API URL: {{ value.ws_client_api }}
    + OpenAI-Compatible API URL: {{ value.openai_client_api }} +

    +

    Context Size: {{ value.context_size }}

    +

    Average Generation Time: {{ value.avg_generation_time | int }} seconds

    +
    +
    {% endfor %}