refactor a lot of things, major cleanup, use postgresql
This commit is contained in:
parent
ee9a0d4858
commit
fd09c783d3
|
@ -3,7 +3,6 @@ import logging
|
|||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from redis import Redis
|
||||
|
||||
|
@ -14,6 +13,7 @@ from llm_server.custom_redis import redis
|
|||
from llm_server.database.conn import Database
|
||||
from llm_server.database.create import create_db
|
||||
from llm_server.database.database import get_number_of_rows
|
||||
from llm_server.helpers import resolve_path
|
||||
from llm_server.logging import create_logger, logging_info, init_logging
|
||||
from llm_server.routes.v1.generate_stats import generate_stats
|
||||
from llm_server.workers.threader import start_background
|
||||
|
@ -23,7 +23,7 @@ config_path_environ = os.getenv("CONFIG_PATH")
|
|||
if config_path_environ:
|
||||
config_path = config_path_environ
|
||||
else:
|
||||
config_path = Path(script_path, 'config', 'config.yml')
|
||||
config_path = resolve_path(script_path, 'config', 'config.yml')
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Daemon microservice.')
|
||||
|
@ -47,7 +47,7 @@ if __name__ == "__main__":
|
|||
logger.info(f'Failed to load config: {msg}')
|
||||
sys.exit(1)
|
||||
|
||||
Database.initialise(maxconn=GlobalConfig.get().mysql.maxconn, host=GlobalConfig.get().mysql.host, user=GlobalConfig.get().mysql.username, password=GlobalConfig.get().mysql.password, database=GlobalConfig.get().mysql.database)
|
||||
Database.initialise(**GlobalConfig.get().postgresql.dict())
|
||||
create_db()
|
||||
|
||||
cluster_config.clear()
|
||||
|
@ -57,7 +57,7 @@ if __name__ == "__main__":
|
|||
generate_stats(regen=True)
|
||||
|
||||
if GlobalConfig.get().load_num_prompts:
|
||||
redis.set('proompts', get_number_of_rows('prompts'))
|
||||
redis.set('proompts', get_number_of_rows('messages'))
|
||||
|
||||
start_background()
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ def get_backends_from_model(model_name: str):
|
|||
:param model_name:
|
||||
:return:
|
||||
"""
|
||||
assert isinstance(model_name, str)
|
||||
return [x.decode('utf-8') for x in redis_running_models.smembers(model_name)]
|
||||
|
||||
|
||||
|
@ -25,7 +26,7 @@ def get_running_models():
|
|||
Get all the models that are in the cluster.
|
||||
:return:
|
||||
"""
|
||||
return list(redis_running_models.keys())
|
||||
return [x.decode('utf-8') for x in list(redis_running_models.keys())]
|
||||
|
||||
|
||||
def is_valid_model(model_name: str) -> bool:
|
||||
|
@ -81,6 +82,7 @@ def get_model_choices(regen: bool = False) -> tuple[dict, dict]:
|
|||
|
||||
base_client_api = redis.get('base_client_api', dtype=str)
|
||||
running_models = get_running_models()
|
||||
|
||||
model_choices = {}
|
||||
for model in running_models:
|
||||
b = get_backends_from_model(model)
|
||||
|
|
|
@ -33,7 +33,7 @@ class RedisClusterStore:
|
|||
item.backend_url = backend_url
|
||||
stuff[backend_url] = item
|
||||
for k, v in stuff.items():
|
||||
self.add_backend(k, v)
|
||||
self.add_backend(k, v.dict())
|
||||
|
||||
def add_backend(self, name: str, values: dict):
|
||||
self.config_redis.hset(name, mapping={k: pickle.dumps(v) for k, v in values.items()})
|
||||
|
|
|
@ -1,14 +1,22 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
from llm_server.config.global_config import GlobalConfig
|
||||
|
||||
|
||||
def cluster_worker_count():
|
||||
count = 0
|
||||
for item in GlobalConfig.get().cluster:
|
||||
count += item['concurrent_gens']
|
||||
count += item.concurrent_gens
|
||||
return count
|
||||
|
||||
|
||||
mode_ui_names = {
|
||||
'ooba': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
|
||||
'vllm': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
|
||||
class ModeUINameStr(BaseModel):
|
||||
name: str
|
||||
api_name: str
|
||||
streaming_name: str
|
||||
|
||||
|
||||
MODE_UI_NAMES = {
|
||||
'ooba': ModeUINameStr(name='Text Gen WebUI (ooba)', api_name='Blocking API url', streaming_name='Streaming API url'),
|
||||
'vllm': ModeUINameStr(name='Text Gen WebUI (ooba)', api_name='Blocking API url', streaming_name='Streaming API url'),
|
||||
}
|
||||
|
|
|
@ -63,8 +63,8 @@ def load_config(config_path: Path):
|
|||
config_model = ConfigModel(**config.config)
|
||||
GlobalConfig.initalize(config_model)
|
||||
|
||||
if not (0 < GlobalConfig.get().mysql.maxconn <= 32):
|
||||
return False, f'"maxcon" should be higher than 0 and lower or equal to 32. Current value: "{GlobalConfig.get().mysql.maxconn}"'
|
||||
if GlobalConfig.get().postgresql.maxconn < 0:
|
||||
return False, f'"maxcon" should be higher than 0. Current value: "{GlobalConfig.get().postgresql.maxconn}"'
|
||||
|
||||
openai.api_key = GlobalConfig.get().openai_api_key
|
||||
|
||||
|
|
|
@ -19,9 +19,9 @@ class ConfigFrontendApiModes(str, Enum):
|
|||
ooba = 'ooba'
|
||||
|
||||
|
||||
class ConfigMysql(BaseModel):
|
||||
class ConfigPostgresql(BaseModel):
|
||||
host: str
|
||||
username: str
|
||||
user: str
|
||||
password: str
|
||||
database: str
|
||||
maxconn: int
|
||||
|
@ -37,9 +37,8 @@ class ConfigModel(BaseModel):
|
|||
cluster: List[ConfigCluser]
|
||||
prioritize_by_size: bool
|
||||
admin_token: Union[str, None]
|
||||
mysql: ConfigMysql
|
||||
postgresql: ConfigPostgresql
|
||||
http_host: str
|
||||
webserver_log_directory: str
|
||||
include_system_tokens_in_stats: bool
|
||||
background_homepage_cacher: bool
|
||||
max_new_tokens: int
|
||||
|
@ -55,6 +54,7 @@ class ConfigModel(BaseModel):
|
|||
info_html: Union[str, None]
|
||||
enable_openi_compatible_backend: bool
|
||||
openai_api_key: Union[str, None]
|
||||
openai_system_prompt: str
|
||||
expose_openai_system_prompt: bool
|
||||
openai_expose_our_model: bool
|
||||
openai_force_no_hashes: bool
|
||||
|
@ -72,3 +72,4 @@ class ConfigModel(BaseModel):
|
|||
load_num_prompts: bool
|
||||
manual_model_name: Union[str, None]
|
||||
backend_request_timeout: int
|
||||
backend_generate_request_timeout: int
|
||||
|
|
|
@ -2,7 +2,7 @@ from typing import Union
|
|||
|
||||
import bison
|
||||
|
||||
from llm_server.opts import default_openai_system_prompt
|
||||
from llm_server.globals import DEFAULT_OPENAI_SYSTEM_PROMPT
|
||||
|
||||
config_scheme = bison.Scheme(
|
||||
bison.Option('frontend_api_mode', choices=['ooba'], field_type=str),
|
||||
|
@ -14,15 +14,14 @@ config_scheme = bison.Scheme(
|
|||
)),
|
||||
bison.Option('prioritize_by_size', default=True, field_type=bool),
|
||||
bison.Option('admin_token', default=None, field_type=Union[str, None]),
|
||||
bison.ListOption('mysql', member_scheme=bison.Scheme(
|
||||
bison.ListOption('postgresql', member_scheme=bison.Scheme(
|
||||
bison.Option('host', field_type=str),
|
||||
bison.Option('username', field_type=str),
|
||||
bison.Option('user', field_type=str),
|
||||
bison.Option('password', field_type=str),
|
||||
bison.Option('database', field_type=str),
|
||||
bison.Option('maxconn', field_type=int)
|
||||
)),
|
||||
bison.Option('http_host', default='', field_type=str),
|
||||
bison.Option('webserver_log_directory', default='/var/log/localllm', field_type=str),
|
||||
bison.Option('include_system_tokens_in_stats', default=True, field_type=bool),
|
||||
bison.Option('background_homepage_cacher', default=True, field_type=bool),
|
||||
bison.Option('max_new_tokens', default=500, field_type=int),
|
||||
|
@ -41,7 +40,7 @@ config_scheme = bison.Scheme(
|
|||
bison.Option('expose_openai_system_prompt', default=True, field_type=bool),
|
||||
bison.Option('openai_expose_our_model', default='', field_type=bool),
|
||||
bison.Option('openai_force_no_hashes', default=True, field_type=bool),
|
||||
bison.Option('openai_system_prompt', default=default_openai_system_prompt, field_type=str),
|
||||
bison.Option('openai_system_prompt', default=DEFAULT_OPENAI_SYSTEM_PROMPT, field_type=str),
|
||||
bison.Option('openai_moderation_enabled', default=False, field_type=bool),
|
||||
bison.Option('openai_moderation_timeout', default=5, field_type=int),
|
||||
bison.Option('openai_moderation_scan_last_n', default=5, field_type=int),
|
||||
|
@ -55,5 +54,6 @@ config_scheme = bison.Scheme(
|
|||
bison.Option('show_backend_info', default=True, field_type=bool),
|
||||
bison.Option('load_num_prompts', default=True, field_type=bool),
|
||||
bison.Option('manual_model_name', default=None, field_type=Union[str, None]),
|
||||
bison.Option('backend_request_timeout', default=30, field_type=int)
|
||||
bison.Option('backend_request_timeout', default=30, field_type=int),
|
||||
bison.Option('backend_generate_request_timeout', default=95, field_type=int)
|
||||
)
|
||||
|
|
|
@ -2,13 +2,13 @@ import logging
|
|||
import pickle
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Callable, List, Mapping, Optional, Union
|
||||
from typing import Union
|
||||
|
||||
import redis as redis_pkg
|
||||
import simplejson as json
|
||||
from flask_caching import Cache
|
||||
from redis import Redis
|
||||
from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, PatternT, ZScoreBoundT, AbsExpiryT
|
||||
from redis.typing import ExpiryT, KeyT, PatternT
|
||||
|
||||
flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/15', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
|
||||
|
||||
|
@ -38,18 +38,11 @@ class RedisCustom(Redis):
|
|||
def _key(self, key):
|
||||
return f"{self.prefix}:{key}"
|
||||
|
||||
def set(self, key: KeyT,
|
||||
value: EncodableT,
|
||||
ex: Union[ExpiryT, None] = None,
|
||||
px: Union[ExpiryT, None] = None,
|
||||
nx: bool = False,
|
||||
xx: bool = False,
|
||||
keepttl: bool = False,
|
||||
get: bool = False,
|
||||
exat: Union[AbsExpiryT, None] = None,
|
||||
pxat: Union[AbsExpiryT, None] = None
|
||||
):
|
||||
return self.redis.set(self._key(key), value, ex=ex)
|
||||
def execute_command(self, *args, **options):
|
||||
if args[0] != 'GET':
|
||||
args = list(args)
|
||||
args[1] = self._key(args[1])
|
||||
return super().execute_command(*args, **options)
|
||||
|
||||
def get(self, key, default=None, dtype=None):
|
||||
# TODO: use pickle
|
||||
|
@ -73,103 +66,6 @@ class RedisCustom(Redis):
|
|||
else:
|
||||
return d
|
||||
|
||||
def incr(self, key, amount=1):
|
||||
return self.redis.incr(self._key(key), amount)
|
||||
|
||||
def decr(self, key, amount=1):
|
||||
return self.redis.decr(self._key(key), amount)
|
||||
|
||||
def sadd(self, key: str, *values: FieldT):
|
||||
return self.redis.sadd(self._key(key), *values)
|
||||
|
||||
def srem(self, key: str, *values: FieldT):
|
||||
return self.redis.srem(self._key(key), *values)
|
||||
|
||||
def sismember(self, key: str, value: str):
|
||||
return self.redis.sismember(self._key(key), value)
|
||||
|
||||
def lindex(
|
||||
self, name: str, index: int
|
||||
):
|
||||
return self.redis.lindex(self._key(name), index)
|
||||
|
||||
def lrem(self, name: str, count: int, value: str):
|
||||
return self.redis.lrem(self._key(name), count, value)
|
||||
|
||||
def rpush(self, name: str, *values: FieldT):
|
||||
return self.redis.rpush(self._key(name), *values)
|
||||
|
||||
def llen(self, name: str):
|
||||
return self.redis.llen(self._key(name))
|
||||
|
||||
def zrangebyscore(
|
||||
self,
|
||||
name: KeyT,
|
||||
min: ZScoreBoundT,
|
||||
max: ZScoreBoundT,
|
||||
start: Union[int, None] = None,
|
||||
num: Union[int, None] = None,
|
||||
withscores: bool = False,
|
||||
score_cast_func: Union[type, Callable] = float,
|
||||
):
|
||||
return self.redis.zrangebyscore(self._key(name), min, max, start, num, withscores, score_cast_func)
|
||||
|
||||
def zremrangebyscore(
|
||||
self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT
|
||||
):
|
||||
return self.redis.zremrangebyscore(self._key(name), min, max)
|
||||
|
||||
def hincrby(
|
||||
self, name: str, key: str, amount: int = 1
|
||||
):
|
||||
return self.redis.hincrby(self._key(name), key, amount)
|
||||
|
||||
def zcard(self, name: KeyT):
|
||||
return self.redis.zcard(self._key(name))
|
||||
|
||||
def hdel(self, name: str, *keys: str):
|
||||
return self.redis.hdel(self._key(name), *keys)
|
||||
|
||||
def hget(
|
||||
self, name: str, key: str
|
||||
):
|
||||
return self.redis.hget(self._key(name), key)
|
||||
|
||||
def zadd(
|
||||
self,
|
||||
name: KeyT,
|
||||
mapping: Mapping[AnyKeyT, EncodableT],
|
||||
nx: bool = False,
|
||||
xx: bool = False,
|
||||
ch: bool = False,
|
||||
incr: bool = False,
|
||||
gt: bool = False,
|
||||
lt: bool = False,
|
||||
):
|
||||
return self.redis.zadd(self._key(name), mapping, nx, xx, ch, incr, gt, lt)
|
||||
|
||||
def lpush(self, name: str, *values: FieldT):
|
||||
return self.redis.lpush(self._key(name), *values)
|
||||
|
||||
def hset(
|
||||
self,
|
||||
name: str,
|
||||
key: Optional = None,
|
||||
value=None,
|
||||
mapping: Optional[dict] = None,
|
||||
items: Optional[list] = None,
|
||||
):
|
||||
return self.redis.hset(self._key(name), key, value, mapping, items)
|
||||
|
||||
def hkeys(self, name: str):
|
||||
return self.redis.hkeys(self._key(name))
|
||||
|
||||
def hmget(self, name: str, keys: List, *args: List):
|
||||
return self.redis.hmget(self._key(name), keys, *args)
|
||||
|
||||
def hgetall(self, name: str):
|
||||
return self.redis.hgetall(self._key(name))
|
||||
|
||||
def keys(self, pattern: PatternT = "*", **kwargs):
|
||||
raw_keys = self.redis.keys(self._key(pattern), **kwargs)
|
||||
keys = []
|
||||
|
@ -179,25 +75,9 @@ class RedisCustom(Redis):
|
|||
# Delete prefix
|
||||
del p[0]
|
||||
k = ':'.join(p)
|
||||
if k != '____':
|
||||
keys.append(k)
|
||||
# keys.append(k)
|
||||
return keys
|
||||
|
||||
def pipeline(self, transaction=True, shard_hint=None):
|
||||
return self.redis.pipeline(transaction, shard_hint)
|
||||
|
||||
def smembers(self, name: str):
|
||||
return self.redis.smembers(self._key(name))
|
||||
|
||||
def spop(self, name: str, count: Optional[int] = None):
|
||||
return self.redis.spop(self._key(name), count)
|
||||
|
||||
def rpoplpush(self, src, dst):
|
||||
return self.redis.rpoplpush(src, dst)
|
||||
|
||||
def zpopmin(self, name: KeyT, count: Union[int, None] = None):
|
||||
return self.redis.zpopmin(self._key(name), count)
|
||||
|
||||
def exists(self, *names: KeyT):
|
||||
n = []
|
||||
for name in names:
|
||||
|
@ -238,32 +118,5 @@ class RedisCustom(Redis):
|
|||
self.flush()
|
||||
return True
|
||||
|
||||
def lrange(self, name: str, start: int, end: int):
|
||||
return self.redis.lrange(self._key(name), start, end)
|
||||
|
||||
def delete(self, *names: KeyT):
|
||||
return self.redis.delete(*[self._key(i) for i in names])
|
||||
|
||||
def lpop(self, name: str, count: Optional[int] = None):
|
||||
return self.redis.lpop(self._key(name), count)
|
||||
|
||||
def zrange(
|
||||
self,
|
||||
name: KeyT,
|
||||
start: int,
|
||||
end: int,
|
||||
desc: bool = False,
|
||||
withscores: bool = False,
|
||||
score_cast_func: Union[type, Callable] = float,
|
||||
byscore: bool = False,
|
||||
bylex: bool = False,
|
||||
offset: int = None,
|
||||
num: int = None,
|
||||
):
|
||||
return self.redis.zrange(self._key(name), start, end, desc, withscores, score_cast_func, byscore, bylex, offset, num)
|
||||
|
||||
def zrem(self, name: KeyT, *values: FieldT):
|
||||
return self.redis.zrem(self._key(name), *values)
|
||||
|
||||
|
||||
redis = RedisCustom('local_llm')
|
||||
|
|
|
@ -1,39 +1,42 @@
|
|||
from mysql.connector import pooling
|
||||
from psycopg2 import pool, InterfaceError
|
||||
|
||||
|
||||
class Database:
|
||||
__connection_pool = None
|
||||
|
||||
@classmethod
|
||||
def initialise(cls, maxconn: int, **kwargs):
|
||||
def initialise(cls, maxconn, **kwargs):
|
||||
if cls.__connection_pool is not None:
|
||||
raise Exception('Database connection pool is already initialised')
|
||||
cls.__connection_pool = pooling.MySQLConnectionPool(pool_size=maxconn,
|
||||
pool_reset_session=True,
|
||||
**kwargs)
|
||||
cls.__connection_pool = pool.ThreadedConnectionPool(minconn=1, maxconn=maxconn, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_connection(cls):
|
||||
return cls.__connection_pool.get_connection()
|
||||
return cls.__connection_pool.getconn()
|
||||
|
||||
@classmethod
|
||||
def return_connection(cls, connection):
|
||||
connection.close()
|
||||
cls.__connection_pool.putconn(connection)
|
||||
|
||||
|
||||
class CursorFromConnectionFromPool:
|
||||
def __init__(self):
|
||||
def __init__(self, cursor_factory=None):
|
||||
self.conn = None
|
||||
self.cursor = None
|
||||
self.cursor_factory = cursor_factory
|
||||
|
||||
def __enter__(self):
|
||||
self.conn = Database.get_connection()
|
||||
self.cursor = self.conn.cursor()
|
||||
self.cursor = self.conn.cursor(cursor_factory=self.cursor_factory)
|
||||
return self.cursor
|
||||
|
||||
def __exit__(self, exception_type, exception_value, exception_traceback):
|
||||
if exception_value is not None: # This is equivalent of saying if there is an exception
|
||||
self.conn.rollback()
|
||||
try:
|
||||
self.conn.rollback()
|
||||
except InterfaceError as e:
|
||||
if e != 'connection already closed':
|
||||
raise
|
||||
else:
|
||||
self.cursor.close()
|
||||
self.conn.commit()
|
||||
|
|
|
@ -4,36 +4,39 @@ from llm_server.database.conn import CursorFromConnectionFromPool
|
|||
def create_db():
|
||||
with CursorFromConnectionFromPool() as cursor:
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS prompts (
|
||||
ip TEXT,
|
||||
token TEXT DEFAULT NULL,
|
||||
model TEXT,
|
||||
backend_mode TEXT,
|
||||
backend_url TEXT,
|
||||
request_url TEXT,
|
||||
generation_time FLOAT,
|
||||
prompt LONGTEXT,
|
||||
prompt_tokens INTEGER,
|
||||
response LONGTEXT,
|
||||
response_tokens INTEGER,
|
||||
response_status INTEGER,
|
||||
parameters TEXT,
|
||||
# CHECK (parameters IS NULL OR JSON_VALID(parameters)),
|
||||
headers TEXT,
|
||||
# CHECK (headers IS NULL OR JSON_VALID(headers)),
|
||||
timestamp INTEGER
|
||||
)
|
||||
CREATE TABLE IF NOT EXISTS public.messages
|
||||
(
|
||||
ip text COLLATE pg_catalog."default" NOT NULL,
|
||||
token text COLLATE pg_catalog."default",
|
||||
model text COLLATE pg_catalog."default" NOT NULL,
|
||||
backend_mode text COLLATE pg_catalog."default" NOT NULL,
|
||||
backend_url text COLLATE pg_catalog."default" NOT NULL,
|
||||
request_url text COLLATE pg_catalog."default" NOT NULL,
|
||||
generation_time double precision NOT NULL,
|
||||
prompt text COLLATE pg_catalog."default" NOT NULL,
|
||||
prompt_tokens integer NOT NULL,
|
||||
response text COLLATE pg_catalog."default" NOT NULL,
|
||||
response_tokens integer NOT NULL,
|
||||
response_status integer NOT NULL,
|
||||
parameters jsonb NOT NULL,
|
||||
headers jsonb,
|
||||
"timestamp" timestamp with time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
id SERIAL PRIMARY KEY
|
||||
);
|
||||
''')
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS token_auth (
|
||||
token TEXT,
|
||||
UNIQUE (token),
|
||||
type TEXT NOT NULL,
|
||||
priority INTEGER DEFAULT 9999,
|
||||
simultaneous_ip INTEGER DEFAULT NULL,
|
||||
uses INTEGER DEFAULT 0,
|
||||
max_uses INTEGER,
|
||||
expire INTEGER,
|
||||
disabled BOOLEAN DEFAULT 0
|
||||
)
|
||||
CREATE TABLE IF NOT EXISTS public.token_auth
|
||||
(
|
||||
token text COLLATE pg_catalog."default" NOT NULL,
|
||||
type text COLLATE pg_catalog."default" NOT NULL,
|
||||
priority integer NOT NULL DEFAULT 9999,
|
||||
simultaneous_ip text COLLATE pg_catalog."default",
|
||||
openai_moderation_enabled boolean NOT NULL DEFAULT true,
|
||||
uses integer NOT NULL DEFAULT 0,
|
||||
max_uses integer,
|
||||
expire timestamp with time zone,
|
||||
disabled boolean NOT NULL DEFAULT false,
|
||||
notes text COLLATE pg_catalog."default" NOT NULL DEFAULT ''::text,
|
||||
CONSTRAINT token_auth_pkey PRIMARY KEY (token)
|
||||
)
|
||||
''')
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import json
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Union
|
||||
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
|
@ -51,10 +52,10 @@ def do_db_log(ip: str, token: str, prompt: str, response: Union[str, None], gen_
|
|||
backend_info = cluster_config.get_backend(backend_url)
|
||||
running_model = backend_info.get('model')
|
||||
backend_mode = backend_info['mode']
|
||||
timestamp = int(time.time())
|
||||
timestamp = datetime.now()
|
||||
with CursorFromConnectionFromPool() as cursor:
|
||||
cursor.execute("""
|
||||
INSERT INTO prompts
|
||||
INSERT INTO messages
|
||||
(ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
""",
|
||||
|
@ -65,12 +66,12 @@ def is_valid_api_key(api_key):
|
|||
with CursorFromConnectionFromPool() as cursor:
|
||||
cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = %s", (api_key,))
|
||||
row = cursor.fetchone()
|
||||
if row is not None:
|
||||
token, uses, max_uses, expire, disabled = row
|
||||
disabled = bool(disabled)
|
||||
if ((uses is None or max_uses is None) or uses < max_uses) and (expire is None or expire > time.time()) and not disabled:
|
||||
return True
|
||||
return False
|
||||
if row is not None:
|
||||
token, uses, max_uses, expire, disabled = row
|
||||
disabled = bool(disabled)
|
||||
if ((uses is None or max_uses is None) or uses < max_uses) and (expire is None or expire > time.time()) and not disabled:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_api_key_moderated(api_key):
|
||||
|
@ -146,9 +147,9 @@ def sum_column(table_name, column_name):
|
|||
|
||||
def get_distinct_ips_24h():
|
||||
# Get the current time and subtract 24 hours (in seconds)
|
||||
past_24_hours = int(time.time()) - 24 * 60 * 60
|
||||
past_24_hours = datetime.now() - timedelta(days=1)
|
||||
with CursorFromConnectionFromPool() as cursor:
|
||||
cursor.execute("SELECT COUNT(DISTINCT ip) FROM prompts WHERE timestamp >= %s AND (token NOT LIKE 'SYSTEM__%%' OR token IS NULL)", (past_24_hours,))
|
||||
cursor.execute("SELECT COUNT(DISTINCT ip) FROM messages WHERE timestamp >= %s AND (token NOT LIKE 'SYSTEM__%%' OR token IS NULL)", (past_24_hours,))
|
||||
result = cursor.fetchone()
|
||||
return result[0] if result else 0
|
||||
|
||||
|
|
|
@ -1,10 +1,6 @@
|
|||
# Read-only global variables
|
||||
|
||||
default_openai_system_prompt = """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nLines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user.\n\n"""
|
||||
|
||||
# cluster = {}
|
||||
|
||||
|
||||
DEFAULT_OPENAI_SYSTEM_PROMPT = """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nLines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user.\n\n"""
|
||||
REDIS_STREAM_TIMEOUT = 25000
|
||||
|
||||
LOGGING_FORMAT = "%(asctime)s: %(levelname)s:%(name)s - %(message)s"
|
||||
BACKEND_OFFLINE = 'The model you requested is not a valid choice. Please retry your query.'
|
|
@ -15,19 +15,6 @@ def resolve_path(*p: str):
|
|||
return Path(*p).expanduser().resolve().absolute()
|
||||
|
||||
|
||||
def safe_list_get(l, idx, default):
|
||||
"""
|
||||
https://stackoverflow.com/a/5125636
|
||||
:param l:
|
||||
:param idx:
|
||||
:param default:
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
return l[idx]
|
||||
except IndexError:
|
||||
return default
|
||||
|
||||
|
||||
def deep_sort(obj):
|
||||
if isinstance(obj, dict):
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from llm_server import opts
|
||||
from llm_server import globals
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
|
||||
|
||||
|
|
|
@ -8,14 +8,14 @@ import requests
|
|||
from llm_server.config.global_config import GlobalConfig
|
||||
|
||||
|
||||
def generate(json_data: dict):
|
||||
try:
|
||||
r = requests.post(f'{GlobalConfig.get().backend_url}/api/v1/generate', json=json_data, verify=GlobalConfig.get().verify_ssl, timeout=GlobalConfig.get().backend_generate_request_timeout)
|
||||
except requests.exceptions.ReadTimeout:
|
||||
return False, None, 'Request to backend timed out'
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
return False, None, 'Request to backend encountered error'
|
||||
if r.status_code != 200:
|
||||
return False, r, f'Backend returned {r.status_code}'
|
||||
return True, r, None
|
||||
# def generate(json_data: dict):
|
||||
# try:
|
||||
# r = requests.post(f'{GlobalConfig.get().backend_url}/api/v1/generate', json=json_data, verify=GlobalConfig.get().verify_ssl, timeout=GlobalConfig.get().backend_generate_request_timeout)
|
||||
# except requests.exceptions.ReadTimeout:
|
||||
# return False, None, 'Request to backend timed out'
|
||||
# except Exception as e:
|
||||
# traceback.print_exc()
|
||||
# return False, None, 'Request to backend encountered error'
|
||||
# if r.status_code != 200:
|
||||
# return False, r, f'Backend returned {r.status_code}'
|
||||
# return True, r, None
|
||||
|
|
|
@ -98,3 +98,14 @@ def return_invalid_model_err(requested_model: str):
|
|||
"code": "model_not_found"
|
||||
}
|
||||
}), 404
|
||||
|
||||
|
||||
def return_oai_internal_server_error():
|
||||
return jsonify({
|
||||
"error": {
|
||||
"message": "Internal server error",
|
||||
"type": "auth_subrequest_error",
|
||||
"param": None,
|
||||
"code": "internal_error"
|
||||
}
|
||||
}), 500
|
||||
|
|
|
@ -31,7 +31,7 @@ def handle_blocking_request(json_data: dict, cluster_backend, timeout: int = 10)
|
|||
return False, None, 'Request to backend timed out'
|
||||
except Exception as e:
|
||||
# print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}')
|
||||
return False, None, 'Request to backend encountered error'
|
||||
return False, None, f'Request to backend encountered error -- {e.__class__.__name__}: {e}'
|
||||
if r.status_code != 200:
|
||||
# print(f'Failed to reach VLLM inference endpoint - got code {r.status_code}')
|
||||
return False, r, f'Backend returned {r.status_code}'
|
||||
|
|
|
@ -1,16 +1,14 @@
|
|||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import coloredlogs
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server import globals
|
||||
|
||||
|
||||
class LoggingInfo:
|
||||
def __init__(self):
|
||||
self._level = logging.INFO
|
||||
self._format = opts.LOGGING_FORMAT
|
||||
self._format = globals.LOGGING_FORMAT
|
||||
|
||||
@property
|
||||
def level(self):
|
||||
|
@ -30,30 +28,17 @@ class LoggingInfo:
|
|||
|
||||
|
||||
logging_info = LoggingInfo()
|
||||
LOG_DIRECTORY = None
|
||||
|
||||
|
||||
def init_logging(filepath: Path = None):
|
||||
def init_logging():
|
||||
"""
|
||||
Set up the parent logger. Ensures this logger and all children to log to a file.
|
||||
This is only called by `server.py` since there is wierdness with Gunicorn. The deamon doesn't need this.
|
||||
:return:
|
||||
"""
|
||||
global LOG_DIRECTORY
|
||||
logger = logging.getLogger('llm_server')
|
||||
logger.setLevel(logging_info.level)
|
||||
|
||||
if filepath:
|
||||
p = Path(filepath)
|
||||
if not p.parent.is_dir():
|
||||
logger.fatal(f'Log directory does not exist: {p.parent}')
|
||||
sys.exit(1)
|
||||
LOG_DIRECTORY = p.parent
|
||||
handler = logging.FileHandler(filepath)
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
|
||||
def create_logger(name):
|
||||
logger = logging.getLogger('llm_server').getChild(name)
|
||||
|
@ -65,7 +50,4 @@ def create_logger(name):
|
|||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
coloredlogs.install(logger=logger, level=logging_info.level)
|
||||
if LOG_DIRECTORY:
|
||||
handler = logging.FileHandler(LOG_DIRECTORY / f'{name}.log')
|
||||
logger.addHandler(handler)
|
||||
return logger
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
BACKEND_OFFLINE = 'The model you requested is not a valid choice. Please retry your query.'
|
|
@ -3,7 +3,7 @@ from typing import Tuple
|
|||
import flask
|
||||
from flask import jsonify, request
|
||||
|
||||
from llm_server import messages
|
||||
import llm_server.globals
|
||||
from llm_server.config.global_config import GlobalConfig
|
||||
from llm_server.database.log_to_db import log_to_db
|
||||
from llm_server.logging import create_logger
|
||||
|
@ -21,7 +21,7 @@ class OobaRequestHandler(RequestHandler):
|
|||
assert not self.used
|
||||
if self.offline:
|
||||
# _logger.debug(f'This backend is offline.')
|
||||
return self.handle_error(messages.BACKEND_OFFLINE)
|
||||
return self.handle_error(llm_server.globals.BACKEND_OFFLINE)
|
||||
|
||||
request_valid, invalid_response = self.validate_request()
|
||||
if not request_valid:
|
||||
|
|
|
@ -2,6 +2,7 @@ from flask import Blueprint
|
|||
|
||||
from ..request_handler import before_request
|
||||
from ...config.global_config import GlobalConfig
|
||||
from ...llm.openai.oai_to_vllm import return_oai_internal_server_error
|
||||
from ...logging import create_logger
|
||||
|
||||
_logger = create_logger('OpenAI')
|
||||
|
@ -26,15 +27,8 @@ def handle_error(e):
|
|||
"auth_subrequest_error"
|
||||
"""
|
||||
|
||||
_logger(f'OAI returning error: {e}')
|
||||
return jsonify({
|
||||
"error": {
|
||||
"message": "Internal server error",
|
||||
"type": "auth_subrequest_error",
|
||||
"param": None,
|
||||
"code": "internal_error"
|
||||
}
|
||||
}), 500
|
||||
_logger.error(f'OAI returning error: {e}')
|
||||
return_oai_internal_server_error()
|
||||
|
||||
|
||||
from .models import openai_list_models
|
||||
|
|
|
@ -13,7 +13,7 @@ from ..openai_request_handler import OpenAIRequestHandler
|
|||
from ..queue import priority_queue
|
||||
from ...config.global_config import GlobalConfig
|
||||
from ...database.log_to_db import log_to_db
|
||||
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
|
||||
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai, return_oai_internal_server_error
|
||||
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
||||
from ...logging import create_logger
|
||||
|
||||
|
@ -32,7 +32,8 @@ def openai_chat_completions(model_name=None):
|
|||
else:
|
||||
handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name)
|
||||
if handler.offline:
|
||||
return return_invalid_model_err(model_name)
|
||||
# return return_invalid_model_err(model_name)
|
||||
return_oai_internal_server_error()
|
||||
|
||||
if not request_json_body.get('stream'):
|
||||
try:
|
||||
|
|
|
@ -14,7 +14,7 @@ from llm_server.custom_redis import redis
|
|||
from llm_server.database.database import is_api_key_moderated
|
||||
from llm_server.database.log_to_db import log_to_db
|
||||
from llm_server.llm import get_token_count
|
||||
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_invalid_model_err
|
||||
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_invalid_model_err, return_oai_internal_server_error
|
||||
from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
||||
from llm_server.logging import create_logger
|
||||
from llm_server.routes.request_handler import RequestHandler
|
||||
|
@ -31,9 +31,10 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
def handle_request(self) -> Tuple[flask.Response, int]:
|
||||
assert not self.used
|
||||
if self.offline:
|
||||
msg = return_invalid_model_err(self.selected_model)
|
||||
_logger.error(f'OAI is offline: {msg}')
|
||||
return self.handle_error(msg)
|
||||
# msg = return_invalid_model_err(self.selected_model)
|
||||
# _logger.error(f'OAI is offline: {msg}')
|
||||
# return self.handle_error(msg)
|
||||
return_oai_internal_server_error()
|
||||
|
||||
if GlobalConfig.get().openai_silent_trim:
|
||||
oai_messages = trim_messages_to_fit(self.request.json['messages'], self.cluster_backend_info['model_config']['max_position_embeddings'], self.backend_url)
|
||||
|
@ -109,7 +110,7 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
return response, 429
|
||||
|
||||
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
|
||||
_logger.error('OAI Error: {error_msg}')
|
||||
_logger.error(f'OAI Error: {error_msg}')
|
||||
return jsonify({
|
||||
"error": {
|
||||
"message": "Invalid request, check your parameters and try again.",
|
||||
|
|
|
@ -34,7 +34,7 @@ def generate_stats(regen: bool = False):
|
|||
'proompts_total': get_total_proompts() if GlobalConfig.get().show_num_prompts else None,
|
||||
'uptime': int((datetime.now() - server_start_time).total_seconds()) if GlobalConfig.get().show_uptime else None,
|
||||
# 'estimated_avg_tps': estimated_avg_tps,
|
||||
'tokens_generated': sum_column('prompts', 'response_tokens') if GlobalConfig.get().show_total_output_tokens else None,
|
||||
'tokens_generated': sum_column('messages', 'response_tokens') if GlobalConfig.get().show_total_output_tokens else None,
|
||||
'num_backends': len(cluster_config.all()) if GlobalConfig.get().show_backends else None,
|
||||
},
|
||||
'endpoints': {
|
||||
|
|
|
@ -8,6 +8,7 @@ import ujson
|
|||
from redis import Redis
|
||||
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.config.global_config import GlobalConfig
|
||||
from llm_server.custom_redis import RedisCustom, redis
|
||||
from llm_server.llm.generator import generator
|
||||
from llm_server.logging import create_logger
|
||||
|
@ -148,12 +149,12 @@ def worker(backend_url):
|
|||
status_redis.setp(str(worker_id), None)
|
||||
|
||||
|
||||
def start_workers(cluster: dict):
|
||||
def start_workers():
|
||||
logger = create_logger('inferencer')
|
||||
i = 0
|
||||
for item in cluster:
|
||||
for _ in range(item['concurrent_gens']):
|
||||
t = threading.Thread(target=worker, args=(item['backend_url'],))
|
||||
for item in GlobalConfig.get().cluster:
|
||||
for _ in range(item.concurrent_gens):
|
||||
t = threading.Thread(target=worker, args=(item.backend_url,))
|
||||
t.daemon = True
|
||||
t.start()
|
||||
i += 1
|
||||
|
|
|
@ -49,10 +49,10 @@ def main_background_thread():
|
|||
def calc_stats_for_backend(backend_url, running_model, backend_mode):
|
||||
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
|
||||
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now.
|
||||
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time',
|
||||
average_generation_elapsed_sec = weighted_average_column_for_model('messages', 'generation_time',
|
||||
running_model, backend_mode, backend_url, exclude_zeros=True,
|
||||
include_system_tokens=GlobalConfig.get().include_system_tokens_in_stats) or 0
|
||||
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens',
|
||||
average_output_tokens = weighted_average_column_for_model('messages', 'response_tokens',
|
||||
running_model, backend_mode, backend_url, exclude_zeros=True,
|
||||
include_system_tokens=GlobalConfig.get().include_system_tokens_in_stats) or 0
|
||||
estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero
|
||||
|
|
|
@ -22,7 +22,7 @@ def cache_stats():
|
|||
|
||||
def start_background():
|
||||
logger = create_logger('threader')
|
||||
start_workers(GlobalConfig.get().cluster)
|
||||
start_workers()
|
||||
|
||||
t = Thread(target=main_background_thread)
|
||||
t.daemon = True
|
||||
|
@ -46,7 +46,7 @@ def start_background():
|
|||
t = Thread(target=console_printer)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
logger.info('Started the console logger.infoer.')
|
||||
logger.info('Started the console logger.')
|
||||
|
||||
t = Thread(target=cluster_worker)
|
||||
t.daemon = True
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
from llm_server.helpers import resolve_path
|
||||
|
||||
try:
|
||||
import gevent.monkey
|
||||
|
||||
gevent.monkey.patch_all()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from llm_server.config.global_config import GlobalConfig
|
||||
from llm_server.config.load import load_config
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.database.conn import Database
|
||||
from llm_server.database.create import create_db
|
||||
from llm_server.logging import init_logging, create_logger
|
||||
|
||||
|
||||
def post_fork(server, worker):
|
||||
"""
|
||||
Initalize the worker after gunicorn has forked. This is done to avoid issues with the database manager.
|
||||
"""
|
||||
script_path = Path(os.path.dirname(os.path.realpath(__file__)))
|
||||
config_path_environ = os.getenv("CONFIG_PATH")
|
||||
if config_path_environ:
|
||||
config_path = config_path_environ
|
||||
else:
|
||||
config_path = script_path / '../config/config.yml'
|
||||
config_path = resolve_path(config_path)
|
||||
|
||||
success, msg = load_config(config_path)
|
||||
if not success:
|
||||
logger = logging.getLogger('llm_server')
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.error(f'Failed to load config: {msg}')
|
||||
sys.exit(1)
|
||||
|
||||
init_logging()
|
||||
logger = create_logger('Server')
|
||||
logger.debug('Debug logging enabled.')
|
||||
|
||||
while not redis.get('daemon_started', dtype=bool):
|
||||
logger.warning('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?')
|
||||
time.sleep(10)
|
||||
|
||||
Database.initialise(**GlobalConfig.get().postgresql.dict())
|
||||
create_db()
|
||||
|
||||
logger.info('Started HTTP worker!')
|
|
@ -11,6 +11,8 @@ WorkingDirectory=/srv/server/local-llm-server
|
|||
# Sometimes the old processes aren't terminated when the service is restarted.
|
||||
ExecStartPre=/usr/bin/pkill -9 -f "/srv/server/local-llm-server/venv/bin/python3 /srv/server/local-llm-server/venv/bin/gunicorn"
|
||||
|
||||
# TODO: make sure gunicorn logs to stdout and logging also goes to stdout
|
||||
|
||||
# Need a lot of workers since we have long-running requests. This takes about 3.5G memory.
|
||||
ExecStart=/srv/server/local-llm-server/venv/bin/gunicorn --workers 20 --bind 0.0.0.0:5000 server:app --timeout 60 --worker-class gevent --access-logfile '-' --error-logfile '-'
|
||||
|
||||
|
|
|
@ -4,7 +4,6 @@ Flask-Caching==2.0.2
|
|||
requests~=2.31.0
|
||||
tiktoken~=0.5.0
|
||||
gevent~=23.9.0.post1
|
||||
mysql-connector-python==8.4.0
|
||||
simplejson~=3.19.1
|
||||
websockets~=11.0.3
|
||||
basicauth~=1.0.0
|
||||
|
@ -17,3 +16,4 @@ vllm==0.2.7
|
|||
coloredlogs~=15.0.1
|
||||
git+https://git.evulid.cc/cyberes/bison.git
|
||||
pydantic
|
||||
psycopg2-binary==2.9.9
|
70
server.py
70
server.py
|
@ -1,33 +1,13 @@
|
|||
import time
|
||||
|
||||
from llm_server.config.global_config import GlobalConfig
|
||||
|
||||
try:
|
||||
import gevent.monkey
|
||||
|
||||
gevent.monkey.patch_all()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import simplejson as json
|
||||
from flask import Flask, jsonify, render_template, request, Response
|
||||
|
||||
import config
|
||||
from llm_server.cluster.backend import get_model_choices
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.config.config import mode_ui_names
|
||||
from llm_server.config.load import load_config
|
||||
from llm_server.config.config import MODE_UI_NAMES
|
||||
from llm_server.config.global_config import GlobalConfig
|
||||
from llm_server.custom_redis import flask_cache, redis
|
||||
from llm_server.database.conn import Database
|
||||
from llm_server.database.create import create_db
|
||||
from llm_server.helpers import auto_set_base_client_api
|
||||
from llm_server.llm.vllm.info import vllm_info
|
||||
from llm_server.logging import init_logging, create_logger
|
||||
from llm_server.routes.openai import openai_bp, openai_model_bp
|
||||
from llm_server.routes.server_error import handle_server_error
|
||||
from llm_server.routes.v1 import bp
|
||||
|
@ -63,32 +43,6 @@ from llm_server.sock import init_wssocket
|
|||
# TODO: add more excluding to SYSTEM__ tokens
|
||||
# TODO: return 200 when returning formatted sillytavern error
|
||||
|
||||
script_path = os.path.dirname(os.path.realpath(__file__))
|
||||
config_path_environ = os.getenv("CONFIG_PATH")
|
||||
if config_path_environ:
|
||||
config_path = config_path_environ
|
||||
else:
|
||||
config_path = Path(script_path, 'config', 'config.yml')
|
||||
|
||||
success, msg = load_config(config_path)
|
||||
if not success:
|
||||
logger = logging.getLogger('llm_server')
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.error(f'Failed to load config: {msg}')
|
||||
sys.exit(1)
|
||||
|
||||
init_logging(Path(GlobalConfig.get().webserver_log_directory) / 'server.log')
|
||||
logger = create_logger('Server')
|
||||
logger.debug('Debug logging enabled.')
|
||||
|
||||
while not redis.get('daemon_started', dtype=bool):
|
||||
logger.warning('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?')
|
||||
time.sleep(10)
|
||||
|
||||
logger.info('Started HTTP worker!')
|
||||
|
||||
Database.initialise(maxconn=GlobalConfig.get().mysql.maxconn, host=GlobalConfig.get().mysql.host, user=GlobalConfig.get().mysql.username, password=GlobalConfig.get().mysql.password, database=GlobalConfig.get().mysql.database)
|
||||
create_db()
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
|
@ -139,13 +93,13 @@ def home():
|
|||
# to None by the daemon.
|
||||
default_model_info['context_size'] = '-'
|
||||
|
||||
if len(config['analytics_tracking_code']):
|
||||
analytics_tracking_code = f"<script>\n{config['analytics_tracking_code']}\n</script>"
|
||||
if len(GlobalConfig.get().analytics_tracking_code):
|
||||
analytics_tracking_code = f"<script>\n{GlobalConfig.get().analytics_tracking_code}\n</script>"
|
||||
else:
|
||||
analytics_tracking_code = ''
|
||||
|
||||
if config['info_html']:
|
||||
info_html = config['info_html']
|
||||
if GlobalConfig.get().info_html:
|
||||
info_html = GlobalConfig.get().info_html
|
||||
else:
|
||||
info_html = ''
|
||||
|
||||
|
@ -166,9 +120,9 @@ def home():
|
|||
client_api=f'https://{base_client_api}',
|
||||
ws_client_api=f'wss://{base_client_api}/v1/stream' if GlobalConfig.get().enable_streaming else 'disabled',
|
||||
default_estimated_wait=default_estimated_wait_sec,
|
||||
mode_name=mode_ui_names[GlobalConfig.get().frontend_api_mode][0],
|
||||
api_input_textbox=mode_ui_names[GlobalConfig.get().frontend_api_mode][1],
|
||||
streaming_input_textbox=mode_ui_names[GlobalConfig.get().frontend_api_mode][2],
|
||||
mode_name=MODE_UI_NAMES[GlobalConfig.get().frontend_api_mode].name,
|
||||
api_input_textbox=MODE_UI_NAMES[GlobalConfig.get().frontend_api_mode].api_name,
|
||||
streaming_input_textbox=MODE_UI_NAMES[GlobalConfig.get().frontend_api_mode].streaming_name,
|
||||
default_context_size=default_model_info['context_size'],
|
||||
stats_json=json.dumps(stats, indent=4, ensure_ascii=False),
|
||||
extra_info=mode_info,
|
||||
|
@ -212,6 +166,6 @@ def before_app_request():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# server_startup(None)
|
||||
print('FLASK MODE - Startup complete!')
|
||||
app.run(host='0.0.0.0', threaded=False, processes=15)
|
||||
print('Do not run this file directly. Instead, use gunicorn:')
|
||||
print("gunicorn -c other/gunicorn_conf.py server:app -b 0.0.0.0:5000 --worker-class gevent --workers 3 --access-logfile '-' --error-logfile '-'")
|
||||
quit(1)
|
||||
|
|
Reference in New Issue