redo database connection, add pooling, minor logging changes, other clean up
This commit is contained in:
parent
ab408c6c5b
commit
ff82add09e
|
@ -28,7 +28,6 @@ if __name__ == "__main__":
|
||||||
parser.add_argument('-d', '--debug', action='store_true', help='Enable debug logging.')
|
parser.add_argument('-d', '--debug', action='store_true', help='Enable debug logging.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# TODO: have this be set by either the arg or a config value
|
|
||||||
if args.debug:
|
if args.debug:
|
||||||
logging_info.level = logging.DEBUG
|
logging_info.level = logging.DEBUG
|
||||||
|
|
||||||
|
|
|
@ -6,11 +6,14 @@ from llm_server import opts
|
||||||
from llm_server.cluster.redis_cycle import add_backend_cycler, redis_cycle
|
from llm_server.cluster.redis_cycle import add_backend_cycler, redis_cycle
|
||||||
from llm_server.cluster.stores import redis_running_models
|
from llm_server.cluster.stores import redis_running_models
|
||||||
from llm_server.custom_redis import RedisCustom
|
from llm_server.custom_redis import RedisCustom
|
||||||
|
from llm_server.logging import create_logger
|
||||||
from llm_server.routes.helpers.model import estimate_model_size
|
from llm_server.routes.helpers.model import estimate_model_size
|
||||||
|
|
||||||
|
|
||||||
# Don't try to reorganize this file or else you'll run into circular imports.
|
# Don't try to reorganize this file or else you'll run into circular imports.
|
||||||
|
|
||||||
|
_logger = create_logger('redis')
|
||||||
|
|
||||||
|
|
||||||
class RedisClusterStore:
|
class RedisClusterStore:
|
||||||
"""
|
"""
|
||||||
A class used to store the cluster state in Redis.
|
A class used to store the cluster state in Redis.
|
||||||
|
@ -67,7 +70,7 @@ class RedisClusterStore:
|
||||||
if not backend_info['online']:
|
if not backend_info['online']:
|
||||||
old = backend_url
|
old = backend_url
|
||||||
backend_url = get_a_cluster_backend()
|
backend_url = get_a_cluster_backend()
|
||||||
print(f'Backend {old} offline. Request was redirected to {backend_url}')
|
_logger.debug(f'Backend {old} offline. Request was redirected to {backend_url}')
|
||||||
return backend_url
|
return backend_url
|
||||||
|
|
||||||
|
|
||||||
|
@ -108,8 +111,7 @@ def get_backends():
|
||||||
)
|
)
|
||||||
return [url for url, info in online_backends], [url for url, info in offline_backends]
|
return [url for url, info in online_backends], [url for url, info in offline_backends]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
traceback.print_exc()
|
_logger.err(f'Failed to get a backend from the cluster config: {traceback.format_exc()}\nCurrent backends: {backends}')
|
||||||
print(backends)
|
|
||||||
|
|
||||||
|
|
||||||
def get_a_cluster_backend(model=None):
|
def get_a_cluster_backend(model=None):
|
||||||
|
|
|
@ -7,10 +7,13 @@ import llm_server
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
from llm_server.config.config import ConfigLoader, config_default_vars, config_required_vars
|
from llm_server.config.config import ConfigLoader, config_default_vars, config_required_vars
|
||||||
from llm_server.custom_redis import redis
|
from llm_server.custom_redis import redis
|
||||||
from llm_server.database.conn import database
|
from llm_server.database.conn import Database
|
||||||
from llm_server.database.database import get_number_of_rows
|
from llm_server.database.database import get_number_of_rows
|
||||||
|
from llm_server.logging import create_logger
|
||||||
from llm_server.routes.queue import PriorityQueue
|
from llm_server.routes.queue import PriorityQueue
|
||||||
|
|
||||||
|
_logger = create_logger('config')
|
||||||
|
|
||||||
|
|
||||||
def load_config(config_path):
|
def load_config(config_path):
|
||||||
config_loader = ConfigLoader(config_path, config_default_vars, config_required_vars)
|
config_loader = ConfigLoader(config_path, config_default_vars, config_required_vars)
|
||||||
|
@ -58,7 +61,7 @@ def load_config(config_path):
|
||||||
llm_server.routes.queue.priority_queue = PriorityQueue([x['backend_url'] for x in config['cluster']])
|
llm_server.routes.queue.priority_queue = PriorityQueue([x['backend_url'] for x in config['cluster']])
|
||||||
|
|
||||||
if opts.openai_expose_our_model and not opts.openai_api_key:
|
if opts.openai_expose_our_model and not opts.openai_api_key:
|
||||||
print('If you set openai_expose_our_model to false, you must set your OpenAI key in openai_api_key.')
|
_logger.error('If you set openai_expose_our_model to false, you must set your OpenAI key in openai_api_key.')
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
opts.verify_ssl = config['verify_ssl']
|
opts.verify_ssl = config['verify_ssl']
|
||||||
|
@ -67,11 +70,11 @@ def load_config(config_path):
|
||||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||||
|
|
||||||
if config['http_host']:
|
if config['http_host']:
|
||||||
http_host = re.sub(r'http(?:s)?://', '', config["http_host"])
|
http_host = re.sub(r'https?://', '', config["http_host"])
|
||||||
redis.set('http_host', http_host)
|
redis.set('http_host', http_host)
|
||||||
redis.set('base_client_api', f'{http_host}/{opts.frontend_api_client.strip("/")}')
|
redis.set('base_client_api', f'{http_host}/{opts.frontend_api_client.strip("/")}')
|
||||||
|
|
||||||
database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database'])
|
Database.initialise(maxconn=config['mysql']['maxconn'], host=config['mysql']['host'], user=config['mysql']['username'], password=config['mysql']['password'], database=config['mysql']['database'])
|
||||||
|
|
||||||
if config['load_num_prompts']:
|
if config['load_num_prompts']:
|
||||||
redis.set('proompts', get_number_of_rows('prompts'))
|
redis.set('proompts', get_number_of_rows('prompts'))
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import logging
|
||||||
import pickle
|
import pickle
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
@ -29,8 +30,9 @@ class RedisCustom(Redis):
|
||||||
try:
|
try:
|
||||||
self.set('____', 1)
|
self.set('____', 1)
|
||||||
except redis_pkg.exceptions.ConnectionError as e:
|
except redis_pkg.exceptions.ConnectionError as e:
|
||||||
print('Failed to connect to the Redis server:', e)
|
logger = logging.getLogger('redis')
|
||||||
print('Did you install and start the Redis server?')
|
logger.setLevel(logging.INFO)
|
||||||
|
logger.error(f'Failed to connect to the Redis server: {e}\nDid you install and start the Redis server?')
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
def _key(self, key):
|
def _key(self, key):
|
||||||
|
|
|
@ -1,28 +1,40 @@
|
||||||
import pymysql
|
from mysql.connector import pooling
|
||||||
|
|
||||||
|
|
||||||
class DatabaseConnection:
|
class Database:
|
||||||
host: str = None
|
__connection_pool = None
|
||||||
username: str = None
|
|
||||||
password: str = None
|
|
||||||
database_name: str = None
|
|
||||||
|
|
||||||
def init_db(self, host, username, password, database_name):
|
@classmethod
|
||||||
self.host = host
|
def initialise(cls, maxconn, **kwargs):
|
||||||
self.username = username
|
if cls.__connection_pool is not None:
|
||||||
self.password = password
|
raise Exception('Database connection pool is already initialised')
|
||||||
self.database_name = database_name
|
cls.__connection_pool = pooling.MySQLConnectionPool(pool_size=maxconn,
|
||||||
|
pool_reset_session=True,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
def cursor(self):
|
@classmethod
|
||||||
db = pymysql.connect(
|
def get_connection(cls):
|
||||||
host=self.host,
|
return cls.__connection_pool.get_connection()
|
||||||
user=self.username,
|
|
||||||
password=self.password,
|
@classmethod
|
||||||
database=self.database_name,
|
def return_connection(cls, connection):
|
||||||
charset='utf8mb4',
|
connection.close()
|
||||||
autocommit=True,
|
|
||||||
)
|
|
||||||
return db.cursor()
|
|
||||||
|
|
||||||
|
|
||||||
database = DatabaseConnection()
|
class CursorFromConnectionFromPool:
|
||||||
|
def __init__(self):
|
||||||
|
self.conn = None
|
||||||
|
self.cursor = None
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.conn = Database.get_connection()
|
||||||
|
self.cursor = self.conn.cursor()
|
||||||
|
return self.cursor
|
||||||
|
|
||||||
|
def __exit__(self, exception_type, exception_value, exception_traceback):
|
||||||
|
if exception_value is not None: # This is equivalent of saying if there is an exception
|
||||||
|
self.conn.rollback()
|
||||||
|
else:
|
||||||
|
self.cursor.close()
|
||||||
|
self.conn.commit()
|
||||||
|
Database.return_connection(self.conn)
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
from llm_server.database.conn import database
|
from llm_server.database.conn import CursorFromConnectionFromPool
|
||||||
|
|
||||||
|
|
||||||
def create_db():
|
def create_db():
|
||||||
cursor = database.cursor()
|
with CursorFromConnectionFromPool() as cursor:
|
||||||
cursor.execute('''
|
cursor.execute('''
|
||||||
CREATE TABLE IF NOT EXISTS prompts (
|
CREATE TABLE IF NOT EXISTS prompts (
|
||||||
ip TEXT,
|
ip TEXT,
|
||||||
|
@ -37,4 +37,3 @@ def create_db():
|
||||||
disabled BOOLEAN DEFAULT 0
|
disabled BOOLEAN DEFAULT 0
|
||||||
)
|
)
|
||||||
''')
|
''')
|
||||||
cursor.close()
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ from typing import Union
|
||||||
|
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
from llm_server.cluster.cluster_config import cluster_config
|
from llm_server.cluster.cluster_config import cluster_config
|
||||||
from llm_server.database.conn import database
|
from llm_server.database.conn import CursorFromConnectionFromPool
|
||||||
from llm_server.llm import get_token_count
|
from llm_server.llm import get_token_count
|
||||||
|
|
||||||
|
|
||||||
|
@ -52,21 +52,17 @@ def do_db_log(ip: str, token: str, prompt: str, response: Union[str, None], gen_
|
||||||
running_model = backend_info.get('model')
|
running_model = backend_info.get('model')
|
||||||
backend_mode = backend_info['mode']
|
backend_mode = backend_info['mode']
|
||||||
timestamp = int(time.time())
|
timestamp = int(time.time())
|
||||||
cursor = database.cursor()
|
with CursorFromConnectionFromPool() as cursor:
|
||||||
try:
|
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
INSERT INTO prompts
|
INSERT INTO prompts
|
||||||
(ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp)
|
(ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp)
|
||||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||||
""",
|
""",
|
||||||
(ip, token, running_model, backend_mode, backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
|
(ip, token, running_model, backend_mode, backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
|
||||||
finally:
|
|
||||||
cursor.close()
|
|
||||||
|
|
||||||
|
|
||||||
def is_valid_api_key(api_key):
|
def is_valid_api_key(api_key):
|
||||||
cursor = database.cursor()
|
with CursorFromConnectionFromPool() as cursor:
|
||||||
try:
|
|
||||||
cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = %s", (api_key,))
|
cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = %s", (api_key,))
|
||||||
row = cursor.fetchone()
|
row = cursor.fetchone()
|
||||||
if row is not None:
|
if row is not None:
|
||||||
|
@ -75,52 +71,38 @@ def is_valid_api_key(api_key):
|
||||||
if ((uses is None or max_uses is None) or uses < max_uses) and (expire is None or expire > time.time()) and not disabled:
|
if ((uses is None or max_uses is None) or uses < max_uses) and (expire is None or expire > time.time()) and not disabled:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
finally:
|
|
||||||
cursor.close()
|
|
||||||
|
|
||||||
|
|
||||||
def is_api_key_moderated(api_key):
|
def is_api_key_moderated(api_key):
|
||||||
if not api_key:
|
if not api_key:
|
||||||
return opts.openai_moderation_enabled
|
return opts.openai_moderation_enabled
|
||||||
cursor = database.cursor()
|
with CursorFromConnectionFromPool() as cursor:
|
||||||
try:
|
|
||||||
cursor.execute("SELECT openai_moderation_enabled FROM token_auth WHERE token = %s", (api_key,))
|
cursor.execute("SELECT openai_moderation_enabled FROM token_auth WHERE token = %s", (api_key,))
|
||||||
row = cursor.fetchone()
|
row = cursor.fetchone()
|
||||||
if row is not None:
|
if row is not None:
|
||||||
return bool(row[0])
|
return bool(row[0])
|
||||||
return opts.openai_moderation_enabled
|
return opts.openai_moderation_enabled
|
||||||
finally:
|
|
||||||
cursor.close()
|
|
||||||
|
|
||||||
|
|
||||||
def get_number_of_rows(table_name):
|
def get_number_of_rows(table_name):
|
||||||
cursor = database.cursor()
|
with CursorFromConnectionFromPool() as cursor:
|
||||||
try:
|
|
||||||
cursor.execute(f"SELECT COUNT(*) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%' OR token IS NULL")
|
cursor.execute(f"SELECT COUNT(*) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%' OR token IS NULL")
|
||||||
result = cursor.fetchone()
|
result = cursor.fetchone()
|
||||||
return result[0]
|
return result[0]
|
||||||
finally:
|
|
||||||
cursor.close()
|
|
||||||
|
|
||||||
|
|
||||||
def average_column(table_name, column_name):
|
def average_column(table_name, column_name):
|
||||||
cursor = database.cursor()
|
with CursorFromConnectionFromPool() as cursor:
|
||||||
try:
|
|
||||||
cursor.execute(f"SELECT AVG({column_name}) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%' OR token IS NULL")
|
cursor.execute(f"SELECT AVG({column_name}) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%' OR token IS NULL")
|
||||||
result = cursor.fetchone()
|
result = cursor.fetchone()
|
||||||
return result[0]
|
return result[0]
|
||||||
finally:
|
|
||||||
cursor.close()
|
|
||||||
|
|
||||||
|
|
||||||
def average_column_for_model(table_name, column_name, model_name):
|
def average_column_for_model(table_name, column_name, model_name):
|
||||||
cursor = database.cursor()
|
with CursorFromConnectionFromPool() as cursor:
|
||||||
try:
|
|
||||||
cursor.execute(f"SELECT AVG({column_name}) FROM {table_name} WHERE model = %s AND token NOT LIKE 'SYSTEM__%%' OR token IS NULL", (model_name,))
|
cursor.execute(f"SELECT AVG({column_name}) FROM {table_name} WHERE model = %s AND token NOT LIKE 'SYSTEM__%%' OR token IS NULL", (model_name,))
|
||||||
result = cursor.fetchone()
|
result = cursor.fetchone()
|
||||||
return result[0]
|
return result[0]
|
||||||
finally:
|
|
||||||
cursor.close()
|
|
||||||
|
|
||||||
|
|
||||||
def weighted_average_column_for_model(table_name, column_name, model_name, backend_name, backend_url, exclude_zeros: bool = False, include_system_tokens: bool = True):
|
def weighted_average_column_for_model(table_name, column_name, model_name, backend_name, backend_url, exclude_zeros: bool = False, include_system_tokens: bool = True):
|
||||||
|
@ -129,8 +111,7 @@ def weighted_average_column_for_model(table_name, column_name, model_name, backe
|
||||||
else:
|
else:
|
||||||
sql = f"SELECT {column_name}, id FROM {table_name} WHERE model = %s AND backend_mode = %s AND backend_url = %s AND (token NOT LIKE 'SYSTEM__%%' OR token IS NULL) ORDER BY id DESC"
|
sql = f"SELECT {column_name}, id FROM {table_name} WHERE model = %s AND backend_mode = %s AND backend_url = %s AND (token NOT LIKE 'SYSTEM__%%' OR token IS NULL) ORDER BY id DESC"
|
||||||
|
|
||||||
cursor = database.cursor()
|
with CursorFromConnectionFromPool() as cursor:
|
||||||
try:
|
|
||||||
try:
|
try:
|
||||||
cursor.execute(sql, (model_name, backend_name, backend_url,))
|
cursor.execute(sql, (model_name, backend_name, backend_url,))
|
||||||
results = cursor.fetchall()
|
results = cursor.fetchall()
|
||||||
|
@ -154,46 +135,34 @@ def weighted_average_column_for_model(table_name, column_name, model_name, backe
|
||||||
calculated_avg = 0
|
calculated_avg = 0
|
||||||
|
|
||||||
return calculated_avg
|
return calculated_avg
|
||||||
finally:
|
|
||||||
cursor.close()
|
|
||||||
|
|
||||||
|
|
||||||
def sum_column(table_name, column_name):
|
def sum_column(table_name, column_name):
|
||||||
cursor = database.cursor()
|
with CursorFromConnectionFromPool() as cursor:
|
||||||
try:
|
|
||||||
cursor.execute(f"SELECT SUM({column_name}) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%' OR token IS NULL")
|
cursor.execute(f"SELECT SUM({column_name}) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%' OR token IS NULL")
|
||||||
result = cursor.fetchone()
|
result = cursor.fetchone()
|
||||||
return result[0] if result else 0
|
return result[0] if result else 0
|
||||||
finally:
|
|
||||||
cursor.close()
|
|
||||||
|
|
||||||
|
|
||||||
def get_distinct_ips_24h():
|
def get_distinct_ips_24h():
|
||||||
# Get the current time and subtract 24 hours (in seconds)
|
# Get the current time and subtract 24 hours (in seconds)
|
||||||
past_24_hours = int(time.time()) - 24 * 60 * 60
|
past_24_hours = int(time.time()) - 24 * 60 * 60
|
||||||
cursor = database.cursor()
|
with CursorFromConnectionFromPool() as cursor:
|
||||||
try:
|
|
||||||
cursor.execute("SELECT COUNT(DISTINCT ip) FROM prompts WHERE timestamp >= %s AND (token NOT LIKE 'SYSTEM__%%' OR token IS NULL)", (past_24_hours,))
|
cursor.execute("SELECT COUNT(DISTINCT ip) FROM prompts WHERE timestamp >= %s AND (token NOT LIKE 'SYSTEM__%%' OR token IS NULL)", (past_24_hours,))
|
||||||
result = cursor.fetchone()
|
result = cursor.fetchone()
|
||||||
return result[0] if result else 0
|
return result[0] if result else 0
|
||||||
finally:
|
|
||||||
cursor.close()
|
|
||||||
|
|
||||||
|
|
||||||
def increment_token_uses(token):
|
def increment_token_uses(token):
|
||||||
cursor = database.cursor()
|
with CursorFromConnectionFromPool() as cursor:
|
||||||
try:
|
|
||||||
cursor.execute('UPDATE token_auth SET uses = uses + 1 WHERE token = %s', (token,))
|
cursor.execute('UPDATE token_auth SET uses = uses + 1 WHERE token = %s', (token,))
|
||||||
finally:
|
|
||||||
cursor.close()
|
|
||||||
|
|
||||||
|
|
||||||
def get_token_ratelimit(token):
|
def get_token_ratelimit(token):
|
||||||
priority = 9990
|
priority = 9990
|
||||||
simultaneous_ip = opts.simultaneous_requests_per_ip
|
simultaneous_ip = opts.simultaneous_requests_per_ip
|
||||||
if token:
|
if token:
|
||||||
cursor = database.cursor()
|
with CursorFromConnectionFromPool() as cursor:
|
||||||
try:
|
|
||||||
cursor.execute("SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s", (token,))
|
cursor.execute("SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s", (token,))
|
||||||
result = cursor.fetchone()
|
result = cursor.fetchone()
|
||||||
if result:
|
if result:
|
||||||
|
@ -201,6 +170,4 @@ def get_token_ratelimit(token):
|
||||||
if simultaneous_ip is None:
|
if simultaneous_ip is None:
|
||||||
# No ratelimit for this token if null
|
# No ratelimit for this token if null
|
||||||
simultaneous_ip = 999999999
|
simultaneous_ip = 999999999
|
||||||
finally:
|
|
||||||
cursor.close()
|
|
||||||
return priority, simultaneous_ip
|
return priority, simultaneous_ip
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
|
from llm_server.logging import create_logger
|
||||||
|
|
||||||
|
_logger = create_logger('moderation')
|
||||||
|
|
||||||
|
|
||||||
def check_moderation_endpoint(prompt: str):
|
def check_moderation_endpoint(prompt: str):
|
||||||
|
@ -10,7 +13,7 @@ def check_moderation_endpoint(prompt: str):
|
||||||
}
|
}
|
||||||
response = requests.post('https://api.openai.com/v1/moderations', headers=headers, json={"input": prompt}, timeout=10)
|
response = requests.post('https://api.openai.com/v1/moderations', headers=headers, json={"input": prompt}, timeout=10)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
print('moderation failed:', response)
|
_logger.error(f'moderation failed: {response}')
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
response = response.json()
|
response = response.json()
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
from flask import jsonify
|
from flask import jsonify
|
||||||
|
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
|
from llm_server.logging import create_logger
|
||||||
|
|
||||||
|
_logger = create_logger('oai_to_vllm')
|
||||||
|
|
||||||
|
|
||||||
def oai_to_vllm(request_json_body, stop_hashes: bool, mode):
|
def oai_to_vllm(request_json_body, stop_hashes: bool, mode):
|
||||||
|
@ -36,7 +39,7 @@ def oai_to_vllm(request_json_body, stop_hashes: bool, mode):
|
||||||
|
|
||||||
|
|
||||||
def format_oai_err(err_msg):
|
def format_oai_err(err_msg):
|
||||||
print('OAI ERROR MESSAGE:', err_msg)
|
_logger.error(f'Got an OAI error message: {err_msg}')
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"error": {
|
"error": {
|
||||||
"message": err_msg,
|
"message": err_msg,
|
||||||
|
|
|
@ -39,6 +39,7 @@ def init_logging(filepath: Path = None):
|
||||||
This is only called by `server.py` since there is wierdness with Gunicorn. The deamon doesn't need this.
|
This is only called by `server.py` since there is wierdness with Gunicorn. The deamon doesn't need this.
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
global LOG_DIRECTORY
|
||||||
logger = logging.getLogger('llm_server')
|
logger = logging.getLogger('llm_server')
|
||||||
logger.setLevel(logging_info.level)
|
logger.setLevel(logging_info.level)
|
||||||
|
|
||||||
|
|
|
@ -5,9 +5,12 @@ from flask import jsonify, request
|
||||||
|
|
||||||
from llm_server import messages, opts
|
from llm_server import messages, opts
|
||||||
from llm_server.database.log_to_db import log_to_db
|
from llm_server.database.log_to_db import log_to_db
|
||||||
|
from llm_server.logging import create_logger
|
||||||
from llm_server.routes.helpers.client import format_sillytavern_err
|
from llm_server.routes.helpers.client import format_sillytavern_err
|
||||||
from llm_server.routes.request_handler import RequestHandler
|
from llm_server.routes.request_handler import RequestHandler
|
||||||
|
|
||||||
|
_logger = create_logger('OobaRequestHandler')
|
||||||
|
|
||||||
|
|
||||||
class OobaRequestHandler(RequestHandler):
|
class OobaRequestHandler(RequestHandler):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
@ -16,7 +19,7 @@ class OobaRequestHandler(RequestHandler):
|
||||||
def handle_request(self, return_ok: bool = True):
|
def handle_request(self, return_ok: bool = True):
|
||||||
assert not self.used
|
assert not self.used
|
||||||
if self.offline:
|
if self.offline:
|
||||||
print('This backend is offline:', messages.BACKEND_OFFLINE)
|
# _logger.debug(f'This backend is offline.')
|
||||||
return self.handle_error(messages.BACKEND_OFFLINE)
|
return self.handle_error(messages.BACKEND_OFFLINE)
|
||||||
|
|
||||||
request_valid, invalid_response = self.validate_request()
|
request_valid, invalid_response = self.validate_request()
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
from flask import Blueprint
|
from flask import Blueprint
|
||||||
|
|
||||||
from ..request_handler import before_request
|
from ..request_handler import before_request
|
||||||
from ..server_error import handle_server_error
|
|
||||||
from ... import opts
|
from ... import opts
|
||||||
|
from ...logging import create_logger
|
||||||
|
|
||||||
|
_logger = create_logger('OpenAI')
|
||||||
|
|
||||||
openai_bp = Blueprint('openai/v1/', __name__)
|
openai_bp = Blueprint('openai/v1/', __name__)
|
||||||
openai_model_bp = Blueprint('openai/', __name__)
|
openai_model_bp = Blueprint('openai/', __name__)
|
||||||
|
@ -24,7 +26,7 @@ def handle_error(e):
|
||||||
"auth_subrequest_error"
|
"auth_subrequest_error"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
print('OAI returning error:', e)
|
_logger(f'OAI returning error: {e}')
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"error": {
|
"error": {
|
||||||
"message": "Internal server error",
|
"message": "Internal server error",
|
||||||
|
|
|
@ -15,6 +15,9 @@ from ... import opts
|
||||||
from ...database.log_to_db import log_to_db
|
from ...database.log_to_db import log_to_db
|
||||||
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
|
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
|
||||||
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
||||||
|
from ...logging import create_logger
|
||||||
|
|
||||||
|
_logger = create_logger('OpenAIChatCompletions')
|
||||||
|
|
||||||
|
|
||||||
# TODO: add rate-limit headers?
|
# TODO: add rate-limit headers?
|
||||||
|
@ -99,7 +102,7 @@ def openai_chat_completions(model_name=None):
|
||||||
# return a 408 if necessary.
|
# return a 408 if necessary.
|
||||||
_, stream_name, error_msg = event.wait()
|
_, stream_name, error_msg = event.wait()
|
||||||
if error_msg:
|
if error_msg:
|
||||||
print('OAI failed to start streaming:', error_msg)
|
_logger.error(f'OAI failed to start streaming: {error_msg}')
|
||||||
stream_name = None # set to null so that the Finally ignores it.
|
stream_name = None # set to null so that the Finally ignores it.
|
||||||
return 'Request Timeout', 408
|
return 'Request Timeout', 408
|
||||||
|
|
||||||
|
@ -111,7 +114,7 @@ def openai_chat_completions(model_name=None):
|
||||||
while True:
|
while True:
|
||||||
stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout)
|
stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout)
|
||||||
if not stream_data:
|
if not stream_data:
|
||||||
print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.")
|
_logger.debug(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.")
|
||||||
yield 'data: [DONE]\n\n'
|
yield 'data: [DONE]\n\n'
|
||||||
else:
|
else:
|
||||||
for stream_index, item in stream_data[0][1]:
|
for stream_index, item in stream_data[0][1]:
|
||||||
|
@ -120,7 +123,7 @@ def openai_chat_completions(model_name=None):
|
||||||
data = ujson.loads(item[b'data'])
|
data = ujson.loads(item[b'data'])
|
||||||
if data['error']:
|
if data['error']:
|
||||||
# Not printing error since we can just check the daemon log.
|
# Not printing error since we can just check the daemon log.
|
||||||
print('OAI streaming encountered error')
|
_logger.warn(f'OAI streaming encountered error: {data["error"]}')
|
||||||
yield 'data: [DONE]\n\n'
|
yield 'data: [DONE]\n\n'
|
||||||
return
|
return
|
||||||
elif data['new']:
|
elif data['new']:
|
||||||
|
|
|
@ -16,10 +16,13 @@ from ...database.log_to_db import log_to_db
|
||||||
from ...llm import get_token_count
|
from ...llm import get_token_count
|
||||||
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
|
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
|
||||||
from ...llm.openai.transform import generate_oai_string, trim_string_to_fit
|
from ...llm.openai.transform import generate_oai_string, trim_string_to_fit
|
||||||
|
from ...logging import create_logger
|
||||||
|
|
||||||
# TODO: add rate-limit headers?
|
# TODO: add rate-limit headers?
|
||||||
|
|
||||||
|
_logger = create_logger('OpenAICompletions')
|
||||||
|
|
||||||
|
|
||||||
@openai_bp.route('/completions', methods=['POST'])
|
@openai_bp.route('/completions', methods=['POST'])
|
||||||
@openai_model_bp.route('/<model_name>/v1/completions', methods=['POST'])
|
@openai_model_bp.route('/<model_name>/v1/completions', methods=['POST'])
|
||||||
def openai_completions(model_name=None):
|
def openai_completions(model_name=None):
|
||||||
|
@ -144,7 +147,7 @@ def openai_completions(model_name=None):
|
||||||
|
|
||||||
_, stream_name, error_msg = event.wait()
|
_, stream_name, error_msg = event.wait()
|
||||||
if error_msg:
|
if error_msg:
|
||||||
print('OAI failed to start streaming:', error_msg)
|
_logger.error(f'OAI failed to start streaming: {error_msg}')
|
||||||
stream_name = None
|
stream_name = None
|
||||||
return 'Request Timeout', 408
|
return 'Request Timeout', 408
|
||||||
|
|
||||||
|
@ -156,7 +159,7 @@ def openai_completions(model_name=None):
|
||||||
while True:
|
while True:
|
||||||
stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout)
|
stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout)
|
||||||
if not stream_data:
|
if not stream_data:
|
||||||
print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.")
|
_logger.debug(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.")
|
||||||
yield 'data: [DONE]\n\n'
|
yield 'data: [DONE]\n\n'
|
||||||
else:
|
else:
|
||||||
for stream_index, item in stream_data[0][1]:
|
for stream_index, item in stream_data[0][1]:
|
||||||
|
@ -164,7 +167,7 @@ def openai_completions(model_name=None):
|
||||||
timestamp = int(stream_index.decode('utf-8').split('-')[0])
|
timestamp = int(stream_index.decode('utf-8').split('-')[0])
|
||||||
data = ujson.loads(item[b'data'])
|
data = ujson.loads(item[b'data'])
|
||||||
if data['error']:
|
if data['error']:
|
||||||
print('OAI streaming encountered error')
|
_logger.error(f'OAI streaming encountered error: {data["error"]}')
|
||||||
yield 'data: [DONE]\n\n'
|
yield 'data: [DONE]\n\n'
|
||||||
return
|
return
|
||||||
elif data['new']:
|
elif data['new']:
|
||||||
|
|
|
@ -16,9 +16,12 @@ from llm_server.database.log_to_db import log_to_db
|
||||||
from llm_server.llm import get_token_count
|
from llm_server.llm import get_token_count
|
||||||
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_invalid_model_err
|
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_invalid_model_err
|
||||||
from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
||||||
|
from llm_server.logging import create_logger
|
||||||
from llm_server.routes.request_handler import RequestHandler
|
from llm_server.routes.request_handler import RequestHandler
|
||||||
from llm_server.workers.moderator import add_moderation_task, get_results
|
from llm_server.workers.moderator import add_moderation_task, get_results
|
||||||
|
|
||||||
|
_logger = create_logger('OpenAIRequestHandler')
|
||||||
|
|
||||||
|
|
||||||
class OpenAIRequestHandler(RequestHandler):
|
class OpenAIRequestHandler(RequestHandler):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
@ -29,7 +32,7 @@ class OpenAIRequestHandler(RequestHandler):
|
||||||
assert not self.used
|
assert not self.used
|
||||||
if self.offline:
|
if self.offline:
|
||||||
msg = return_invalid_model_err(self.selected_model)
|
msg = return_invalid_model_err(self.selected_model)
|
||||||
print('OAI Offline:', msg)
|
_logger.error(f'OAI is offline: {msg}')
|
||||||
return self.handle_error(msg)
|
return self.handle_error(msg)
|
||||||
|
|
||||||
if opts.openai_silent_trim:
|
if opts.openai_silent_trim:
|
||||||
|
@ -72,7 +75,7 @@ class OpenAIRequestHandler(RequestHandler):
|
||||||
self.request.json['messages'].insert((len(self.request.json['messages'])), {'role': 'system', 'content': mod_msg})
|
self.request.json['messages'].insert((len(self.request.json['messages'])), {'role': 'system', 'content': mod_msg})
|
||||||
self.prompt = transform_messages_to_prompt(self.request.json['messages'])
|
self.prompt = transform_messages_to_prompt(self.request.json['messages'])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'OpenAI moderation endpoint failed:', f'{e.__class__.__name__}: {e}')
|
_logger.error(f'OpenAI moderation endpoint failed: {e.__class__.__name__}: {e}')
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
llm_request = {**self.parameters, 'prompt': self.prompt}
|
llm_request = {**self.parameters, 'prompt': self.prompt}
|
||||||
|
@ -106,7 +109,7 @@ class OpenAIRequestHandler(RequestHandler):
|
||||||
return response, 429
|
return response, 429
|
||||||
|
|
||||||
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
|
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
|
||||||
print('OAI Error:', error_msg)
|
_logger.error('OAI Error: {error_msg}')
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"error": {
|
"error": {
|
||||||
"message": "Invalid request, check your parameters and try again.",
|
"message": "Invalid request, check your parameters and try again.",
|
||||||
|
@ -155,7 +158,7 @@ class OpenAIRequestHandler(RequestHandler):
|
||||||
def validate_request(self, prompt: str = None, do_log: bool = False) -> Tuple[bool, Tuple[Response | None, int]]:
|
def validate_request(self, prompt: str = None, do_log: bool = False) -> Tuple[bool, Tuple[Response | None, int]]:
|
||||||
self.parameters, parameters_invalid_msg = self.get_parameters()
|
self.parameters, parameters_invalid_msg = self.get_parameters()
|
||||||
if not self.parameters:
|
if not self.parameters:
|
||||||
print('OAI BACKEND VALIDATION ERROR:', parameters_invalid_msg)
|
_logger.error(f'OAI BACKEND VALIDATION ERROR: {parameters_invalid_msg}')
|
||||||
return False, (Response('Invalid request, check your parameters and try again.'), 400)
|
return False, (Response('Invalid request, check your parameters and try again.'), 400)
|
||||||
invalid_oai_err_msg = validate_oai(self.parameters)
|
invalid_oai_err_msg = validate_oai(self.parameters)
|
||||||
if invalid_oai_err_msg:
|
if invalid_oai_err_msg:
|
||||||
|
|
|
@ -10,6 +10,7 @@ from llm_server import opts
|
||||||
from llm_server.cluster.cluster_config import cluster_config
|
from llm_server.cluster.cluster_config import cluster_config
|
||||||
from llm_server.custom_redis import RedisCustom, redis
|
from llm_server.custom_redis import RedisCustom, redis
|
||||||
from llm_server.database.database import get_token_ratelimit
|
from llm_server.database.database import get_token_ratelimit
|
||||||
|
from llm_server.logging import create_logger
|
||||||
|
|
||||||
|
|
||||||
def increment_ip_count(client_ip: str, redis_key):
|
def increment_ip_count(client_ip: str, redis_key):
|
||||||
|
@ -30,6 +31,7 @@ class RedisPriorityQueue:
|
||||||
def __init__(self, name, db: int = 12):
|
def __init__(self, name, db: int = 12):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.redis = RedisCustom(name, db=db)
|
self.redis = RedisCustom(name, db=db)
|
||||||
|
self._logger = create_logger('RedisPriorityQueue')
|
||||||
|
|
||||||
def put(self, item, priority: int, selected_model: str, do_stream: bool = False):
|
def put(self, item, priority: int, selected_model: str, do_stream: bool = False):
|
||||||
# TODO: remove this when we're sure nothing strange is happening
|
# TODO: remove this when we're sure nothing strange is happening
|
||||||
|
@ -41,7 +43,7 @@ class RedisPriorityQueue:
|
||||||
ip_count = self.get_ip_request_count(item[1])
|
ip_count = self.get_ip_request_count(item[1])
|
||||||
_, simultaneous_ip = get_token_ratelimit(item[2])
|
_, simultaneous_ip = get_token_ratelimit(item[2])
|
||||||
if ip_count and int(ip_count) >= simultaneous_ip and priority != 0:
|
if ip_count and int(ip_count) >= simultaneous_ip and priority != 0:
|
||||||
print(f'Rejecting request from {item[1]} - {ip_count} request queued.')
|
self._logger.debug(f'Rejecting request from {item[1]} - {ip_count} request queued.')
|
||||||
return None # reject the request
|
return None # reject the request
|
||||||
|
|
||||||
timestamp = time.time()
|
timestamp = time.time()
|
||||||
|
@ -98,7 +100,7 @@ class RedisPriorityQueue:
|
||||||
event_id = item_data[1]
|
event_id = item_data[1]
|
||||||
event = DataEvent(event_id)
|
event = DataEvent(event_id)
|
||||||
event.set((False, None, 'closed'))
|
event.set((False, None, 'closed'))
|
||||||
print('Removed timed-out item from queue:', event_id)
|
self._logger.debug('Removed timed-out item from queue: {event_id}')
|
||||||
|
|
||||||
|
|
||||||
class DataEvent:
|
class DataEvent:
|
||||||
|
|
|
@ -12,10 +12,13 @@ from llm_server.database.log_to_db import log_to_db
|
||||||
from llm_server.helpers import auto_set_base_client_api
|
from llm_server.helpers import auto_set_base_client_api
|
||||||
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
|
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
|
||||||
from llm_server.llm.vllm.vllm_backend import VLLMBackend
|
from llm_server.llm.vllm.vllm_backend import VLLMBackend
|
||||||
|
from llm_server.logging import create_logger
|
||||||
from llm_server.routes.auth import parse_token
|
from llm_server.routes.auth import parse_token
|
||||||
from llm_server.routes.helpers.http import require_api_key, validate_json
|
from llm_server.routes.helpers.http import require_api_key, validate_json
|
||||||
from llm_server.routes.queue import priority_queue
|
from llm_server.routes.queue import priority_queue
|
||||||
|
|
||||||
|
_logger = create_logger('RequestHandler')
|
||||||
|
|
||||||
|
|
||||||
class RequestHandler:
|
class RequestHandler:
|
||||||
def __init__(self, incoming_request: flask.Request, selected_model: str = None, incoming_json: Union[dict, str] = None):
|
def __init__(self, incoming_request: flask.Request, selected_model: str = None, incoming_json: Union[dict, str] = None):
|
||||||
|
@ -223,7 +226,7 @@ class RequestHandler:
|
||||||
processing_ip = 0
|
processing_ip = 0
|
||||||
|
|
||||||
if queued_ip_count + processing_ip >= self.token_simultaneous_ip:
|
if queued_ip_count + processing_ip >= self.token_simultaneous_ip:
|
||||||
print(f'Rejecting request from {self.client_ip} - {processing_ip} processing, {queued_ip_count} queued')
|
_logger.debug(f'Rejecting request from {self.client_ip} - {processing_ip} processing, {queued_ip_count} queued')
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
|
@ -1,3 +1,8 @@
|
||||||
|
from llm_server.logging import create_logger
|
||||||
|
|
||||||
|
_logger = create_logger('handle_server_error')
|
||||||
|
|
||||||
|
|
||||||
def handle_server_error(e):
|
def handle_server_error(e):
|
||||||
print('Internal Error:', e)
|
_logger.error(f'Internal Error: {e}')
|
||||||
return {'error': True, 'code': 500, 'message': 'Internal Server Error :('}, 500
|
return {'error': True, 'code': 500, 'message': 'Internal Server Error :('}, 500
|
||||||
|
|
|
@ -13,12 +13,15 @@ from ..queue import priority_queue
|
||||||
from ... import opts
|
from ... import opts
|
||||||
from ...custom_redis import redis
|
from ...custom_redis import redis
|
||||||
from ...database.log_to_db import log_to_db
|
from ...database.log_to_db import log_to_db
|
||||||
|
from ...logging import create_logger
|
||||||
from ...sock import sock
|
from ...sock import sock
|
||||||
|
|
||||||
|
|
||||||
# Stacking the @sock.route() creates a TypeError error on the /v1/stream endpoint.
|
# Stacking the @sock.route() creates a TypeError error on the /v1/stream endpoint.
|
||||||
# We solve this by splitting the routes
|
# We solve this by splitting the routes
|
||||||
|
|
||||||
|
_logger = create_logger('GenerateStream')
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/v1/stream')
|
@bp.route('/v1/stream')
|
||||||
@bp.route('/<model_name>/v1/stream')
|
@bp.route('/<model_name>/v1/stream')
|
||||||
def stream(model_name=None):
|
def stream(model_name=None):
|
||||||
|
@ -85,7 +88,7 @@ def do_stream(ws, model_name):
|
||||||
handler = OobaRequestHandler(incoming_request=request, selected_model=model_name, incoming_json=request_json_body)
|
handler = OobaRequestHandler(incoming_request=request, selected_model=model_name, incoming_json=request_json_body)
|
||||||
if handler.offline:
|
if handler.offline:
|
||||||
msg = f'{handler.selected_model} is not a valid model choice.'
|
msg = f'{handler.selected_model} is not a valid model choice.'
|
||||||
print(msg)
|
_logger.debug(msg)
|
||||||
ws.send(json.dumps({
|
ws.send(json.dumps({
|
||||||
'event': 'text_stream',
|
'event': 'text_stream',
|
||||||
'message_num': 0,
|
'message_num': 0,
|
||||||
|
@ -131,7 +134,7 @@ def do_stream(ws, model_name):
|
||||||
|
|
||||||
_, stream_name, error_msg = event.wait()
|
_, stream_name, error_msg = event.wait()
|
||||||
if error_msg:
|
if error_msg:
|
||||||
print('Stream failed to start streaming:', error_msg)
|
_logger.error(f'Stream failed to start streaming: {error_msg}')
|
||||||
ws.close(reason=1014, message='Request Timeout')
|
ws.close(reason=1014, message='Request Timeout')
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -143,14 +146,14 @@ def do_stream(ws, model_name):
|
||||||
while True:
|
while True:
|
||||||
stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout)
|
stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout)
|
||||||
if not stream_data:
|
if not stream_data:
|
||||||
print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.")
|
_logger.error(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.")
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
for stream_index, item in stream_data[0][1]:
|
for stream_index, item in stream_data[0][1]:
|
||||||
last_id = stream_index
|
last_id = stream_index
|
||||||
data = ujson.loads(item[b'data'])
|
data = ujson.loads(item[b'data'])
|
||||||
if data['error']:
|
if data['error']:
|
||||||
print(data['error'])
|
_logger.error(f'Encountered error while streaming: {data["error"]}')
|
||||||
send_err_and_quit('Encountered exception while streaming.')
|
send_err_and_quit('Encountered exception while streaming.')
|
||||||
return
|
return
|
||||||
elif data['new']:
|
elif data['new']:
|
||||||
|
|
|
@ -2,6 +2,7 @@ import time
|
||||||
|
|
||||||
from redis import Redis
|
from redis import Redis
|
||||||
|
|
||||||
|
from llm_server.logging import create_logger
|
||||||
from llm_server.workers.inferencer import STREAM_NAME_PREFIX
|
from llm_server.workers.inferencer import STREAM_NAME_PREFIX
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,6 +11,7 @@ from llm_server.workers.inferencer import STREAM_NAME_PREFIX
|
||||||
def cleaner():
|
def cleaner():
|
||||||
r = Redis(db=8)
|
r = Redis(db=8)
|
||||||
stream_info = {}
|
stream_info = {}
|
||||||
|
logger = create_logger('cleaner')
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
all_streams = r.keys(f'{STREAM_NAME_PREFIX}:*')
|
all_streams = r.keys(f'{STREAM_NAME_PREFIX}:*')
|
||||||
|
@ -26,7 +28,7 @@ def cleaner():
|
||||||
# If the size hasn't changed for 5 minutes, delete the stream
|
# If the size hasn't changed for 5 minutes, delete the stream
|
||||||
if time.time() - stream_info[stream]['time'] >= 300:
|
if time.time() - stream_info[stream]['time'] >= 300:
|
||||||
r.delete(stream)
|
r.delete(stream)
|
||||||
print(f"Stream '{stream}' deleted due to inactivity.")
|
logger.debug(f"Stream '{stream}' deleted due to inactivity.")
|
||||||
del stream_info[stream]
|
del stream_info[stream]
|
||||||
|
|
||||||
time.sleep(60)
|
time.sleep(60)
|
||||||
|
|
|
@ -4,6 +4,7 @@ import traceback
|
||||||
import redis
|
import redis
|
||||||
|
|
||||||
from llm_server.database.database import do_db_log
|
from llm_server.database.database import do_db_log
|
||||||
|
from llm_server.logging import create_logger
|
||||||
|
|
||||||
|
|
||||||
def db_logger():
|
def db_logger():
|
||||||
|
@ -16,6 +17,7 @@ def db_logger():
|
||||||
r = redis.Redis(host='localhost', port=6379, db=3)
|
r = redis.Redis(host='localhost', port=6379, db=3)
|
||||||
p = r.pubsub()
|
p = r.pubsub()
|
||||||
p.subscribe('database-logger')
|
p.subscribe('database-logger')
|
||||||
|
logger = create_logger('main_bg')
|
||||||
|
|
||||||
for message in p.listen():
|
for message in p.listen():
|
||||||
try:
|
try:
|
||||||
|
@ -28,4 +30,4 @@ def db_logger():
|
||||||
if function_name == 'log_prompt':
|
if function_name == 'log_prompt':
|
||||||
do_db_log(*args, **kwargs)
|
do_db_log(*args, **kwargs)
|
||||||
except:
|
except:
|
||||||
traceback.print_exc()
|
logger.error(traceback.format_exc())
|
||||||
|
|
|
@ -7,10 +7,12 @@ from llm_server.cluster.cluster_config import get_backends, cluster_config
|
||||||
from llm_server.custom_redis import redis
|
from llm_server.custom_redis import redis
|
||||||
from llm_server.database.database import weighted_average_column_for_model
|
from llm_server.database.database import weighted_average_column_for_model
|
||||||
from llm_server.llm.info import get_info
|
from llm_server.llm.info import get_info
|
||||||
|
from llm_server.logging import create_logger
|
||||||
from llm_server.routes.queue import RedisPriorityQueue, priority_queue
|
from llm_server.routes.queue import RedisPriorityQueue, priority_queue
|
||||||
|
|
||||||
|
|
||||||
def main_background_thread():
|
def main_background_thread():
|
||||||
|
logger = create_logger('main_bg')
|
||||||
while True:
|
while True:
|
||||||
online, offline = get_backends()
|
online, offline = get_backends()
|
||||||
for backend_url in online:
|
for backend_url in online:
|
||||||
|
@ -34,7 +36,7 @@ def main_background_thread():
|
||||||
base_client_api = redis.get('base_client_api', dtype=str)
|
base_client_api = redis.get('base_client_api', dtype=str)
|
||||||
r = requests.get('https://' + base_client_api, timeout=5)
|
r = requests.get('https://' + base_client_api, timeout=5)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'Failed fetch the homepage - {e.__class__.__name__}: {e}')
|
logger.error(f'Failed fetch the homepage - {e.__class__.__name__}: {e}')
|
||||||
|
|
||||||
backends = priority_queue.get_backends()
|
backends = priority_queue.get_backends()
|
||||||
for backend_url in backends:
|
for backend_url in backends:
|
||||||
|
|
|
@ -4,7 +4,7 @@ Flask-Caching==2.0.2
|
||||||
requests~=2.31.0
|
requests~=2.31.0
|
||||||
tiktoken~=0.5.0
|
tiktoken~=0.5.0
|
||||||
gevent~=23.9.0.post1
|
gevent~=23.9.0.post1
|
||||||
PyMySQL~=1.1.0
|
mysql-connector-python==8.4.0
|
||||||
simplejson~=3.19.1
|
simplejson~=3.19.1
|
||||||
websockets~=11.0.3
|
websockets~=11.0.3
|
||||||
basicauth~=1.0.0
|
basicauth~=1.0.0
|
||||||
|
@ -14,5 +14,4 @@ gunicorn==21.2.0
|
||||||
redis==5.0.1
|
redis==5.0.1
|
||||||
ujson==5.8.0
|
ujson==5.8.0
|
||||||
vllm==0.2.7
|
vllm==0.2.7
|
||||||
gradio~=3.46.1
|
|
||||||
coloredlogs~=15.0.1
|
coloredlogs~=15.0.1
|
||||||
|
|
26
server.py
26
server.py
|
@ -22,11 +22,11 @@ from llm_server.cluster.cluster_config import cluster_config
|
||||||
from llm_server.config.config import mode_ui_names
|
from llm_server.config.config import mode_ui_names
|
||||||
from llm_server.config.load import load_config
|
from llm_server.config.load import load_config
|
||||||
from llm_server.custom_redis import flask_cache, redis
|
from llm_server.custom_redis import flask_cache, redis
|
||||||
from llm_server.database.conn import database
|
from llm_server.database.conn import database, Database
|
||||||
from llm_server.database.create import create_db
|
from llm_server.database.create import create_db
|
||||||
from llm_server.helpers import auto_set_base_client_api
|
from llm_server.helpers import auto_set_base_client_api
|
||||||
from llm_server.llm.vllm.info import vllm_info
|
from llm_server.llm.vllm.info import vllm_info
|
||||||
from llm_server.logging import init_logging
|
from llm_server.logging import init_logging, create_logger
|
||||||
from llm_server.routes.openai import openai_bp, openai_model_bp
|
from llm_server.routes.openai import openai_bp, openai_model_bp
|
||||||
from llm_server.routes.server_error import handle_server_error
|
from llm_server.routes.server_error import handle_server_error
|
||||||
from llm_server.routes.v1 import bp
|
from llm_server.routes.v1 import bp
|
||||||
|
@ -62,13 +62,6 @@ from llm_server.sock import init_wssocket
|
||||||
# TODO: add more excluding to SYSTEM__ tokens
|
# TODO: add more excluding to SYSTEM__ tokens
|
||||||
# TODO: return 200 when returning formatted sillytavern error
|
# TODO: return 200 when returning formatted sillytavern error
|
||||||
|
|
||||||
try:
|
|
||||||
import vllm
|
|
||||||
except ModuleNotFoundError as e:
|
|
||||||
print('Could not import vllm-gptq:', e)
|
|
||||||
print('Please see README.md for install instructions.')
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
script_path = os.path.dirname(os.path.realpath(__file__))
|
script_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
config_path_environ = os.getenv("CONFIG_PATH")
|
config_path_environ = os.getenv("CONFIG_PATH")
|
||||||
if config_path_environ:
|
if config_path_environ:
|
||||||
|
@ -78,11 +71,20 @@ else:
|
||||||
|
|
||||||
success, config, msg = load_config(config_path)
|
success, config, msg = load_config(config_path)
|
||||||
if not success:
|
if not success:
|
||||||
print('Failed to load config:', msg)
|
logger = logging.getLogger('llm_server')
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
logger.error(f'Failed to load config: {msg}')
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
init_logging(Path(config['webserver_log_directory']) / 'server.log')
|
init_logging(Path(config['webserver_log_directory']) / 'server.log')
|
||||||
logger = logging.getLogger('llm_server')
|
logger = create_logger('Server')
|
||||||
|
logger.debug('Debug logging enabled.')
|
||||||
|
|
||||||
|
try:
|
||||||
|
import vllm
|
||||||
|
except ModuleNotFoundError as e:
|
||||||
|
logger.error(f'Could not import vllm-gptq: {e}')
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
while not redis.get('daemon_started', dtype=bool):
|
while not redis.get('daemon_started', dtype=bool):
|
||||||
logger.warning('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?')
|
logger.warning('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?')
|
||||||
|
@ -90,7 +92,7 @@ while not redis.get('daemon_started', dtype=bool):
|
||||||
|
|
||||||
logger.info('Started HTTP worker!')
|
logger.info('Started HTTP worker!')
|
||||||
|
|
||||||
database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database'])
|
Database.initialise(maxconn=config['mysql']['maxconn'], host=config['mysql']['host'], user=config['mysql']['username'], password=config['mysql']['password'], database=config['mysql']['database'])
|
||||||
create_db()
|
create_db()
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
<!DOCTYPE html>
|
<!DOCTYPE html>
|
||||||
<html lang="en">
|
<html lang="en">
|
||||||
|
|
||||||
<head>
|
<head>
|
||||||
<title>{{ llm_middleware_name }}</title>
|
<title>{{ llm_middleware_name }}</title>
|
||||||
<meta content="width=device-width, initial-scale=1" name="viewport"/>
|
<meta content="width=device-width, initial-scale=1" name="viewport"/>
|
||||||
|
@ -112,7 +111,8 @@
|
||||||
<li>Set your API type to <kbd>{{ mode_name }}</kbd></li>
|
<li>Set your API type to <kbd>{{ mode_name }}</kbd></li>
|
||||||
<li>Enter <kbd>{{ client_api }}</kbd> in the <kbd>{{ api_input_textbox }}</kbd> textbox.</li>
|
<li>Enter <kbd>{{ client_api }}</kbd> in the <kbd>{{ api_input_textbox }}</kbd> textbox.</li>
|
||||||
{% if enable_streaming %}
|
{% if enable_streaming %}
|
||||||
<li>Enter <kbd>{{ ws_client_api }}</kbd> in the <kbd>{{ streaming_input_textbox }}</kbd> textbox.</li>
|
<li>Enter <kbd>{{ ws_client_api }}</kbd> in the <kbd>{{ streaming_input_textbox }}</kbd> textbox.
|
||||||
|
</li>
|
||||||
{% endif %}
|
{% endif %}
|
||||||
<li>If you have a token, check the <kbd>Mancer AI</kbd> checkbox and enter your token in the <kbd>Mancer
|
<li>If you have a token, check the <kbd>Mancer AI</kbd> checkbox and enter your token in the <kbd>Mancer
|
||||||
API key</kbd> textbox.
|
API key</kbd> textbox.
|
||||||
|
@ -127,7 +127,8 @@
|
||||||
<br>
|
<br>
|
||||||
<div id="openai">
|
<div id="openai">
|
||||||
<strong>OpenAI-Compatible API</strong>
|
<strong>OpenAI-Compatible API</strong>
|
||||||
<p>The OpenAI-compatible API adds a system prompt to set the AI's behavior to a "helpful assistant". You can view this prompt <a href="/api/openai/v1/prompt">here</a>.</p>
|
<p>The OpenAI-compatible API adds a system prompt to set the AI's behavior to a "helpful assistant". You
|
||||||
|
can view this prompt <a href="/api/openai/v1/prompt">here</a>.</p>
|
||||||
</div>
|
</div>
|
||||||
{% endif %}
|
{% endif %}
|
||||||
<br>
|
<br>
|
||||||
|
@ -148,7 +149,8 @@
|
||||||
|
|
||||||
{% for key, value in model_choices.items() %}
|
{% for key, value in model_choices.items() %}
|
||||||
<div class="info-box">
|
<div class="info-box">
|
||||||
<h3>{{ key }} <span class="header-workers">- {{ value.backend_count }} {% if value.backend_count == 1 %}worker{% else %}workers{% endif %}</span></h3>
|
<h3>{{ key }} <span class="header-workers">- {{ value.backend_count }} {% if value.backend_count == 1 %}
|
||||||
|
worker{% else %}workers{% endif %}</span></h3>
|
||||||
|
|
||||||
{% if value.estimated_wait == 0 and value.estimated_wait >= value.concurrent_gens %}
|
{% if value.estimated_wait == 0 and value.estimated_wait >= value.concurrent_gens %}
|
||||||
{# There will be a wait if the queue is empty but prompts are processing, but we don't know how long. #}
|
{# There will be a wait if the queue is empty but prompts are processing, but we don't know how long. #}
|
||||||
|
|
Reference in New Issue