redo database connection, add pooling, minor logging changes, other clean up

This commit is contained in:
Cyberes 2024-05-07 09:48:51 -06:00
parent ab408c6c5b
commit ff82add09e
25 changed files with 219 additions and 192 deletions

View File

@ -28,7 +28,6 @@ if __name__ == "__main__":
parser.add_argument('-d', '--debug', action='store_true', help='Enable debug logging.')
args = parser.parse_args()
# TODO: have this be set by either the arg or a config value
if args.debug:
logging_info.level = logging.DEBUG

View File

@ -6,11 +6,14 @@ from llm_server import opts
from llm_server.cluster.redis_cycle import add_backend_cycler, redis_cycle
from llm_server.cluster.stores import redis_running_models
from llm_server.custom_redis import RedisCustom
from llm_server.logging import create_logger
from llm_server.routes.helpers.model import estimate_model_size
# Don't try to reorganize this file or else you'll run into circular imports.
_logger = create_logger('redis')
class RedisClusterStore:
"""
A class used to store the cluster state in Redis.
@ -67,7 +70,7 @@ class RedisClusterStore:
if not backend_info['online']:
old = backend_url
backend_url = get_a_cluster_backend()
print(f'Backend {old} offline. Request was redirected to {backend_url}')
_logger.debug(f'Backend {old} offline. Request was redirected to {backend_url}')
return backend_url
@ -108,8 +111,7 @@ def get_backends():
)
return [url for url, info in online_backends], [url for url, info in offline_backends]
except KeyError:
traceback.print_exc()
print(backends)
_logger.err(f'Failed to get a backend from the cluster config: {traceback.format_exc()}\nCurrent backends: {backends}')
def get_a_cluster_backend(model=None):

View File

@ -7,10 +7,13 @@ import llm_server
from llm_server import opts
from llm_server.config.config import ConfigLoader, config_default_vars, config_required_vars
from llm_server.custom_redis import redis
from llm_server.database.conn import database
from llm_server.database.conn import Database
from llm_server.database.database import get_number_of_rows
from llm_server.logging import create_logger
from llm_server.routes.queue import PriorityQueue
_logger = create_logger('config')
def load_config(config_path):
config_loader = ConfigLoader(config_path, config_default_vars, config_required_vars)
@ -58,7 +61,7 @@ def load_config(config_path):
llm_server.routes.queue.priority_queue = PriorityQueue([x['backend_url'] for x in config['cluster']])
if opts.openai_expose_our_model and not opts.openai_api_key:
print('If you set openai_expose_our_model to false, you must set your OpenAI key in openai_api_key.')
_logger.error('If you set openai_expose_our_model to false, you must set your OpenAI key in openai_api_key.')
sys.exit(1)
opts.verify_ssl = config['verify_ssl']
@ -67,11 +70,11 @@ def load_config(config_path):
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
if config['http_host']:
http_host = re.sub(r'http(?:s)?://', '', config["http_host"])
http_host = re.sub(r'https?://', '', config["http_host"])
redis.set('http_host', http_host)
redis.set('base_client_api', f'{http_host}/{opts.frontend_api_client.strip("/")}')
database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database'])
Database.initialise(maxconn=config['mysql']['maxconn'], host=config['mysql']['host'], user=config['mysql']['username'], password=config['mysql']['password'], database=config['mysql']['database'])
if config['load_num_prompts']:
redis.set('proompts', get_number_of_rows('prompts'))

View File

@ -1,3 +1,4 @@
import logging
import pickle
import sys
import traceback
@ -29,8 +30,9 @@ class RedisCustom(Redis):
try:
self.set('____', 1)
except redis_pkg.exceptions.ConnectionError as e:
print('Failed to connect to the Redis server:', e)
print('Did you install and start the Redis server?')
logger = logging.getLogger('redis')
logger.setLevel(logging.INFO)
logger.error(f'Failed to connect to the Redis server: {e}\nDid you install and start the Redis server?')
sys.exit(1)
def _key(self, key):

View File

@ -1,28 +1,40 @@
import pymysql
from mysql.connector import pooling
class DatabaseConnection:
host: str = None
username: str = None
password: str = None
database_name: str = None
class Database:
__connection_pool = None
def init_db(self, host, username, password, database_name):
self.host = host
self.username = username
self.password = password
self.database_name = database_name
@classmethod
def initialise(cls, maxconn, **kwargs):
if cls.__connection_pool is not None:
raise Exception('Database connection pool is already initialised')
cls.__connection_pool = pooling.MySQLConnectionPool(pool_size=maxconn,
pool_reset_session=True,
**kwargs)
def cursor(self):
db = pymysql.connect(
host=self.host,
user=self.username,
password=self.password,
database=self.database_name,
charset='utf8mb4',
autocommit=True,
)
return db.cursor()
@classmethod
def get_connection(cls):
return cls.__connection_pool.get_connection()
@classmethod
def return_connection(cls, connection):
connection.close()
database = DatabaseConnection()
class CursorFromConnectionFromPool:
def __init__(self):
self.conn = None
self.cursor = None
def __enter__(self):
self.conn = Database.get_connection()
self.cursor = self.conn.cursor()
return self.cursor
def __exit__(self, exception_type, exception_value, exception_traceback):
if exception_value is not None: # This is equivalent of saying if there is an exception
self.conn.rollback()
else:
self.cursor.close()
self.conn.commit()
Database.return_connection(self.conn)

View File

@ -1,40 +1,39 @@
from llm_server.database.conn import database
from llm_server.database.conn import CursorFromConnectionFromPool
def create_db():
cursor = database.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS prompts (
ip TEXT,
token TEXT DEFAULT NULL,
model TEXT,
backend_mode TEXT,
backend_url TEXT,
request_url TEXT,
generation_time FLOAT,
prompt LONGTEXT,
prompt_tokens INTEGER,
response LONGTEXT,
response_tokens INTEGER,
response_status INTEGER,
parameters TEXT,
# CHECK (parameters IS NULL OR JSON_VALID(parameters)),
headers TEXT,
# CHECK (headers IS NULL OR JSON_VALID(headers)),
timestamp INTEGER
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS token_auth (
token TEXT,
UNIQUE (token),
type TEXT NOT NULL,
priority INTEGER DEFAULT 9999,
simultaneous_ip INTEGER DEFAULT NULL,
uses INTEGER DEFAULT 0,
max_uses INTEGER,
expire INTEGER,
disabled BOOLEAN DEFAULT 0
)
''')
cursor.close()
with CursorFromConnectionFromPool() as cursor:
cursor.execute('''
CREATE TABLE IF NOT EXISTS prompts (
ip TEXT,
token TEXT DEFAULT NULL,
model TEXT,
backend_mode TEXT,
backend_url TEXT,
request_url TEXT,
generation_time FLOAT,
prompt LONGTEXT,
prompt_tokens INTEGER,
response LONGTEXT,
response_tokens INTEGER,
response_status INTEGER,
parameters TEXT,
# CHECK (parameters IS NULL OR JSON_VALID(parameters)),
headers TEXT,
# CHECK (headers IS NULL OR JSON_VALID(headers)),
timestamp INTEGER
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS token_auth (
token TEXT,
UNIQUE (token),
type TEXT NOT NULL,
priority INTEGER DEFAULT 9999,
simultaneous_ip INTEGER DEFAULT NULL,
uses INTEGER DEFAULT 0,
max_uses INTEGER,
expire INTEGER,
disabled BOOLEAN DEFAULT 0
)
''')

View File

@ -5,7 +5,7 @@ from typing import Union
from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config
from llm_server.database.conn import database
from llm_server.database.conn import CursorFromConnectionFromPool
from llm_server.llm import get_token_count
@ -52,21 +52,17 @@ def do_db_log(ip: str, token: str, prompt: str, response: Union[str, None], gen_
running_model = backend_info.get('model')
backend_mode = backend_info['mode']
timestamp = int(time.time())
cursor = database.cursor()
try:
with CursorFromConnectionFromPool() as cursor:
cursor.execute("""
INSERT INTO prompts
(ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""",
INSERT INTO prompts
(ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""",
(ip, token, running_model, backend_mode, backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
finally:
cursor.close()
def is_valid_api_key(api_key):
cursor = database.cursor()
try:
with CursorFromConnectionFromPool() as cursor:
cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = %s", (api_key,))
row = cursor.fetchone()
if row is not None:
@ -75,52 +71,38 @@ def is_valid_api_key(api_key):
if ((uses is None or max_uses is None) or uses < max_uses) and (expire is None or expire > time.time()) and not disabled:
return True
return False
finally:
cursor.close()
def is_api_key_moderated(api_key):
if not api_key:
return opts.openai_moderation_enabled
cursor = database.cursor()
try:
with CursorFromConnectionFromPool() as cursor:
cursor.execute("SELECT openai_moderation_enabled FROM token_auth WHERE token = %s", (api_key,))
row = cursor.fetchone()
if row is not None:
return bool(row[0])
return opts.openai_moderation_enabled
finally:
cursor.close()
def get_number_of_rows(table_name):
cursor = database.cursor()
try:
with CursorFromConnectionFromPool() as cursor:
cursor.execute(f"SELECT COUNT(*) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%' OR token IS NULL")
result = cursor.fetchone()
return result[0]
finally:
cursor.close()
def average_column(table_name, column_name):
cursor = database.cursor()
try:
with CursorFromConnectionFromPool() as cursor:
cursor.execute(f"SELECT AVG({column_name}) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%' OR token IS NULL")
result = cursor.fetchone()
return result[0]
finally:
cursor.close()
def average_column_for_model(table_name, column_name, model_name):
cursor = database.cursor()
try:
with CursorFromConnectionFromPool() as cursor:
cursor.execute(f"SELECT AVG({column_name}) FROM {table_name} WHERE model = %s AND token NOT LIKE 'SYSTEM__%%' OR token IS NULL", (model_name,))
result = cursor.fetchone()
return result[0]
finally:
cursor.close()
def weighted_average_column_for_model(table_name, column_name, model_name, backend_name, backend_url, exclude_zeros: bool = False, include_system_tokens: bool = True):
@ -129,8 +111,7 @@ def weighted_average_column_for_model(table_name, column_name, model_name, backe
else:
sql = f"SELECT {column_name}, id FROM {table_name} WHERE model = %s AND backend_mode = %s AND backend_url = %s AND (token NOT LIKE 'SYSTEM__%%' OR token IS NULL) ORDER BY id DESC"
cursor = database.cursor()
try:
with CursorFromConnectionFromPool() as cursor:
try:
cursor.execute(sql, (model_name, backend_name, backend_url,))
results = cursor.fetchall()
@ -154,46 +135,34 @@ def weighted_average_column_for_model(table_name, column_name, model_name, backe
calculated_avg = 0
return calculated_avg
finally:
cursor.close()
def sum_column(table_name, column_name):
cursor = database.cursor()
try:
with CursorFromConnectionFromPool() as cursor:
cursor.execute(f"SELECT SUM({column_name}) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%' OR token IS NULL")
result = cursor.fetchone()
return result[0] if result else 0
finally:
cursor.close()
def get_distinct_ips_24h():
# Get the current time and subtract 24 hours (in seconds)
past_24_hours = int(time.time()) - 24 * 60 * 60
cursor = database.cursor()
try:
with CursorFromConnectionFromPool() as cursor:
cursor.execute("SELECT COUNT(DISTINCT ip) FROM prompts WHERE timestamp >= %s AND (token NOT LIKE 'SYSTEM__%%' OR token IS NULL)", (past_24_hours,))
result = cursor.fetchone()
return result[0] if result else 0
finally:
cursor.close()
def increment_token_uses(token):
cursor = database.cursor()
try:
with CursorFromConnectionFromPool() as cursor:
cursor.execute('UPDATE token_auth SET uses = uses + 1 WHERE token = %s', (token,))
finally:
cursor.close()
def get_token_ratelimit(token):
priority = 9990
simultaneous_ip = opts.simultaneous_requests_per_ip
if token:
cursor = database.cursor()
try:
with CursorFromConnectionFromPool() as cursor:
cursor.execute("SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s", (token,))
result = cursor.fetchone()
if result:
@ -201,6 +170,4 @@ def get_token_ratelimit(token):
if simultaneous_ip is None:
# No ratelimit for this token if null
simultaneous_ip = 999999999
finally:
cursor.close()
return priority, simultaneous_ip

View File

@ -1,6 +1,9 @@
import requests
from llm_server import opts
from llm_server.logging import create_logger
_logger = create_logger('moderation')
def check_moderation_endpoint(prompt: str):
@ -10,7 +13,7 @@ def check_moderation_endpoint(prompt: str):
}
response = requests.post('https://api.openai.com/v1/moderations', headers=headers, json={"input": prompt}, timeout=10)
if response.status_code != 200:
print('moderation failed:', response)
_logger.error(f'moderation failed: {response}')
response.raise_for_status()
response = response.json()

View File

@ -1,6 +1,9 @@
from flask import jsonify
from llm_server import opts
from llm_server.logging import create_logger
_logger = create_logger('oai_to_vllm')
def oai_to_vllm(request_json_body, stop_hashes: bool, mode):
@ -36,7 +39,7 @@ def oai_to_vllm(request_json_body, stop_hashes: bool, mode):
def format_oai_err(err_msg):
print('OAI ERROR MESSAGE:', err_msg)
_logger.error(f'Got an OAI error message: {err_msg}')
return jsonify({
"error": {
"message": err_msg,

View File

@ -39,6 +39,7 @@ def init_logging(filepath: Path = None):
This is only called by `server.py` since there is wierdness with Gunicorn. The deamon doesn't need this.
:return:
"""
global LOG_DIRECTORY
logger = logging.getLogger('llm_server')
logger.setLevel(logging_info.level)

View File

@ -5,9 +5,12 @@ from flask import jsonify, request
from llm_server import messages, opts
from llm_server.database.log_to_db import log_to_db
from llm_server.logging import create_logger
from llm_server.routes.helpers.client import format_sillytavern_err
from llm_server.routes.request_handler import RequestHandler
_logger = create_logger('OobaRequestHandler')
class OobaRequestHandler(RequestHandler):
def __init__(self, *args, **kwargs):
@ -16,7 +19,7 @@ class OobaRequestHandler(RequestHandler):
def handle_request(self, return_ok: bool = True):
assert not self.used
if self.offline:
print('This backend is offline:', messages.BACKEND_OFFLINE)
# _logger.debug(f'This backend is offline.')
return self.handle_error(messages.BACKEND_OFFLINE)
request_valid, invalid_response = self.validate_request()

View File

@ -1,8 +1,10 @@
from flask import Blueprint
from ..request_handler import before_request
from ..server_error import handle_server_error
from ... import opts
from ...logging import create_logger
_logger = create_logger('OpenAI')
openai_bp = Blueprint('openai/v1/', __name__)
openai_model_bp = Blueprint('openai/', __name__)
@ -24,7 +26,7 @@ def handle_error(e):
"auth_subrequest_error"
"""
print('OAI returning error:', e)
_logger(f'OAI returning error: {e}')
return jsonify({
"error": {
"message": "Internal server error",

View File

@ -15,6 +15,9 @@ from ... import opts
from ...database.log_to_db import log_to_db
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
from ...logging import create_logger
_logger = create_logger('OpenAIChatCompletions')
# TODO: add rate-limit headers?
@ -99,7 +102,7 @@ def openai_chat_completions(model_name=None):
# return a 408 if necessary.
_, stream_name, error_msg = event.wait()
if error_msg:
print('OAI failed to start streaming:', error_msg)
_logger.error(f'OAI failed to start streaming: {error_msg}')
stream_name = None # set to null so that the Finally ignores it.
return 'Request Timeout', 408
@ -111,7 +114,7 @@ def openai_chat_completions(model_name=None):
while True:
stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout)
if not stream_data:
print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.")
_logger.debug(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.")
yield 'data: [DONE]\n\n'
else:
for stream_index, item in stream_data[0][1]:
@ -120,7 +123,7 @@ def openai_chat_completions(model_name=None):
data = ujson.loads(item[b'data'])
if data['error']:
# Not printing error since we can just check the daemon log.
print('OAI streaming encountered error')
_logger.warn(f'OAI streaming encountered error: {data["error"]}')
yield 'data: [DONE]\n\n'
return
elif data['new']:

View File

@ -16,10 +16,13 @@ from ...database.log_to_db import log_to_db
from ...llm import get_token_count
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
from ...llm.openai.transform import generate_oai_string, trim_string_to_fit
from ...logging import create_logger
# TODO: add rate-limit headers?
_logger = create_logger('OpenAICompletions')
@openai_bp.route('/completions', methods=['POST'])
@openai_model_bp.route('/<model_name>/v1/completions', methods=['POST'])
def openai_completions(model_name=None):
@ -144,7 +147,7 @@ def openai_completions(model_name=None):
_, stream_name, error_msg = event.wait()
if error_msg:
print('OAI failed to start streaming:', error_msg)
_logger.error(f'OAI failed to start streaming: {error_msg}')
stream_name = None
return 'Request Timeout', 408
@ -156,7 +159,7 @@ def openai_completions(model_name=None):
while True:
stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout)
if not stream_data:
print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.")
_logger.debug(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.")
yield 'data: [DONE]\n\n'
else:
for stream_index, item in stream_data[0][1]:
@ -164,7 +167,7 @@ def openai_completions(model_name=None):
timestamp = int(stream_index.decode('utf-8').split('-')[0])
data = ujson.loads(item[b'data'])
if data['error']:
print('OAI streaming encountered error')
_logger.error(f'OAI streaming encountered error: {data["error"]}')
yield 'data: [DONE]\n\n'
return
elif data['new']:

View File

@ -16,9 +16,12 @@ from llm_server.database.log_to_db import log_to_db
from llm_server.llm import get_token_count
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_invalid_model_err
from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
from llm_server.logging import create_logger
from llm_server.routes.request_handler import RequestHandler
from llm_server.workers.moderator import add_moderation_task, get_results
_logger = create_logger('OpenAIRequestHandler')
class OpenAIRequestHandler(RequestHandler):
def __init__(self, *args, **kwargs):
@ -29,7 +32,7 @@ class OpenAIRequestHandler(RequestHandler):
assert not self.used
if self.offline:
msg = return_invalid_model_err(self.selected_model)
print('OAI Offline:', msg)
_logger.error(f'OAI is offline: {msg}')
return self.handle_error(msg)
if opts.openai_silent_trim:
@ -72,7 +75,7 @@ class OpenAIRequestHandler(RequestHandler):
self.request.json['messages'].insert((len(self.request.json['messages'])), {'role': 'system', 'content': mod_msg})
self.prompt = transform_messages_to_prompt(self.request.json['messages'])
except Exception as e:
print(f'OpenAI moderation endpoint failed:', f'{e.__class__.__name__}: {e}')
_logger.error(f'OpenAI moderation endpoint failed: {e.__class__.__name__}: {e}')
traceback.print_exc()
llm_request = {**self.parameters, 'prompt': self.prompt}
@ -106,7 +109,7 @@ class OpenAIRequestHandler(RequestHandler):
return response, 429
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
print('OAI Error:', error_msg)
_logger.error('OAI Error: {error_msg}')
return jsonify({
"error": {
"message": "Invalid request, check your parameters and try again.",
@ -155,7 +158,7 @@ class OpenAIRequestHandler(RequestHandler):
def validate_request(self, prompt: str = None, do_log: bool = False) -> Tuple[bool, Tuple[Response | None, int]]:
self.parameters, parameters_invalid_msg = self.get_parameters()
if not self.parameters:
print('OAI BACKEND VALIDATION ERROR:', parameters_invalid_msg)
_logger.error(f'OAI BACKEND VALIDATION ERROR: {parameters_invalid_msg}')
return False, (Response('Invalid request, check your parameters and try again.'), 400)
invalid_oai_err_msg = validate_oai(self.parameters)
if invalid_oai_err_msg:

View File

@ -10,6 +10,7 @@ from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import RedisCustom, redis
from llm_server.database.database import get_token_ratelimit
from llm_server.logging import create_logger
def increment_ip_count(client_ip: str, redis_key):
@ -30,6 +31,7 @@ class RedisPriorityQueue:
def __init__(self, name, db: int = 12):
self.name = name
self.redis = RedisCustom(name, db=db)
self._logger = create_logger('RedisPriorityQueue')
def put(self, item, priority: int, selected_model: str, do_stream: bool = False):
# TODO: remove this when we're sure nothing strange is happening
@ -41,7 +43,7 @@ class RedisPriorityQueue:
ip_count = self.get_ip_request_count(item[1])
_, simultaneous_ip = get_token_ratelimit(item[2])
if ip_count and int(ip_count) >= simultaneous_ip and priority != 0:
print(f'Rejecting request from {item[1]} - {ip_count} request queued.')
self._logger.debug(f'Rejecting request from {item[1]} - {ip_count} request queued.')
return None # reject the request
timestamp = time.time()
@ -98,7 +100,7 @@ class RedisPriorityQueue:
event_id = item_data[1]
event = DataEvent(event_id)
event.set((False, None, 'closed'))
print('Removed timed-out item from queue:', event_id)
self._logger.debug('Removed timed-out item from queue: {event_id}')
class DataEvent:

View File

@ -12,10 +12,13 @@ from llm_server.database.log_to_db import log_to_db
from llm_server.helpers import auto_set_base_client_api
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
from llm_server.llm.vllm.vllm_backend import VLLMBackend
from llm_server.logging import create_logger
from llm_server.routes.auth import parse_token
from llm_server.routes.helpers.http import require_api_key, validate_json
from llm_server.routes.queue import priority_queue
_logger = create_logger('RequestHandler')
class RequestHandler:
def __init__(self, incoming_request: flask.Request, selected_model: str = None, incoming_json: Union[dict, str] = None):
@ -223,7 +226,7 @@ class RequestHandler:
processing_ip = 0
if queued_ip_count + processing_ip >= self.token_simultaneous_ip:
print(f'Rejecting request from {self.client_ip} - {processing_ip} processing, {queued_ip_count} queued')
_logger.debug(f'Rejecting request from {self.client_ip} - {processing_ip} processing, {queued_ip_count} queued')
return True
else:
return False

View File

@ -1,3 +1,8 @@
from llm_server.logging import create_logger
_logger = create_logger('handle_server_error')
def handle_server_error(e):
print('Internal Error:', e)
_logger.error(f'Internal Error: {e}')
return {'error': True, 'code': 500, 'message': 'Internal Server Error :('}, 500

View File

@ -13,12 +13,15 @@ from ..queue import priority_queue
from ... import opts
from ...custom_redis import redis
from ...database.log_to_db import log_to_db
from ...logging import create_logger
from ...sock import sock
# Stacking the @sock.route() creates a TypeError error on the /v1/stream endpoint.
# We solve this by splitting the routes
_logger = create_logger('GenerateStream')
@bp.route('/v1/stream')
@bp.route('/<model_name>/v1/stream')
def stream(model_name=None):
@ -85,7 +88,7 @@ def do_stream(ws, model_name):
handler = OobaRequestHandler(incoming_request=request, selected_model=model_name, incoming_json=request_json_body)
if handler.offline:
msg = f'{handler.selected_model} is not a valid model choice.'
print(msg)
_logger.debug(msg)
ws.send(json.dumps({
'event': 'text_stream',
'message_num': 0,
@ -131,7 +134,7 @@ def do_stream(ws, model_name):
_, stream_name, error_msg = event.wait()
if error_msg:
print('Stream failed to start streaming:', error_msg)
_logger.error(f'Stream failed to start streaming: {error_msg}')
ws.close(reason=1014, message='Request Timeout')
return
@ -143,14 +146,14 @@ def do_stream(ws, model_name):
while True:
stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout)
if not stream_data:
print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.")
_logger.error(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.")
return
else:
for stream_index, item in stream_data[0][1]:
last_id = stream_index
data = ujson.loads(item[b'data'])
if data['error']:
print(data['error'])
_logger.error(f'Encountered error while streaming: {data["error"]}')
send_err_and_quit('Encountered exception while streaming.')
return
elif data['new']:

View File

@ -2,6 +2,7 @@ import time
from redis import Redis
from llm_server.logging import create_logger
from llm_server.workers.inferencer import STREAM_NAME_PREFIX
@ -10,6 +11,7 @@ from llm_server.workers.inferencer import STREAM_NAME_PREFIX
def cleaner():
r = Redis(db=8)
stream_info = {}
logger = create_logger('cleaner')
while True:
all_streams = r.keys(f'{STREAM_NAME_PREFIX}:*')
@ -26,7 +28,7 @@ def cleaner():
# If the size hasn't changed for 5 minutes, delete the stream
if time.time() - stream_info[stream]['time'] >= 300:
r.delete(stream)
print(f"Stream '{stream}' deleted due to inactivity.")
logger.debug(f"Stream '{stream}' deleted due to inactivity.")
del stream_info[stream]
time.sleep(60)

View File

@ -4,6 +4,7 @@ import traceback
import redis
from llm_server.database.database import do_db_log
from llm_server.logging import create_logger
def db_logger():
@ -16,6 +17,7 @@ def db_logger():
r = redis.Redis(host='localhost', port=6379, db=3)
p = r.pubsub()
p.subscribe('database-logger')
logger = create_logger('main_bg')
for message in p.listen():
try:
@ -28,4 +30,4 @@ def db_logger():
if function_name == 'log_prompt':
do_db_log(*args, **kwargs)
except:
traceback.print_exc()
logger.error(traceback.format_exc())

View File

@ -7,10 +7,12 @@ from llm_server.cluster.cluster_config import get_backends, cluster_config
from llm_server.custom_redis import redis
from llm_server.database.database import weighted_average_column_for_model
from llm_server.llm.info import get_info
from llm_server.logging import create_logger
from llm_server.routes.queue import RedisPriorityQueue, priority_queue
def main_background_thread():
logger = create_logger('main_bg')
while True:
online, offline = get_backends()
for backend_url in online:
@ -34,7 +36,7 @@ def main_background_thread():
base_client_api = redis.get('base_client_api', dtype=str)
r = requests.get('https://' + base_client_api, timeout=5)
except Exception as e:
print(f'Failed fetch the homepage - {e.__class__.__name__}: {e}')
logger.error(f'Failed fetch the homepage - {e.__class__.__name__}: {e}')
backends = priority_queue.get_backends()
for backend_url in backends:

View File

@ -4,7 +4,7 @@ Flask-Caching==2.0.2
requests~=2.31.0
tiktoken~=0.5.0
gevent~=23.9.0.post1
PyMySQL~=1.1.0
mysql-connector-python==8.4.0
simplejson~=3.19.1
websockets~=11.0.3
basicauth~=1.0.0
@ -14,5 +14,4 @@ gunicorn==21.2.0
redis==5.0.1
ujson==5.8.0
vllm==0.2.7
gradio~=3.46.1
coloredlogs~=15.0.1

View File

@ -22,11 +22,11 @@ from llm_server.cluster.cluster_config import cluster_config
from llm_server.config.config import mode_ui_names
from llm_server.config.load import load_config
from llm_server.custom_redis import flask_cache, redis
from llm_server.database.conn import database
from llm_server.database.conn import database, Database
from llm_server.database.create import create_db
from llm_server.helpers import auto_set_base_client_api
from llm_server.llm.vllm.info import vllm_info
from llm_server.logging import init_logging
from llm_server.logging import init_logging, create_logger
from llm_server.routes.openai import openai_bp, openai_model_bp
from llm_server.routes.server_error import handle_server_error
from llm_server.routes.v1 import bp
@ -62,13 +62,6 @@ from llm_server.sock import init_wssocket
# TODO: add more excluding to SYSTEM__ tokens
# TODO: return 200 when returning formatted sillytavern error
try:
import vllm
except ModuleNotFoundError as e:
print('Could not import vllm-gptq:', e)
print('Please see README.md for install instructions.')
sys.exit(1)
script_path = os.path.dirname(os.path.realpath(__file__))
config_path_environ = os.getenv("CONFIG_PATH")
if config_path_environ:
@ -78,11 +71,20 @@ else:
success, config, msg = load_config(config_path)
if not success:
print('Failed to load config:', msg)
logger = logging.getLogger('llm_server')
logger.setLevel(logging.INFO)
logger.error(f'Failed to load config: {msg}')
sys.exit(1)
init_logging(Path(config['webserver_log_directory']) / 'server.log')
logger = logging.getLogger('llm_server')
logger = create_logger('Server')
logger.debug('Debug logging enabled.')
try:
import vllm
except ModuleNotFoundError as e:
logger.error(f'Could not import vllm-gptq: {e}')
sys.exit(1)
while not redis.get('daemon_started', dtype=bool):
logger.warning('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?')
@ -90,7 +92,7 @@ while not redis.get('daemon_started', dtype=bool):
logger.info('Started HTTP worker!')
database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database'])
Database.initialise(maxconn=config['mysql']['maxconn'], host=config['mysql']['host'], user=config['mysql']['username'], password=config['mysql']['password'], database=config['mysql']['database'])
create_db()
app = Flask(__name__)

View File

@ -1,6 +1,5 @@
<!DOCTYPE html>
<html lang="en">
<head>
<title>{{ llm_middleware_name }}</title>
<meta content="width=device-width, initial-scale=1" name="viewport"/>
@ -97,8 +96,8 @@
<p><strong>Streaming API URL:</strong> {{ ws_client_api if enable_streaming else 'Disabled' }}</p>
<p><strong>OpenAI-Compatible API URL:</strong> {{ openai_client_api }}</p>
{% if info_html|length > 1 %}
<br>
{{ info_html|safe }}
<br>
{{ info_html|safe }}
{% endif %}
</div>
@ -112,7 +111,8 @@
<li>Set your API type to <kbd>{{ mode_name }}</kbd></li>
<li>Enter <kbd>{{ client_api }}</kbd> in the <kbd>{{ api_input_textbox }}</kbd> textbox.</li>
{% if enable_streaming %}
<li>Enter <kbd>{{ ws_client_api }}</kbd> in the <kbd>{{ streaming_input_textbox }}</kbd> textbox.</li>
<li>Enter <kbd>{{ ws_client_api }}</kbd> in the <kbd>{{ streaming_input_textbox }}</kbd> textbox.
</li>
{% endif %}
<li>If you have a token, check the <kbd>Mancer AI</kbd> checkbox and enter your token in the <kbd>Mancer
API key</kbd> textbox.
@ -124,11 +124,12 @@
</ol>
</div>
{% if openai_client_api != 'disabled' and expose_openai_system_prompt %}
<br>
<div id="openai">
<strong>OpenAI-Compatible API</strong>
<p>The OpenAI-compatible API adds a system prompt to set the AI's behavior to a "helpful assistant". You can view this prompt <a href="/api/openai/v1/prompt">here</a>.</p>
</div>
<br>
<div id="openai">
<strong>OpenAI-Compatible API</strong>
<p>The OpenAI-compatible API adds a system prompt to set the AI's behavior to a "helpful assistant". You
can view this prompt <a href="/api/openai/v1/prompt">here</a>.</p>
</div>
{% endif %}
<br>
<div id="extra-info">{{ extra_info|safe }}</div>
@ -147,30 +148,31 @@
<br>
{% for key, value in model_choices.items() %}
<div class="info-box">
<h3>{{ key }} <span class="header-workers">- {{ value.backend_count }} {% if value.backend_count == 1 %}worker{% else %}workers{% endif %}</span></h3>
<div class="info-box">
<h3>{{ key }} <span class="header-workers">- {{ value.backend_count }} {% if value.backend_count == 1 %}
worker{% else %}workers{% endif %}</span></h3>
{% if value.estimated_wait == 0 and value.estimated_wait >= value.concurrent_gens %}
{# There will be a wait if the queue is empty but prompts are processing, but we don't know how long. #}
{% set estimated_wait_sec = "less than " + value.estimated_wait|int|string + " seconds" %}
{% else %}
{% set estimated_wait_sec = value.estimated_wait|int|string + " seconds" %}
{% endif %}
{% if value.estimated_wait == 0 and value.estimated_wait >= value.concurrent_gens %}
{# There will be a wait if the queue is empty but prompts are processing, but we don't know how long. #}
{% set estimated_wait_sec = "less than " + value.estimated_wait|int|string + " seconds" %}
{% else %}
{% set estimated_wait_sec = value.estimated_wait|int|string + " seconds" %}
{% endif %}
<p>
<strong>Estimated Wait Time:</strong> {{ estimated_wait_sec }}<br>
Processing: {{ value.processing }}<br>
Queued: {{ value.queued }}<br>
</p>
<p>
<strong>Client API URL:</strong> {{ value.client_api }}<br>
<strong>Streaming API URL:</strong> {{ value.ws_client_api }}<br>
<strong>OpenAI-Compatible API URL:</strong> {{ value.openai_client_api }}
</p>
<p><strong>Context Size:</strong> {{ value.context_size }}</p>
<p><strong>Average Generation Time:</strong> {{ value.avg_generation_time | int }} seconds</p>
</div>
<br>
<p>
<strong>Estimated Wait Time:</strong> {{ estimated_wait_sec }}<br>
Processing: {{ value.processing }}<br>
Queued: {{ value.queued }}<br>
</p>
<p>
<strong>Client API URL:</strong> {{ value.client_api }}<br>
<strong>Streaming API URL:</strong> {{ value.ws_client_api }}<br>
<strong>OpenAI-Compatible API URL:</strong> {{ value.openai_client_api }}
</p>
<p><strong>Context Size:</strong> {{ value.context_size }}</p>
<p><strong>Average Generation Time:</strong> {{ value.avg_generation_time | int }} seconds</p>
</div>
<br>
{% endfor %}
</div>
<div class="footer">