refactor a lot of things, major cleanup, use postgresql

This commit is contained in:
Cyberes 2024-05-07 17:03:41 -06:00
parent ee9a0d4858
commit fd09c783d3
31 changed files with 220 additions and 367 deletions

View File

@ -3,7 +3,6 @@ import logging
import os
import sys
import time
from pathlib import Path
from redis import Redis
@ -14,6 +13,7 @@ from llm_server.custom_redis import redis
from llm_server.database.conn import Database
from llm_server.database.create import create_db
from llm_server.database.database import get_number_of_rows
from llm_server.helpers import resolve_path
from llm_server.logging import create_logger, logging_info, init_logging
from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.workers.threader import start_background
@ -23,7 +23,7 @@ config_path_environ = os.getenv("CONFIG_PATH")
if config_path_environ:
config_path = config_path_environ
else:
config_path = Path(script_path, 'config', 'config.yml')
config_path = resolve_path(script_path, 'config', 'config.yml')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Daemon microservice.')
@ -47,7 +47,7 @@ if __name__ == "__main__":
logger.info(f'Failed to load config: {msg}')
sys.exit(1)
Database.initialise(maxconn=GlobalConfig.get().mysql.maxconn, host=GlobalConfig.get().mysql.host, user=GlobalConfig.get().mysql.username, password=GlobalConfig.get().mysql.password, database=GlobalConfig.get().mysql.database)
Database.initialise(**GlobalConfig.get().postgresql.dict())
create_db()
cluster_config.clear()
@ -57,7 +57,7 @@ if __name__ == "__main__":
generate_stats(regen=True)
if GlobalConfig.get().load_num_prompts:
redis.set('proompts', get_number_of_rows('prompts'))
redis.set('proompts', get_number_of_rows('messages'))
start_background()

View File

@ -17,6 +17,7 @@ def get_backends_from_model(model_name: str):
:param model_name:
:return:
"""
assert isinstance(model_name, str)
return [x.decode('utf-8') for x in redis_running_models.smembers(model_name)]
@ -25,7 +26,7 @@ def get_running_models():
Get all the models that are in the cluster.
:return:
"""
return list(redis_running_models.keys())
return [x.decode('utf-8') for x in list(redis_running_models.keys())]
def is_valid_model(model_name: str) -> bool:
@ -81,6 +82,7 @@ def get_model_choices(regen: bool = False) -> tuple[dict, dict]:
base_client_api = redis.get('base_client_api', dtype=str)
running_models = get_running_models()
model_choices = {}
for model in running_models:
b = get_backends_from_model(model)

View File

@ -33,7 +33,7 @@ class RedisClusterStore:
item.backend_url = backend_url
stuff[backend_url] = item
for k, v in stuff.items():
self.add_backend(k, v)
self.add_backend(k, v.dict())
def add_backend(self, name: str, values: dict):
self.config_redis.hset(name, mapping={k: pickle.dumps(v) for k, v in values.items()})

View File

@ -1,14 +1,22 @@
from pydantic import BaseModel
from llm_server.config.global_config import GlobalConfig
def cluster_worker_count():
count = 0
for item in GlobalConfig.get().cluster:
count += item['concurrent_gens']
count += item.concurrent_gens
return count
mode_ui_names = {
'ooba': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
'vllm': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
class ModeUINameStr(BaseModel):
name: str
api_name: str
streaming_name: str
MODE_UI_NAMES = {
'ooba': ModeUINameStr(name='Text Gen WebUI (ooba)', api_name='Blocking API url', streaming_name='Streaming API url'),
'vllm': ModeUINameStr(name='Text Gen WebUI (ooba)', api_name='Blocking API url', streaming_name='Streaming API url'),
}

View File

@ -63,8 +63,8 @@ def load_config(config_path: Path):
config_model = ConfigModel(**config.config)
GlobalConfig.initalize(config_model)
if not (0 < GlobalConfig.get().mysql.maxconn <= 32):
return False, f'"maxcon" should be higher than 0 and lower or equal to 32. Current value: "{GlobalConfig.get().mysql.maxconn}"'
if GlobalConfig.get().postgresql.maxconn < 0:
return False, f'"maxcon" should be higher than 0. Current value: "{GlobalConfig.get().postgresql.maxconn}"'
openai.api_key = GlobalConfig.get().openai_api_key

View File

@ -19,9 +19,9 @@ class ConfigFrontendApiModes(str, Enum):
ooba = 'ooba'
class ConfigMysql(BaseModel):
class ConfigPostgresql(BaseModel):
host: str
username: str
user: str
password: str
database: str
maxconn: int
@ -37,9 +37,8 @@ class ConfigModel(BaseModel):
cluster: List[ConfigCluser]
prioritize_by_size: bool
admin_token: Union[str, None]
mysql: ConfigMysql
postgresql: ConfigPostgresql
http_host: str
webserver_log_directory: str
include_system_tokens_in_stats: bool
background_homepage_cacher: bool
max_new_tokens: int
@ -55,6 +54,7 @@ class ConfigModel(BaseModel):
info_html: Union[str, None]
enable_openi_compatible_backend: bool
openai_api_key: Union[str, None]
openai_system_prompt: str
expose_openai_system_prompt: bool
openai_expose_our_model: bool
openai_force_no_hashes: bool
@ -72,3 +72,4 @@ class ConfigModel(BaseModel):
load_num_prompts: bool
manual_model_name: Union[str, None]
backend_request_timeout: int
backend_generate_request_timeout: int

View File

@ -2,7 +2,7 @@ from typing import Union
import bison
from llm_server.opts import default_openai_system_prompt
from llm_server.globals import DEFAULT_OPENAI_SYSTEM_PROMPT
config_scheme = bison.Scheme(
bison.Option('frontend_api_mode', choices=['ooba'], field_type=str),
@ -14,15 +14,14 @@ config_scheme = bison.Scheme(
)),
bison.Option('prioritize_by_size', default=True, field_type=bool),
bison.Option('admin_token', default=None, field_type=Union[str, None]),
bison.ListOption('mysql', member_scheme=bison.Scheme(
bison.ListOption('postgresql', member_scheme=bison.Scheme(
bison.Option('host', field_type=str),
bison.Option('username', field_type=str),
bison.Option('user', field_type=str),
bison.Option('password', field_type=str),
bison.Option('database', field_type=str),
bison.Option('maxconn', field_type=int)
)),
bison.Option('http_host', default='', field_type=str),
bison.Option('webserver_log_directory', default='/var/log/localllm', field_type=str),
bison.Option('include_system_tokens_in_stats', default=True, field_type=bool),
bison.Option('background_homepage_cacher', default=True, field_type=bool),
bison.Option('max_new_tokens', default=500, field_type=int),
@ -41,7 +40,7 @@ config_scheme = bison.Scheme(
bison.Option('expose_openai_system_prompt', default=True, field_type=bool),
bison.Option('openai_expose_our_model', default='', field_type=bool),
bison.Option('openai_force_no_hashes', default=True, field_type=bool),
bison.Option('openai_system_prompt', default=default_openai_system_prompt, field_type=str),
bison.Option('openai_system_prompt', default=DEFAULT_OPENAI_SYSTEM_PROMPT, field_type=str),
bison.Option('openai_moderation_enabled', default=False, field_type=bool),
bison.Option('openai_moderation_timeout', default=5, field_type=int),
bison.Option('openai_moderation_scan_last_n', default=5, field_type=int),
@ -55,5 +54,6 @@ config_scheme = bison.Scheme(
bison.Option('show_backend_info', default=True, field_type=bool),
bison.Option('load_num_prompts', default=True, field_type=bool),
bison.Option('manual_model_name', default=None, field_type=Union[str, None]),
bison.Option('backend_request_timeout', default=30, field_type=int)
bison.Option('backend_request_timeout', default=30, field_type=int),
bison.Option('backend_generate_request_timeout', default=95, field_type=int)
)

View File

@ -2,13 +2,13 @@ import logging
import pickle
import sys
import traceback
from typing import Callable, List, Mapping, Optional, Union
from typing import Union
import redis as redis_pkg
import simplejson as json
from flask_caching import Cache
from redis import Redis
from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, PatternT, ZScoreBoundT, AbsExpiryT
from redis.typing import ExpiryT, KeyT, PatternT
flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/15', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
@ -38,18 +38,11 @@ class RedisCustom(Redis):
def _key(self, key):
return f"{self.prefix}:{key}"
def set(self, key: KeyT,
value: EncodableT,
ex: Union[ExpiryT, None] = None,
px: Union[ExpiryT, None] = None,
nx: bool = False,
xx: bool = False,
keepttl: bool = False,
get: bool = False,
exat: Union[AbsExpiryT, None] = None,
pxat: Union[AbsExpiryT, None] = None
):
return self.redis.set(self._key(key), value, ex=ex)
def execute_command(self, *args, **options):
if args[0] != 'GET':
args = list(args)
args[1] = self._key(args[1])
return super().execute_command(*args, **options)
def get(self, key, default=None, dtype=None):
# TODO: use pickle
@ -73,103 +66,6 @@ class RedisCustom(Redis):
else:
return d
def incr(self, key, amount=1):
return self.redis.incr(self._key(key), amount)
def decr(self, key, amount=1):
return self.redis.decr(self._key(key), amount)
def sadd(self, key: str, *values: FieldT):
return self.redis.sadd(self._key(key), *values)
def srem(self, key: str, *values: FieldT):
return self.redis.srem(self._key(key), *values)
def sismember(self, key: str, value: str):
return self.redis.sismember(self._key(key), value)
def lindex(
self, name: str, index: int
):
return self.redis.lindex(self._key(name), index)
def lrem(self, name: str, count: int, value: str):
return self.redis.lrem(self._key(name), count, value)
def rpush(self, name: str, *values: FieldT):
return self.redis.rpush(self._key(name), *values)
def llen(self, name: str):
return self.redis.llen(self._key(name))
def zrangebyscore(
self,
name: KeyT,
min: ZScoreBoundT,
max: ZScoreBoundT,
start: Union[int, None] = None,
num: Union[int, None] = None,
withscores: bool = False,
score_cast_func: Union[type, Callable] = float,
):
return self.redis.zrangebyscore(self._key(name), min, max, start, num, withscores, score_cast_func)
def zremrangebyscore(
self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT
):
return self.redis.zremrangebyscore(self._key(name), min, max)
def hincrby(
self, name: str, key: str, amount: int = 1
):
return self.redis.hincrby(self._key(name), key, amount)
def zcard(self, name: KeyT):
return self.redis.zcard(self._key(name))
def hdel(self, name: str, *keys: str):
return self.redis.hdel(self._key(name), *keys)
def hget(
self, name: str, key: str
):
return self.redis.hget(self._key(name), key)
def zadd(
self,
name: KeyT,
mapping: Mapping[AnyKeyT, EncodableT],
nx: bool = False,
xx: bool = False,
ch: bool = False,
incr: bool = False,
gt: bool = False,
lt: bool = False,
):
return self.redis.zadd(self._key(name), mapping, nx, xx, ch, incr, gt, lt)
def lpush(self, name: str, *values: FieldT):
return self.redis.lpush(self._key(name), *values)
def hset(
self,
name: str,
key: Optional = None,
value=None,
mapping: Optional[dict] = None,
items: Optional[list] = None,
):
return self.redis.hset(self._key(name), key, value, mapping, items)
def hkeys(self, name: str):
return self.redis.hkeys(self._key(name))
def hmget(self, name: str, keys: List, *args: List):
return self.redis.hmget(self._key(name), keys, *args)
def hgetall(self, name: str):
return self.redis.hgetall(self._key(name))
def keys(self, pattern: PatternT = "*", **kwargs):
raw_keys = self.redis.keys(self._key(pattern), **kwargs)
keys = []
@ -179,25 +75,9 @@ class RedisCustom(Redis):
# Delete prefix
del p[0]
k = ':'.join(p)
if k != '____':
keys.append(k)
# keys.append(k)
return keys
def pipeline(self, transaction=True, shard_hint=None):
return self.redis.pipeline(transaction, shard_hint)
def smembers(self, name: str):
return self.redis.smembers(self._key(name))
def spop(self, name: str, count: Optional[int] = None):
return self.redis.spop(self._key(name), count)
def rpoplpush(self, src, dst):
return self.redis.rpoplpush(src, dst)
def zpopmin(self, name: KeyT, count: Union[int, None] = None):
return self.redis.zpopmin(self._key(name), count)
def exists(self, *names: KeyT):
n = []
for name in names:
@ -238,32 +118,5 @@ class RedisCustom(Redis):
self.flush()
return True
def lrange(self, name: str, start: int, end: int):
return self.redis.lrange(self._key(name), start, end)
def delete(self, *names: KeyT):
return self.redis.delete(*[self._key(i) for i in names])
def lpop(self, name: str, count: Optional[int] = None):
return self.redis.lpop(self._key(name), count)
def zrange(
self,
name: KeyT,
start: int,
end: int,
desc: bool = False,
withscores: bool = False,
score_cast_func: Union[type, Callable] = float,
byscore: bool = False,
bylex: bool = False,
offset: int = None,
num: int = None,
):
return self.redis.zrange(self._key(name), start, end, desc, withscores, score_cast_func, byscore, bylex, offset, num)
def zrem(self, name: KeyT, *values: FieldT):
return self.redis.zrem(self._key(name), *values)
redis = RedisCustom('local_llm')

View File

@ -1,39 +1,42 @@
from mysql.connector import pooling
from psycopg2 import pool, InterfaceError
class Database:
__connection_pool = None
@classmethod
def initialise(cls, maxconn: int, **kwargs):
def initialise(cls, maxconn, **kwargs):
if cls.__connection_pool is not None:
raise Exception('Database connection pool is already initialised')
cls.__connection_pool = pooling.MySQLConnectionPool(pool_size=maxconn,
pool_reset_session=True,
**kwargs)
cls.__connection_pool = pool.ThreadedConnectionPool(minconn=1, maxconn=maxconn, **kwargs)
@classmethod
def get_connection(cls):
return cls.__connection_pool.get_connection()
return cls.__connection_pool.getconn()
@classmethod
def return_connection(cls, connection):
connection.close()
cls.__connection_pool.putconn(connection)
class CursorFromConnectionFromPool:
def __init__(self):
def __init__(self, cursor_factory=None):
self.conn = None
self.cursor = None
self.cursor_factory = cursor_factory
def __enter__(self):
self.conn = Database.get_connection()
self.cursor = self.conn.cursor()
self.cursor = self.conn.cursor(cursor_factory=self.cursor_factory)
return self.cursor
def __exit__(self, exception_type, exception_value, exception_traceback):
if exception_value is not None: # This is equivalent of saying if there is an exception
self.conn.rollback()
try:
self.conn.rollback()
except InterfaceError as e:
if e != 'connection already closed':
raise
else:
self.cursor.close()
self.conn.commit()

View File

@ -4,36 +4,39 @@ from llm_server.database.conn import CursorFromConnectionFromPool
def create_db():
with CursorFromConnectionFromPool() as cursor:
cursor.execute('''
CREATE TABLE IF NOT EXISTS prompts (
ip TEXT,
token TEXT DEFAULT NULL,
model TEXT,
backend_mode TEXT,
backend_url TEXT,
request_url TEXT,
generation_time FLOAT,
prompt LONGTEXT,
prompt_tokens INTEGER,
response LONGTEXT,
response_tokens INTEGER,
response_status INTEGER,
parameters TEXT,
# CHECK (parameters IS NULL OR JSON_VALID(parameters)),
headers TEXT,
# CHECK (headers IS NULL OR JSON_VALID(headers)),
timestamp INTEGER
)
CREATE TABLE IF NOT EXISTS public.messages
(
ip text COLLATE pg_catalog."default" NOT NULL,
token text COLLATE pg_catalog."default",
model text COLLATE pg_catalog."default" NOT NULL,
backend_mode text COLLATE pg_catalog."default" NOT NULL,
backend_url text COLLATE pg_catalog."default" NOT NULL,
request_url text COLLATE pg_catalog."default" NOT NULL,
generation_time double precision NOT NULL,
prompt text COLLATE pg_catalog."default" NOT NULL,
prompt_tokens integer NOT NULL,
response text COLLATE pg_catalog."default" NOT NULL,
response_tokens integer NOT NULL,
response_status integer NOT NULL,
parameters jsonb NOT NULL,
headers jsonb,
"timestamp" timestamp with time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
id SERIAL PRIMARY KEY
);
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS token_auth (
token TEXT,
UNIQUE (token),
type TEXT NOT NULL,
priority INTEGER DEFAULT 9999,
simultaneous_ip INTEGER DEFAULT NULL,
uses INTEGER DEFAULT 0,
max_uses INTEGER,
expire INTEGER,
disabled BOOLEAN DEFAULT 0
)
CREATE TABLE IF NOT EXISTS public.token_auth
(
token text COLLATE pg_catalog."default" NOT NULL,
type text COLLATE pg_catalog."default" NOT NULL,
priority integer NOT NULL DEFAULT 9999,
simultaneous_ip text COLLATE pg_catalog."default",
openai_moderation_enabled boolean NOT NULL DEFAULT true,
uses integer NOT NULL DEFAULT 0,
max_uses integer,
expire timestamp with time zone,
disabled boolean NOT NULL DEFAULT false,
notes text COLLATE pg_catalog."default" NOT NULL DEFAULT ''::text,
CONSTRAINT token_auth_pkey PRIMARY KEY (token)
)
''')

View File

@ -1,6 +1,7 @@
import json
import time
import traceback
from datetime import datetime, timedelta
from typing import Union
from llm_server.cluster.cluster_config import cluster_config
@ -51,10 +52,10 @@ def do_db_log(ip: str, token: str, prompt: str, response: Union[str, None], gen_
backend_info = cluster_config.get_backend(backend_url)
running_model = backend_info.get('model')
backend_mode = backend_info['mode']
timestamp = int(time.time())
timestamp = datetime.now()
with CursorFromConnectionFromPool() as cursor:
cursor.execute("""
INSERT INTO prompts
INSERT INTO messages
(ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""",
@ -65,12 +66,12 @@ def is_valid_api_key(api_key):
with CursorFromConnectionFromPool() as cursor:
cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = %s", (api_key,))
row = cursor.fetchone()
if row is not None:
token, uses, max_uses, expire, disabled = row
disabled = bool(disabled)
if ((uses is None or max_uses is None) or uses < max_uses) and (expire is None or expire > time.time()) and not disabled:
return True
return False
if row is not None:
token, uses, max_uses, expire, disabled = row
disabled = bool(disabled)
if ((uses is None or max_uses is None) or uses < max_uses) and (expire is None or expire > time.time()) and not disabled:
return True
return False
def is_api_key_moderated(api_key):
@ -146,9 +147,9 @@ def sum_column(table_name, column_name):
def get_distinct_ips_24h():
# Get the current time and subtract 24 hours (in seconds)
past_24_hours = int(time.time()) - 24 * 60 * 60
past_24_hours = datetime.now() - timedelta(days=1)
with CursorFromConnectionFromPool() as cursor:
cursor.execute("SELECT COUNT(DISTINCT ip) FROM prompts WHERE timestamp >= %s AND (token NOT LIKE 'SYSTEM__%%' OR token IS NULL)", (past_24_hours,))
cursor.execute("SELECT COUNT(DISTINCT ip) FROM messages WHERE timestamp >= %s AND (token NOT LIKE 'SYSTEM__%%' OR token IS NULL)", (past_24_hours,))
result = cursor.fetchone()
return result[0] if result else 0

View File

@ -1,10 +1,6 @@
# Read-only global variables
default_openai_system_prompt = """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nLines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user.\n\n"""
# cluster = {}
DEFAULT_OPENAI_SYSTEM_PROMPT = """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nLines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user.\n\n"""
REDIS_STREAM_TIMEOUT = 25000
LOGGING_FORMAT = "%(asctime)s: %(levelname)s:%(name)s - %(message)s"
BACKEND_OFFLINE = 'The model you requested is not a valid choice. Please retry your query.'

View File

@ -15,19 +15,6 @@ def resolve_path(*p: str):
return Path(*p).expanduser().resolve().absolute()
def safe_list_get(l, idx, default):
"""
https://stackoverflow.com/a/5125636
:param l:
:param idx:
:param default:
:return:
"""
try:
return l[idx]
except IndexError:
return default
def deep_sort(obj):
if isinstance(obj, dict):

View File

@ -1,4 +1,4 @@
from llm_server import opts
from llm_server import globals
from llm_server.cluster.cluster_config import cluster_config

View File

@ -8,14 +8,14 @@ import requests
from llm_server.config.global_config import GlobalConfig
def generate(json_data: dict):
try:
r = requests.post(f'{GlobalConfig.get().backend_url}/api/v1/generate', json=json_data, verify=GlobalConfig.get().verify_ssl, timeout=GlobalConfig.get().backend_generate_request_timeout)
except requests.exceptions.ReadTimeout:
return False, None, 'Request to backend timed out'
except Exception as e:
traceback.print_exc()
return False, None, 'Request to backend encountered error'
if r.status_code != 200:
return False, r, f'Backend returned {r.status_code}'
return True, r, None
# def generate(json_data: dict):
# try:
# r = requests.post(f'{GlobalConfig.get().backend_url}/api/v1/generate', json=json_data, verify=GlobalConfig.get().verify_ssl, timeout=GlobalConfig.get().backend_generate_request_timeout)
# except requests.exceptions.ReadTimeout:
# return False, None, 'Request to backend timed out'
# except Exception as e:
# traceback.print_exc()
# return False, None, 'Request to backend encountered error'
# if r.status_code != 200:
# return False, r, f'Backend returned {r.status_code}'
# return True, r, None

View File

@ -98,3 +98,14 @@ def return_invalid_model_err(requested_model: str):
"code": "model_not_found"
}
}), 404
def return_oai_internal_server_error():
return jsonify({
"error": {
"message": "Internal server error",
"type": "auth_subrequest_error",
"param": None,
"code": "internal_error"
}
}), 500

View File

@ -31,7 +31,7 @@ def handle_blocking_request(json_data: dict, cluster_backend, timeout: int = 10)
return False, None, 'Request to backend timed out'
except Exception as e:
# print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}')
return False, None, 'Request to backend encountered error'
return False, None, f'Request to backend encountered error -- {e.__class__.__name__}: {e}'
if r.status_code != 200:
# print(f'Failed to reach VLLM inference endpoint - got code {r.status_code}')
return False, r, f'Backend returned {r.status_code}'

View File

@ -1,16 +1,14 @@
import logging
import sys
from pathlib import Path
import coloredlogs
from llm_server import opts
from llm_server import globals
class LoggingInfo:
def __init__(self):
self._level = logging.INFO
self._format = opts.LOGGING_FORMAT
self._format = globals.LOGGING_FORMAT
@property
def level(self):
@ -30,30 +28,17 @@ class LoggingInfo:
logging_info = LoggingInfo()
LOG_DIRECTORY = None
def init_logging(filepath: Path = None):
def init_logging():
"""
Set up the parent logger. Ensures this logger and all children to log to a file.
This is only called by `server.py` since there is wierdness with Gunicorn. The deamon doesn't need this.
:return:
"""
global LOG_DIRECTORY
logger = logging.getLogger('llm_server')
logger.setLevel(logging_info.level)
if filepath:
p = Path(filepath)
if not p.parent.is_dir():
logger.fatal(f'Log directory does not exist: {p.parent}')
sys.exit(1)
LOG_DIRECTORY = p.parent
handler = logging.FileHandler(filepath)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
def create_logger(name):
logger = logging.getLogger('llm_server').getChild(name)
@ -65,7 +50,4 @@ def create_logger(name):
handler.setFormatter(formatter)
logger.addHandler(handler)
coloredlogs.install(logger=logger, level=logging_info.level)
if LOG_DIRECTORY:
handler = logging.FileHandler(LOG_DIRECTORY / f'{name}.log')
logger.addHandler(handler)
return logger

View File

@ -1 +0,0 @@
BACKEND_OFFLINE = 'The model you requested is not a valid choice. Please retry your query.'

View File

@ -3,7 +3,7 @@ from typing import Tuple
import flask
from flask import jsonify, request
from llm_server import messages
import llm_server.globals
from llm_server.config.global_config import GlobalConfig
from llm_server.database.log_to_db import log_to_db
from llm_server.logging import create_logger
@ -21,7 +21,7 @@ class OobaRequestHandler(RequestHandler):
assert not self.used
if self.offline:
# _logger.debug(f'This backend is offline.')
return self.handle_error(messages.BACKEND_OFFLINE)
return self.handle_error(llm_server.globals.BACKEND_OFFLINE)
request_valid, invalid_response = self.validate_request()
if not request_valid:

View File

@ -2,6 +2,7 @@ from flask import Blueprint
from ..request_handler import before_request
from ...config.global_config import GlobalConfig
from ...llm.openai.oai_to_vllm import return_oai_internal_server_error
from ...logging import create_logger
_logger = create_logger('OpenAI')
@ -26,15 +27,8 @@ def handle_error(e):
"auth_subrequest_error"
"""
_logger(f'OAI returning error: {e}')
return jsonify({
"error": {
"message": "Internal server error",
"type": "auth_subrequest_error",
"param": None,
"code": "internal_error"
}
}), 500
_logger.error(f'OAI returning error: {e}')
return_oai_internal_server_error()
from .models import openai_list_models

View File

@ -13,7 +13,7 @@ from ..openai_request_handler import OpenAIRequestHandler
from ..queue import priority_queue
from ...config.global_config import GlobalConfig
from ...database.log_to_db import log_to_db
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai, return_oai_internal_server_error
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
from ...logging import create_logger
@ -32,7 +32,8 @@ def openai_chat_completions(model_name=None):
else:
handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name)
if handler.offline:
return return_invalid_model_err(model_name)
# return return_invalid_model_err(model_name)
return_oai_internal_server_error()
if not request_json_body.get('stream'):
try:

View File

@ -14,7 +14,7 @@ from llm_server.custom_redis import redis
from llm_server.database.database import is_api_key_moderated
from llm_server.database.log_to_db import log_to_db
from llm_server.llm import get_token_count
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_invalid_model_err
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_invalid_model_err, return_oai_internal_server_error
from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
from llm_server.logging import create_logger
from llm_server.routes.request_handler import RequestHandler
@ -31,9 +31,10 @@ class OpenAIRequestHandler(RequestHandler):
def handle_request(self) -> Tuple[flask.Response, int]:
assert not self.used
if self.offline:
msg = return_invalid_model_err(self.selected_model)
_logger.error(f'OAI is offline: {msg}')
return self.handle_error(msg)
# msg = return_invalid_model_err(self.selected_model)
# _logger.error(f'OAI is offline: {msg}')
# return self.handle_error(msg)
return_oai_internal_server_error()
if GlobalConfig.get().openai_silent_trim:
oai_messages = trim_messages_to_fit(self.request.json['messages'], self.cluster_backend_info['model_config']['max_position_embeddings'], self.backend_url)
@ -109,7 +110,7 @@ class OpenAIRequestHandler(RequestHandler):
return response, 429
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
_logger.error('OAI Error: {error_msg}')
_logger.error(f'OAI Error: {error_msg}')
return jsonify({
"error": {
"message": "Invalid request, check your parameters and try again.",

View File

@ -34,7 +34,7 @@ def generate_stats(regen: bool = False):
'proompts_total': get_total_proompts() if GlobalConfig.get().show_num_prompts else None,
'uptime': int((datetime.now() - server_start_time).total_seconds()) if GlobalConfig.get().show_uptime else None,
# 'estimated_avg_tps': estimated_avg_tps,
'tokens_generated': sum_column('prompts', 'response_tokens') if GlobalConfig.get().show_total_output_tokens else None,
'tokens_generated': sum_column('messages', 'response_tokens') if GlobalConfig.get().show_total_output_tokens else None,
'num_backends': len(cluster_config.all()) if GlobalConfig.get().show_backends else None,
},
'endpoints': {

View File

@ -8,6 +8,7 @@ import ujson
from redis import Redis
from llm_server.cluster.cluster_config import cluster_config
from llm_server.config.global_config import GlobalConfig
from llm_server.custom_redis import RedisCustom, redis
from llm_server.llm.generator import generator
from llm_server.logging import create_logger
@ -148,12 +149,12 @@ def worker(backend_url):
status_redis.setp(str(worker_id), None)
def start_workers(cluster: dict):
def start_workers():
logger = create_logger('inferencer')
i = 0
for item in cluster:
for _ in range(item['concurrent_gens']):
t = threading.Thread(target=worker, args=(item['backend_url'],))
for item in GlobalConfig.get().cluster:
for _ in range(item.concurrent_gens):
t = threading.Thread(target=worker, args=(item.backend_url,))
t.daemon = True
t.start()
i += 1

View File

@ -49,10 +49,10 @@ def main_background_thread():
def calc_stats_for_backend(backend_url, running_model, backend_mode):
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now.
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time',
average_generation_elapsed_sec = weighted_average_column_for_model('messages', 'generation_time',
running_model, backend_mode, backend_url, exclude_zeros=True,
include_system_tokens=GlobalConfig.get().include_system_tokens_in_stats) or 0
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens',
average_output_tokens = weighted_average_column_for_model('messages', 'response_tokens',
running_model, backend_mode, backend_url, exclude_zeros=True,
include_system_tokens=GlobalConfig.get().include_system_tokens_in_stats) or 0
estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero

View File

@ -22,7 +22,7 @@ def cache_stats():
def start_background():
logger = create_logger('threader')
start_workers(GlobalConfig.get().cluster)
start_workers()
t = Thread(target=main_background_thread)
t.daemon = True
@ -46,7 +46,7 @@ def start_background():
t = Thread(target=console_printer)
t.daemon = True
t.start()
logger.info('Started the console logger.infoer.')
logger.info('Started the console logger.')
t = Thread(target=cluster_worker)
t.daemon = True

54
other/gunicorn_conf.py Normal file
View File

@ -0,0 +1,54 @@
from llm_server.helpers import resolve_path
try:
import gevent.monkey
gevent.monkey.patch_all()
except ImportError:
pass
import logging
import os
import sys
import time
from pathlib import Path
from llm_server.config.global_config import GlobalConfig
from llm_server.config.load import load_config
from llm_server.custom_redis import redis
from llm_server.database.conn import Database
from llm_server.database.create import create_db
from llm_server.logging import init_logging, create_logger
def post_fork(server, worker):
"""
Initalize the worker after gunicorn has forked. This is done to avoid issues with the database manager.
"""
script_path = Path(os.path.dirname(os.path.realpath(__file__)))
config_path_environ = os.getenv("CONFIG_PATH")
if config_path_environ:
config_path = config_path_environ
else:
config_path = script_path / '../config/config.yml'
config_path = resolve_path(config_path)
success, msg = load_config(config_path)
if not success:
logger = logging.getLogger('llm_server')
logger.setLevel(logging.INFO)
logger.error(f'Failed to load config: {msg}')
sys.exit(1)
init_logging()
logger = create_logger('Server')
logger.debug('Debug logging enabled.')
while not redis.get('daemon_started', dtype=bool):
logger.warning('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?')
time.sleep(10)
Database.initialise(**GlobalConfig.get().postgresql.dict())
create_db()
logger.info('Started HTTP worker!')

View File

@ -11,6 +11,8 @@ WorkingDirectory=/srv/server/local-llm-server
# Sometimes the old processes aren't terminated when the service is restarted.
ExecStartPre=/usr/bin/pkill -9 -f "/srv/server/local-llm-server/venv/bin/python3 /srv/server/local-llm-server/venv/bin/gunicorn"
# TODO: make sure gunicorn logs to stdout and logging also goes to stdout
# Need a lot of workers since we have long-running requests. This takes about 3.5G memory.
ExecStart=/srv/server/local-llm-server/venv/bin/gunicorn --workers 20 --bind 0.0.0.0:5000 server:app --timeout 60 --worker-class gevent --access-logfile '-' --error-logfile '-'

View File

@ -4,7 +4,6 @@ Flask-Caching==2.0.2
requests~=2.31.0
tiktoken~=0.5.0
gevent~=23.9.0.post1
mysql-connector-python==8.4.0
simplejson~=3.19.1
websockets~=11.0.3
basicauth~=1.0.0
@ -17,3 +16,4 @@ vllm==0.2.7
coloredlogs~=15.0.1
git+https://git.evulid.cc/cyberes/bison.git
pydantic
psycopg2-binary==2.9.9

View File

@ -1,33 +1,13 @@
import time
from llm_server.config.global_config import GlobalConfig
try:
import gevent.monkey
gevent.monkey.patch_all()
except ImportError:
pass
import logging
import os
import sys
from pathlib import Path
import simplejson as json
from flask import Flask, jsonify, render_template, request, Response
import config
from llm_server.cluster.backend import get_model_choices
from llm_server.cluster.cluster_config import cluster_config
from llm_server.config.config import mode_ui_names
from llm_server.config.load import load_config
from llm_server.config.config import MODE_UI_NAMES
from llm_server.config.global_config import GlobalConfig
from llm_server.custom_redis import flask_cache, redis
from llm_server.database.conn import Database
from llm_server.database.create import create_db
from llm_server.helpers import auto_set_base_client_api
from llm_server.llm.vllm.info import vllm_info
from llm_server.logging import init_logging, create_logger
from llm_server.routes.openai import openai_bp, openai_model_bp
from llm_server.routes.server_error import handle_server_error
from llm_server.routes.v1 import bp
@ -63,32 +43,6 @@ from llm_server.sock import init_wssocket
# TODO: add more excluding to SYSTEM__ tokens
# TODO: return 200 when returning formatted sillytavern error
script_path = os.path.dirname(os.path.realpath(__file__))
config_path_environ = os.getenv("CONFIG_PATH")
if config_path_environ:
config_path = config_path_environ
else:
config_path = Path(script_path, 'config', 'config.yml')
success, msg = load_config(config_path)
if not success:
logger = logging.getLogger('llm_server')
logger.setLevel(logging.INFO)
logger.error(f'Failed to load config: {msg}')
sys.exit(1)
init_logging(Path(GlobalConfig.get().webserver_log_directory) / 'server.log')
logger = create_logger('Server')
logger.debug('Debug logging enabled.')
while not redis.get('daemon_started', dtype=bool):
logger.warning('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?')
time.sleep(10)
logger.info('Started HTTP worker!')
Database.initialise(maxconn=GlobalConfig.get().mysql.maxconn, host=GlobalConfig.get().mysql.host, user=GlobalConfig.get().mysql.username, password=GlobalConfig.get().mysql.password, database=GlobalConfig.get().mysql.database)
create_db()
app = Flask(__name__)
@ -139,13 +93,13 @@ def home():
# to None by the daemon.
default_model_info['context_size'] = '-'
if len(config['analytics_tracking_code']):
analytics_tracking_code = f"<script>\n{config['analytics_tracking_code']}\n</script>"
if len(GlobalConfig.get().analytics_tracking_code):
analytics_tracking_code = f"<script>\n{GlobalConfig.get().analytics_tracking_code}\n</script>"
else:
analytics_tracking_code = ''
if config['info_html']:
info_html = config['info_html']
if GlobalConfig.get().info_html:
info_html = GlobalConfig.get().info_html
else:
info_html = ''
@ -166,9 +120,9 @@ def home():
client_api=f'https://{base_client_api}',
ws_client_api=f'wss://{base_client_api}/v1/stream' if GlobalConfig.get().enable_streaming else 'disabled',
default_estimated_wait=default_estimated_wait_sec,
mode_name=mode_ui_names[GlobalConfig.get().frontend_api_mode][0],
api_input_textbox=mode_ui_names[GlobalConfig.get().frontend_api_mode][1],
streaming_input_textbox=mode_ui_names[GlobalConfig.get().frontend_api_mode][2],
mode_name=MODE_UI_NAMES[GlobalConfig.get().frontend_api_mode].name,
api_input_textbox=MODE_UI_NAMES[GlobalConfig.get().frontend_api_mode].api_name,
streaming_input_textbox=MODE_UI_NAMES[GlobalConfig.get().frontend_api_mode].streaming_name,
default_context_size=default_model_info['context_size'],
stats_json=json.dumps(stats, indent=4, ensure_ascii=False),
extra_info=mode_info,
@ -212,6 +166,6 @@ def before_app_request():
if __name__ == "__main__":
# server_startup(None)
print('FLASK MODE - Startup complete!')
app.run(host='0.0.0.0', threaded=False, processes=15)
print('Do not run this file directly. Instead, use gunicorn:')
print("gunicorn -c other/gunicorn_conf.py server:app -b 0.0.0.0:5000 --worker-class gevent --workers 3 --access-logfile '-' --error-logfile '-'")
quit(1)