refactor a lot of things, major cleanup, use postgresql

This commit is contained in:
Cyberes 2024-05-07 17:03:41 -06:00
parent ee9a0d4858
commit fd09c783d3
31 changed files with 220 additions and 367 deletions

View File

@ -3,7 +3,6 @@ import logging
import os import os
import sys import sys
import time import time
from pathlib import Path
from redis import Redis from redis import Redis
@ -14,6 +13,7 @@ from llm_server.custom_redis import redis
from llm_server.database.conn import Database from llm_server.database.conn import Database
from llm_server.database.create import create_db from llm_server.database.create import create_db
from llm_server.database.database import get_number_of_rows from llm_server.database.database import get_number_of_rows
from llm_server.helpers import resolve_path
from llm_server.logging import create_logger, logging_info, init_logging from llm_server.logging import create_logger, logging_info, init_logging
from llm_server.routes.v1.generate_stats import generate_stats from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.workers.threader import start_background from llm_server.workers.threader import start_background
@ -23,7 +23,7 @@ config_path_environ = os.getenv("CONFIG_PATH")
if config_path_environ: if config_path_environ:
config_path = config_path_environ config_path = config_path_environ
else: else:
config_path = Path(script_path, 'config', 'config.yml') config_path = resolve_path(script_path, 'config', 'config.yml')
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Daemon microservice.') parser = argparse.ArgumentParser(description='Daemon microservice.')
@ -47,7 +47,7 @@ if __name__ == "__main__":
logger.info(f'Failed to load config: {msg}') logger.info(f'Failed to load config: {msg}')
sys.exit(1) sys.exit(1)
Database.initialise(maxconn=GlobalConfig.get().mysql.maxconn, host=GlobalConfig.get().mysql.host, user=GlobalConfig.get().mysql.username, password=GlobalConfig.get().mysql.password, database=GlobalConfig.get().mysql.database) Database.initialise(**GlobalConfig.get().postgresql.dict())
create_db() create_db()
cluster_config.clear() cluster_config.clear()
@ -57,7 +57,7 @@ if __name__ == "__main__":
generate_stats(regen=True) generate_stats(regen=True)
if GlobalConfig.get().load_num_prompts: if GlobalConfig.get().load_num_prompts:
redis.set('proompts', get_number_of_rows('prompts')) redis.set('proompts', get_number_of_rows('messages'))
start_background() start_background()

View File

@ -17,6 +17,7 @@ def get_backends_from_model(model_name: str):
:param model_name: :param model_name:
:return: :return:
""" """
assert isinstance(model_name, str)
return [x.decode('utf-8') for x in redis_running_models.smembers(model_name)] return [x.decode('utf-8') for x in redis_running_models.smembers(model_name)]
@ -25,7 +26,7 @@ def get_running_models():
Get all the models that are in the cluster. Get all the models that are in the cluster.
:return: :return:
""" """
return list(redis_running_models.keys()) return [x.decode('utf-8') for x in list(redis_running_models.keys())]
def is_valid_model(model_name: str) -> bool: def is_valid_model(model_name: str) -> bool:
@ -81,6 +82,7 @@ def get_model_choices(regen: bool = False) -> tuple[dict, dict]:
base_client_api = redis.get('base_client_api', dtype=str) base_client_api = redis.get('base_client_api', dtype=str)
running_models = get_running_models() running_models = get_running_models()
model_choices = {} model_choices = {}
for model in running_models: for model in running_models:
b = get_backends_from_model(model) b = get_backends_from_model(model)

View File

@ -33,7 +33,7 @@ class RedisClusterStore:
item.backend_url = backend_url item.backend_url = backend_url
stuff[backend_url] = item stuff[backend_url] = item
for k, v in stuff.items(): for k, v in stuff.items():
self.add_backend(k, v) self.add_backend(k, v.dict())
def add_backend(self, name: str, values: dict): def add_backend(self, name: str, values: dict):
self.config_redis.hset(name, mapping={k: pickle.dumps(v) for k, v in values.items()}) self.config_redis.hset(name, mapping={k: pickle.dumps(v) for k, v in values.items()})

View File

@ -1,14 +1,22 @@
from pydantic import BaseModel
from llm_server.config.global_config import GlobalConfig from llm_server.config.global_config import GlobalConfig
def cluster_worker_count(): def cluster_worker_count():
count = 0 count = 0
for item in GlobalConfig.get().cluster: for item in GlobalConfig.get().cluster:
count += item['concurrent_gens'] count += item.concurrent_gens
return count return count
mode_ui_names = { class ModeUINameStr(BaseModel):
'ooba': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'), name: str
'vllm': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'), api_name: str
streaming_name: str
MODE_UI_NAMES = {
'ooba': ModeUINameStr(name='Text Gen WebUI (ooba)', api_name='Blocking API url', streaming_name='Streaming API url'),
'vllm': ModeUINameStr(name='Text Gen WebUI (ooba)', api_name='Blocking API url', streaming_name='Streaming API url'),
} }

View File

@ -63,8 +63,8 @@ def load_config(config_path: Path):
config_model = ConfigModel(**config.config) config_model = ConfigModel(**config.config)
GlobalConfig.initalize(config_model) GlobalConfig.initalize(config_model)
if not (0 < GlobalConfig.get().mysql.maxconn <= 32): if GlobalConfig.get().postgresql.maxconn < 0:
return False, f'"maxcon" should be higher than 0 and lower or equal to 32. Current value: "{GlobalConfig.get().mysql.maxconn}"' return False, f'"maxcon" should be higher than 0. Current value: "{GlobalConfig.get().postgresql.maxconn}"'
openai.api_key = GlobalConfig.get().openai_api_key openai.api_key = GlobalConfig.get().openai_api_key

View File

@ -19,9 +19,9 @@ class ConfigFrontendApiModes(str, Enum):
ooba = 'ooba' ooba = 'ooba'
class ConfigMysql(BaseModel): class ConfigPostgresql(BaseModel):
host: str host: str
username: str user: str
password: str password: str
database: str database: str
maxconn: int maxconn: int
@ -37,9 +37,8 @@ class ConfigModel(BaseModel):
cluster: List[ConfigCluser] cluster: List[ConfigCluser]
prioritize_by_size: bool prioritize_by_size: bool
admin_token: Union[str, None] admin_token: Union[str, None]
mysql: ConfigMysql postgresql: ConfigPostgresql
http_host: str http_host: str
webserver_log_directory: str
include_system_tokens_in_stats: bool include_system_tokens_in_stats: bool
background_homepage_cacher: bool background_homepage_cacher: bool
max_new_tokens: int max_new_tokens: int
@ -55,6 +54,7 @@ class ConfigModel(BaseModel):
info_html: Union[str, None] info_html: Union[str, None]
enable_openi_compatible_backend: bool enable_openi_compatible_backend: bool
openai_api_key: Union[str, None] openai_api_key: Union[str, None]
openai_system_prompt: str
expose_openai_system_prompt: bool expose_openai_system_prompt: bool
openai_expose_our_model: bool openai_expose_our_model: bool
openai_force_no_hashes: bool openai_force_no_hashes: bool
@ -72,3 +72,4 @@ class ConfigModel(BaseModel):
load_num_prompts: bool load_num_prompts: bool
manual_model_name: Union[str, None] manual_model_name: Union[str, None]
backend_request_timeout: int backend_request_timeout: int
backend_generate_request_timeout: int

View File

@ -2,7 +2,7 @@ from typing import Union
import bison import bison
from llm_server.opts import default_openai_system_prompt from llm_server.globals import DEFAULT_OPENAI_SYSTEM_PROMPT
config_scheme = bison.Scheme( config_scheme = bison.Scheme(
bison.Option('frontend_api_mode', choices=['ooba'], field_type=str), bison.Option('frontend_api_mode', choices=['ooba'], field_type=str),
@ -14,15 +14,14 @@ config_scheme = bison.Scheme(
)), )),
bison.Option('prioritize_by_size', default=True, field_type=bool), bison.Option('prioritize_by_size', default=True, field_type=bool),
bison.Option('admin_token', default=None, field_type=Union[str, None]), bison.Option('admin_token', default=None, field_type=Union[str, None]),
bison.ListOption('mysql', member_scheme=bison.Scheme( bison.ListOption('postgresql', member_scheme=bison.Scheme(
bison.Option('host', field_type=str), bison.Option('host', field_type=str),
bison.Option('username', field_type=str), bison.Option('user', field_type=str),
bison.Option('password', field_type=str), bison.Option('password', field_type=str),
bison.Option('database', field_type=str), bison.Option('database', field_type=str),
bison.Option('maxconn', field_type=int) bison.Option('maxconn', field_type=int)
)), )),
bison.Option('http_host', default='', field_type=str), bison.Option('http_host', default='', field_type=str),
bison.Option('webserver_log_directory', default='/var/log/localllm', field_type=str),
bison.Option('include_system_tokens_in_stats', default=True, field_type=bool), bison.Option('include_system_tokens_in_stats', default=True, field_type=bool),
bison.Option('background_homepage_cacher', default=True, field_type=bool), bison.Option('background_homepage_cacher', default=True, field_type=bool),
bison.Option('max_new_tokens', default=500, field_type=int), bison.Option('max_new_tokens', default=500, field_type=int),
@ -41,7 +40,7 @@ config_scheme = bison.Scheme(
bison.Option('expose_openai_system_prompt', default=True, field_type=bool), bison.Option('expose_openai_system_prompt', default=True, field_type=bool),
bison.Option('openai_expose_our_model', default='', field_type=bool), bison.Option('openai_expose_our_model', default='', field_type=bool),
bison.Option('openai_force_no_hashes', default=True, field_type=bool), bison.Option('openai_force_no_hashes', default=True, field_type=bool),
bison.Option('openai_system_prompt', default=default_openai_system_prompt, field_type=str), bison.Option('openai_system_prompt', default=DEFAULT_OPENAI_SYSTEM_PROMPT, field_type=str),
bison.Option('openai_moderation_enabled', default=False, field_type=bool), bison.Option('openai_moderation_enabled', default=False, field_type=bool),
bison.Option('openai_moderation_timeout', default=5, field_type=int), bison.Option('openai_moderation_timeout', default=5, field_type=int),
bison.Option('openai_moderation_scan_last_n', default=5, field_type=int), bison.Option('openai_moderation_scan_last_n', default=5, field_type=int),
@ -55,5 +54,6 @@ config_scheme = bison.Scheme(
bison.Option('show_backend_info', default=True, field_type=bool), bison.Option('show_backend_info', default=True, field_type=bool),
bison.Option('load_num_prompts', default=True, field_type=bool), bison.Option('load_num_prompts', default=True, field_type=bool),
bison.Option('manual_model_name', default=None, field_type=Union[str, None]), bison.Option('manual_model_name', default=None, field_type=Union[str, None]),
bison.Option('backend_request_timeout', default=30, field_type=int) bison.Option('backend_request_timeout', default=30, field_type=int),
bison.Option('backend_generate_request_timeout', default=95, field_type=int)
) )

View File

@ -2,13 +2,13 @@ import logging
import pickle import pickle
import sys import sys
import traceback import traceback
from typing import Callable, List, Mapping, Optional, Union from typing import Union
import redis as redis_pkg import redis as redis_pkg
import simplejson as json import simplejson as json
from flask_caching import Cache from flask_caching import Cache
from redis import Redis from redis import Redis
from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, PatternT, ZScoreBoundT, AbsExpiryT from redis.typing import ExpiryT, KeyT, PatternT
flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/15', 'CACHE_KEY_PREFIX': 'local_llm_flask'}) flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/15', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
@ -38,18 +38,11 @@ class RedisCustom(Redis):
def _key(self, key): def _key(self, key):
return f"{self.prefix}:{key}" return f"{self.prefix}:{key}"
def set(self, key: KeyT, def execute_command(self, *args, **options):
value: EncodableT, if args[0] != 'GET':
ex: Union[ExpiryT, None] = None, args = list(args)
px: Union[ExpiryT, None] = None, args[1] = self._key(args[1])
nx: bool = False, return super().execute_command(*args, **options)
xx: bool = False,
keepttl: bool = False,
get: bool = False,
exat: Union[AbsExpiryT, None] = None,
pxat: Union[AbsExpiryT, None] = None
):
return self.redis.set(self._key(key), value, ex=ex)
def get(self, key, default=None, dtype=None): def get(self, key, default=None, dtype=None):
# TODO: use pickle # TODO: use pickle
@ -73,103 +66,6 @@ class RedisCustom(Redis):
else: else:
return d return d
def incr(self, key, amount=1):
return self.redis.incr(self._key(key), amount)
def decr(self, key, amount=1):
return self.redis.decr(self._key(key), amount)
def sadd(self, key: str, *values: FieldT):
return self.redis.sadd(self._key(key), *values)
def srem(self, key: str, *values: FieldT):
return self.redis.srem(self._key(key), *values)
def sismember(self, key: str, value: str):
return self.redis.sismember(self._key(key), value)
def lindex(
self, name: str, index: int
):
return self.redis.lindex(self._key(name), index)
def lrem(self, name: str, count: int, value: str):
return self.redis.lrem(self._key(name), count, value)
def rpush(self, name: str, *values: FieldT):
return self.redis.rpush(self._key(name), *values)
def llen(self, name: str):
return self.redis.llen(self._key(name))
def zrangebyscore(
self,
name: KeyT,
min: ZScoreBoundT,
max: ZScoreBoundT,
start: Union[int, None] = None,
num: Union[int, None] = None,
withscores: bool = False,
score_cast_func: Union[type, Callable] = float,
):
return self.redis.zrangebyscore(self._key(name), min, max, start, num, withscores, score_cast_func)
def zremrangebyscore(
self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT
):
return self.redis.zremrangebyscore(self._key(name), min, max)
def hincrby(
self, name: str, key: str, amount: int = 1
):
return self.redis.hincrby(self._key(name), key, amount)
def zcard(self, name: KeyT):
return self.redis.zcard(self._key(name))
def hdel(self, name: str, *keys: str):
return self.redis.hdel(self._key(name), *keys)
def hget(
self, name: str, key: str
):
return self.redis.hget(self._key(name), key)
def zadd(
self,
name: KeyT,
mapping: Mapping[AnyKeyT, EncodableT],
nx: bool = False,
xx: bool = False,
ch: bool = False,
incr: bool = False,
gt: bool = False,
lt: bool = False,
):
return self.redis.zadd(self._key(name), mapping, nx, xx, ch, incr, gt, lt)
def lpush(self, name: str, *values: FieldT):
return self.redis.lpush(self._key(name), *values)
def hset(
self,
name: str,
key: Optional = None,
value=None,
mapping: Optional[dict] = None,
items: Optional[list] = None,
):
return self.redis.hset(self._key(name), key, value, mapping, items)
def hkeys(self, name: str):
return self.redis.hkeys(self._key(name))
def hmget(self, name: str, keys: List, *args: List):
return self.redis.hmget(self._key(name), keys, *args)
def hgetall(self, name: str):
return self.redis.hgetall(self._key(name))
def keys(self, pattern: PatternT = "*", **kwargs): def keys(self, pattern: PatternT = "*", **kwargs):
raw_keys = self.redis.keys(self._key(pattern), **kwargs) raw_keys = self.redis.keys(self._key(pattern), **kwargs)
keys = [] keys = []
@ -179,25 +75,9 @@ class RedisCustom(Redis):
# Delete prefix # Delete prefix
del p[0] del p[0]
k = ':'.join(p) k = ':'.join(p)
if k != '____': # keys.append(k)
keys.append(k)
return keys return keys
def pipeline(self, transaction=True, shard_hint=None):
return self.redis.pipeline(transaction, shard_hint)
def smembers(self, name: str):
return self.redis.smembers(self._key(name))
def spop(self, name: str, count: Optional[int] = None):
return self.redis.spop(self._key(name), count)
def rpoplpush(self, src, dst):
return self.redis.rpoplpush(src, dst)
def zpopmin(self, name: KeyT, count: Union[int, None] = None):
return self.redis.zpopmin(self._key(name), count)
def exists(self, *names: KeyT): def exists(self, *names: KeyT):
n = [] n = []
for name in names: for name in names:
@ -238,32 +118,5 @@ class RedisCustom(Redis):
self.flush() self.flush()
return True return True
def lrange(self, name: str, start: int, end: int):
return self.redis.lrange(self._key(name), start, end)
def delete(self, *names: KeyT):
return self.redis.delete(*[self._key(i) for i in names])
def lpop(self, name: str, count: Optional[int] = None):
return self.redis.lpop(self._key(name), count)
def zrange(
self,
name: KeyT,
start: int,
end: int,
desc: bool = False,
withscores: bool = False,
score_cast_func: Union[type, Callable] = float,
byscore: bool = False,
bylex: bool = False,
offset: int = None,
num: int = None,
):
return self.redis.zrange(self._key(name), start, end, desc, withscores, score_cast_func, byscore, bylex, offset, num)
def zrem(self, name: KeyT, *values: FieldT):
return self.redis.zrem(self._key(name), *values)
redis = RedisCustom('local_llm') redis = RedisCustom('local_llm')

View File

@ -1,39 +1,42 @@
from mysql.connector import pooling from psycopg2 import pool, InterfaceError
class Database: class Database:
__connection_pool = None __connection_pool = None
@classmethod @classmethod
def initialise(cls, maxconn: int, **kwargs): def initialise(cls, maxconn, **kwargs):
if cls.__connection_pool is not None: if cls.__connection_pool is not None:
raise Exception('Database connection pool is already initialised') raise Exception('Database connection pool is already initialised')
cls.__connection_pool = pooling.MySQLConnectionPool(pool_size=maxconn, cls.__connection_pool = pool.ThreadedConnectionPool(minconn=1, maxconn=maxconn, **kwargs)
pool_reset_session=True,
**kwargs)
@classmethod @classmethod
def get_connection(cls): def get_connection(cls):
return cls.__connection_pool.get_connection() return cls.__connection_pool.getconn()
@classmethod @classmethod
def return_connection(cls, connection): def return_connection(cls, connection):
connection.close() cls.__connection_pool.putconn(connection)
class CursorFromConnectionFromPool: class CursorFromConnectionFromPool:
def __init__(self): def __init__(self, cursor_factory=None):
self.conn = None self.conn = None
self.cursor = None self.cursor = None
self.cursor_factory = cursor_factory
def __enter__(self): def __enter__(self):
self.conn = Database.get_connection() self.conn = Database.get_connection()
self.cursor = self.conn.cursor() self.cursor = self.conn.cursor(cursor_factory=self.cursor_factory)
return self.cursor return self.cursor
def __exit__(self, exception_type, exception_value, exception_traceback): def __exit__(self, exception_type, exception_value, exception_traceback):
if exception_value is not None: # This is equivalent of saying if there is an exception if exception_value is not None: # This is equivalent of saying if there is an exception
self.conn.rollback() try:
self.conn.rollback()
except InterfaceError as e:
if e != 'connection already closed':
raise
else: else:
self.cursor.close() self.cursor.close()
self.conn.commit() self.conn.commit()

View File

@ -4,36 +4,39 @@ from llm_server.database.conn import CursorFromConnectionFromPool
def create_db(): def create_db():
with CursorFromConnectionFromPool() as cursor: with CursorFromConnectionFromPool() as cursor:
cursor.execute(''' cursor.execute('''
CREATE TABLE IF NOT EXISTS prompts ( CREATE TABLE IF NOT EXISTS public.messages
ip TEXT, (
token TEXT DEFAULT NULL, ip text COLLATE pg_catalog."default" NOT NULL,
model TEXT, token text COLLATE pg_catalog."default",
backend_mode TEXT, model text COLLATE pg_catalog."default" NOT NULL,
backend_url TEXT, backend_mode text COLLATE pg_catalog."default" NOT NULL,
request_url TEXT, backend_url text COLLATE pg_catalog."default" NOT NULL,
generation_time FLOAT, request_url text COLLATE pg_catalog."default" NOT NULL,
prompt LONGTEXT, generation_time double precision NOT NULL,
prompt_tokens INTEGER, prompt text COLLATE pg_catalog."default" NOT NULL,
response LONGTEXT, prompt_tokens integer NOT NULL,
response_tokens INTEGER, response text COLLATE pg_catalog."default" NOT NULL,
response_status INTEGER, response_tokens integer NOT NULL,
parameters TEXT, response_status integer NOT NULL,
# CHECK (parameters IS NULL OR JSON_VALID(parameters)), parameters jsonb NOT NULL,
headers TEXT, headers jsonb,
# CHECK (headers IS NULL OR JSON_VALID(headers)), "timestamp" timestamp with time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
timestamp INTEGER id SERIAL PRIMARY KEY
) );
''') ''')
cursor.execute(''' cursor.execute('''
CREATE TABLE IF NOT EXISTS token_auth ( CREATE TABLE IF NOT EXISTS public.token_auth
token TEXT, (
UNIQUE (token), token text COLLATE pg_catalog."default" NOT NULL,
type TEXT NOT NULL, type text COLLATE pg_catalog."default" NOT NULL,
priority INTEGER DEFAULT 9999, priority integer NOT NULL DEFAULT 9999,
simultaneous_ip INTEGER DEFAULT NULL, simultaneous_ip text COLLATE pg_catalog."default",
uses INTEGER DEFAULT 0, openai_moderation_enabled boolean NOT NULL DEFAULT true,
max_uses INTEGER, uses integer NOT NULL DEFAULT 0,
expire INTEGER, max_uses integer,
disabled BOOLEAN DEFAULT 0 expire timestamp with time zone,
) disabled boolean NOT NULL DEFAULT false,
notes text COLLATE pg_catalog."default" NOT NULL DEFAULT ''::text,
CONSTRAINT token_auth_pkey PRIMARY KEY (token)
)
''') ''')

View File

@ -1,6 +1,7 @@
import json import json
import time import time
import traceback import traceback
from datetime import datetime, timedelta
from typing import Union from typing import Union
from llm_server.cluster.cluster_config import cluster_config from llm_server.cluster.cluster_config import cluster_config
@ -51,10 +52,10 @@ def do_db_log(ip: str, token: str, prompt: str, response: Union[str, None], gen_
backend_info = cluster_config.get_backend(backend_url) backend_info = cluster_config.get_backend(backend_url)
running_model = backend_info.get('model') running_model = backend_info.get('model')
backend_mode = backend_info['mode'] backend_mode = backend_info['mode']
timestamp = int(time.time()) timestamp = datetime.now()
with CursorFromConnectionFromPool() as cursor: with CursorFromConnectionFromPool() as cursor:
cursor.execute(""" cursor.execute("""
INSERT INTO prompts INSERT INTO messages
(ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp) (ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""", """,
@ -65,12 +66,12 @@ def is_valid_api_key(api_key):
with CursorFromConnectionFromPool() as cursor: with CursorFromConnectionFromPool() as cursor:
cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = %s", (api_key,)) cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = %s", (api_key,))
row = cursor.fetchone() row = cursor.fetchone()
if row is not None: if row is not None:
token, uses, max_uses, expire, disabled = row token, uses, max_uses, expire, disabled = row
disabled = bool(disabled) disabled = bool(disabled)
if ((uses is None or max_uses is None) or uses < max_uses) and (expire is None or expire > time.time()) and not disabled: if ((uses is None or max_uses is None) or uses < max_uses) and (expire is None or expire > time.time()) and not disabled:
return True return True
return False return False
def is_api_key_moderated(api_key): def is_api_key_moderated(api_key):
@ -146,9 +147,9 @@ def sum_column(table_name, column_name):
def get_distinct_ips_24h(): def get_distinct_ips_24h():
# Get the current time and subtract 24 hours (in seconds) # Get the current time and subtract 24 hours (in seconds)
past_24_hours = int(time.time()) - 24 * 60 * 60 past_24_hours = datetime.now() - timedelta(days=1)
with CursorFromConnectionFromPool() as cursor: with CursorFromConnectionFromPool() as cursor:
cursor.execute("SELECT COUNT(DISTINCT ip) FROM prompts WHERE timestamp >= %s AND (token NOT LIKE 'SYSTEM__%%' OR token IS NULL)", (past_24_hours,)) cursor.execute("SELECT COUNT(DISTINCT ip) FROM messages WHERE timestamp >= %s AND (token NOT LIKE 'SYSTEM__%%' OR token IS NULL)", (past_24_hours,))
result = cursor.fetchone() result = cursor.fetchone()
return result[0] if result else 0 return result[0] if result else 0

View File

@ -1,10 +1,6 @@
# Read-only global variables # Read-only global variables
default_openai_system_prompt = """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nLines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user.\n\n""" DEFAULT_OPENAI_SYSTEM_PROMPT = """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nLines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user.\n\n"""
# cluster = {}
REDIS_STREAM_TIMEOUT = 25000 REDIS_STREAM_TIMEOUT = 25000
LOGGING_FORMAT = "%(asctime)s: %(levelname)s:%(name)s - %(message)s" LOGGING_FORMAT = "%(asctime)s: %(levelname)s:%(name)s - %(message)s"
BACKEND_OFFLINE = 'The model you requested is not a valid choice. Please retry your query.'

View File

@ -15,19 +15,6 @@ def resolve_path(*p: str):
return Path(*p).expanduser().resolve().absolute() return Path(*p).expanduser().resolve().absolute()
def safe_list_get(l, idx, default):
"""
https://stackoverflow.com/a/5125636
:param l:
:param idx:
:param default:
:return:
"""
try:
return l[idx]
except IndexError:
return default
def deep_sort(obj): def deep_sort(obj):
if isinstance(obj, dict): if isinstance(obj, dict):

View File

@ -1,4 +1,4 @@
from llm_server import opts from llm_server import globals
from llm_server.cluster.cluster_config import cluster_config from llm_server.cluster.cluster_config import cluster_config

View File

@ -8,14 +8,14 @@ import requests
from llm_server.config.global_config import GlobalConfig from llm_server.config.global_config import GlobalConfig
def generate(json_data: dict): # def generate(json_data: dict):
try: # try:
r = requests.post(f'{GlobalConfig.get().backend_url}/api/v1/generate', json=json_data, verify=GlobalConfig.get().verify_ssl, timeout=GlobalConfig.get().backend_generate_request_timeout) # r = requests.post(f'{GlobalConfig.get().backend_url}/api/v1/generate', json=json_data, verify=GlobalConfig.get().verify_ssl, timeout=GlobalConfig.get().backend_generate_request_timeout)
except requests.exceptions.ReadTimeout: # except requests.exceptions.ReadTimeout:
return False, None, 'Request to backend timed out' # return False, None, 'Request to backend timed out'
except Exception as e: # except Exception as e:
traceback.print_exc() # traceback.print_exc()
return False, None, 'Request to backend encountered error' # return False, None, 'Request to backend encountered error'
if r.status_code != 200: # if r.status_code != 200:
return False, r, f'Backend returned {r.status_code}' # return False, r, f'Backend returned {r.status_code}'
return True, r, None # return True, r, None

View File

@ -98,3 +98,14 @@ def return_invalid_model_err(requested_model: str):
"code": "model_not_found" "code": "model_not_found"
} }
}), 404 }), 404
def return_oai_internal_server_error():
return jsonify({
"error": {
"message": "Internal server error",
"type": "auth_subrequest_error",
"param": None,
"code": "internal_error"
}
}), 500

View File

@ -31,7 +31,7 @@ def handle_blocking_request(json_data: dict, cluster_backend, timeout: int = 10)
return False, None, 'Request to backend timed out' return False, None, 'Request to backend timed out'
except Exception as e: except Exception as e:
# print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}') # print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}')
return False, None, 'Request to backend encountered error' return False, None, f'Request to backend encountered error -- {e.__class__.__name__}: {e}'
if r.status_code != 200: if r.status_code != 200:
# print(f'Failed to reach VLLM inference endpoint - got code {r.status_code}') # print(f'Failed to reach VLLM inference endpoint - got code {r.status_code}')
return False, r, f'Backend returned {r.status_code}' return False, r, f'Backend returned {r.status_code}'

View File

@ -1,16 +1,14 @@
import logging import logging
import sys
from pathlib import Path
import coloredlogs import coloredlogs
from llm_server import opts from llm_server import globals
class LoggingInfo: class LoggingInfo:
def __init__(self): def __init__(self):
self._level = logging.INFO self._level = logging.INFO
self._format = opts.LOGGING_FORMAT self._format = globals.LOGGING_FORMAT
@property @property
def level(self): def level(self):
@ -30,30 +28,17 @@ class LoggingInfo:
logging_info = LoggingInfo() logging_info = LoggingInfo()
LOG_DIRECTORY = None
def init_logging(filepath: Path = None): def init_logging():
""" """
Set up the parent logger. Ensures this logger and all children to log to a file. Set up the parent logger. Ensures this logger and all children to log to a file.
This is only called by `server.py` since there is wierdness with Gunicorn. The deamon doesn't need this. This is only called by `server.py` since there is wierdness with Gunicorn. The deamon doesn't need this.
:return: :return:
""" """
global LOG_DIRECTORY
logger = logging.getLogger('llm_server') logger = logging.getLogger('llm_server')
logger.setLevel(logging_info.level) logger.setLevel(logging_info.level)
if filepath:
p = Path(filepath)
if not p.parent.is_dir():
logger.fatal(f'Log directory does not exist: {p.parent}')
sys.exit(1)
LOG_DIRECTORY = p.parent
handler = logging.FileHandler(filepath)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
def create_logger(name): def create_logger(name):
logger = logging.getLogger('llm_server').getChild(name) logger = logging.getLogger('llm_server').getChild(name)
@ -65,7 +50,4 @@ def create_logger(name):
handler.setFormatter(formatter) handler.setFormatter(formatter)
logger.addHandler(handler) logger.addHandler(handler)
coloredlogs.install(logger=logger, level=logging_info.level) coloredlogs.install(logger=logger, level=logging_info.level)
if LOG_DIRECTORY:
handler = logging.FileHandler(LOG_DIRECTORY / f'{name}.log')
logger.addHandler(handler)
return logger return logger

View File

@ -1 +0,0 @@
BACKEND_OFFLINE = 'The model you requested is not a valid choice. Please retry your query.'

View File

@ -3,7 +3,7 @@ from typing import Tuple
import flask import flask
from flask import jsonify, request from flask import jsonify, request
from llm_server import messages import llm_server.globals
from llm_server.config.global_config import GlobalConfig from llm_server.config.global_config import GlobalConfig
from llm_server.database.log_to_db import log_to_db from llm_server.database.log_to_db import log_to_db
from llm_server.logging import create_logger from llm_server.logging import create_logger
@ -21,7 +21,7 @@ class OobaRequestHandler(RequestHandler):
assert not self.used assert not self.used
if self.offline: if self.offline:
# _logger.debug(f'This backend is offline.') # _logger.debug(f'This backend is offline.')
return self.handle_error(messages.BACKEND_OFFLINE) return self.handle_error(llm_server.globals.BACKEND_OFFLINE)
request_valid, invalid_response = self.validate_request() request_valid, invalid_response = self.validate_request()
if not request_valid: if not request_valid:

View File

@ -2,6 +2,7 @@ from flask import Blueprint
from ..request_handler import before_request from ..request_handler import before_request
from ...config.global_config import GlobalConfig from ...config.global_config import GlobalConfig
from ...llm.openai.oai_to_vllm import return_oai_internal_server_error
from ...logging import create_logger from ...logging import create_logger
_logger = create_logger('OpenAI') _logger = create_logger('OpenAI')
@ -26,15 +27,8 @@ def handle_error(e):
"auth_subrequest_error" "auth_subrequest_error"
""" """
_logger(f'OAI returning error: {e}') _logger.error(f'OAI returning error: {e}')
return jsonify({ return_oai_internal_server_error()
"error": {
"message": "Internal server error",
"type": "auth_subrequest_error",
"param": None,
"code": "internal_error"
}
}), 500
from .models import openai_list_models from .models import openai_list_models

View File

@ -13,7 +13,7 @@ from ..openai_request_handler import OpenAIRequestHandler
from ..queue import priority_queue from ..queue import priority_queue
from ...config.global_config import GlobalConfig from ...config.global_config import GlobalConfig
from ...database.log_to_db import log_to_db from ...database.log_to_db import log_to_db
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai, return_oai_internal_server_error
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
from ...logging import create_logger from ...logging import create_logger
@ -32,7 +32,8 @@ def openai_chat_completions(model_name=None):
else: else:
handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name) handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name)
if handler.offline: if handler.offline:
return return_invalid_model_err(model_name) # return return_invalid_model_err(model_name)
return_oai_internal_server_error()
if not request_json_body.get('stream'): if not request_json_body.get('stream'):
try: try:

View File

@ -14,7 +14,7 @@ from llm_server.custom_redis import redis
from llm_server.database.database import is_api_key_moderated from llm_server.database.database import is_api_key_moderated
from llm_server.database.log_to_db import log_to_db from llm_server.database.log_to_db import log_to_db
from llm_server.llm import get_token_count from llm_server.llm import get_token_count
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_invalid_model_err from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_invalid_model_err, return_oai_internal_server_error
from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
from llm_server.logging import create_logger from llm_server.logging import create_logger
from llm_server.routes.request_handler import RequestHandler from llm_server.routes.request_handler import RequestHandler
@ -31,9 +31,10 @@ class OpenAIRequestHandler(RequestHandler):
def handle_request(self) -> Tuple[flask.Response, int]: def handle_request(self) -> Tuple[flask.Response, int]:
assert not self.used assert not self.used
if self.offline: if self.offline:
msg = return_invalid_model_err(self.selected_model) # msg = return_invalid_model_err(self.selected_model)
_logger.error(f'OAI is offline: {msg}') # _logger.error(f'OAI is offline: {msg}')
return self.handle_error(msg) # return self.handle_error(msg)
return_oai_internal_server_error()
if GlobalConfig.get().openai_silent_trim: if GlobalConfig.get().openai_silent_trim:
oai_messages = trim_messages_to_fit(self.request.json['messages'], self.cluster_backend_info['model_config']['max_position_embeddings'], self.backend_url) oai_messages = trim_messages_to_fit(self.request.json['messages'], self.cluster_backend_info['model_config']['max_position_embeddings'], self.backend_url)
@ -109,7 +110,7 @@ class OpenAIRequestHandler(RequestHandler):
return response, 429 return response, 429
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]: def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
_logger.error('OAI Error: {error_msg}') _logger.error(f'OAI Error: {error_msg}')
return jsonify({ return jsonify({
"error": { "error": {
"message": "Invalid request, check your parameters and try again.", "message": "Invalid request, check your parameters and try again.",

View File

@ -34,7 +34,7 @@ def generate_stats(regen: bool = False):
'proompts_total': get_total_proompts() if GlobalConfig.get().show_num_prompts else None, 'proompts_total': get_total_proompts() if GlobalConfig.get().show_num_prompts else None,
'uptime': int((datetime.now() - server_start_time).total_seconds()) if GlobalConfig.get().show_uptime else None, 'uptime': int((datetime.now() - server_start_time).total_seconds()) if GlobalConfig.get().show_uptime else None,
# 'estimated_avg_tps': estimated_avg_tps, # 'estimated_avg_tps': estimated_avg_tps,
'tokens_generated': sum_column('prompts', 'response_tokens') if GlobalConfig.get().show_total_output_tokens else None, 'tokens_generated': sum_column('messages', 'response_tokens') if GlobalConfig.get().show_total_output_tokens else None,
'num_backends': len(cluster_config.all()) if GlobalConfig.get().show_backends else None, 'num_backends': len(cluster_config.all()) if GlobalConfig.get().show_backends else None,
}, },
'endpoints': { 'endpoints': {

View File

@ -8,6 +8,7 @@ import ujson
from redis import Redis from redis import Redis
from llm_server.cluster.cluster_config import cluster_config from llm_server.cluster.cluster_config import cluster_config
from llm_server.config.global_config import GlobalConfig
from llm_server.custom_redis import RedisCustom, redis from llm_server.custom_redis import RedisCustom, redis
from llm_server.llm.generator import generator from llm_server.llm.generator import generator
from llm_server.logging import create_logger from llm_server.logging import create_logger
@ -148,12 +149,12 @@ def worker(backend_url):
status_redis.setp(str(worker_id), None) status_redis.setp(str(worker_id), None)
def start_workers(cluster: dict): def start_workers():
logger = create_logger('inferencer') logger = create_logger('inferencer')
i = 0 i = 0
for item in cluster: for item in GlobalConfig.get().cluster:
for _ in range(item['concurrent_gens']): for _ in range(item.concurrent_gens):
t = threading.Thread(target=worker, args=(item['backend_url'],)) t = threading.Thread(target=worker, args=(item.backend_url,))
t.daemon = True t.daemon = True
t.start() t.start()
i += 1 i += 1

View File

@ -49,10 +49,10 @@ def main_background_thread():
def calc_stats_for_backend(backend_url, running_model, backend_mode): def calc_stats_for_backend(backend_url, running_model, backend_mode):
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0 # exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now. # was entered into the column. The new code enters null instead but we need to be backwards compatible for now.
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', average_generation_elapsed_sec = weighted_average_column_for_model('messages', 'generation_time',
running_model, backend_mode, backend_url, exclude_zeros=True, running_model, backend_mode, backend_url, exclude_zeros=True,
include_system_tokens=GlobalConfig.get().include_system_tokens_in_stats) or 0 include_system_tokens=GlobalConfig.get().include_system_tokens_in_stats) or 0
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', average_output_tokens = weighted_average_column_for_model('messages', 'response_tokens',
running_model, backend_mode, backend_url, exclude_zeros=True, running_model, backend_mode, backend_url, exclude_zeros=True,
include_system_tokens=GlobalConfig.get().include_system_tokens_in_stats) or 0 include_system_tokens=GlobalConfig.get().include_system_tokens_in_stats) or 0
estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero

View File

@ -22,7 +22,7 @@ def cache_stats():
def start_background(): def start_background():
logger = create_logger('threader') logger = create_logger('threader')
start_workers(GlobalConfig.get().cluster) start_workers()
t = Thread(target=main_background_thread) t = Thread(target=main_background_thread)
t.daemon = True t.daemon = True
@ -46,7 +46,7 @@ def start_background():
t = Thread(target=console_printer) t = Thread(target=console_printer)
t.daemon = True t.daemon = True
t.start() t.start()
logger.info('Started the console logger.infoer.') logger.info('Started the console logger.')
t = Thread(target=cluster_worker) t = Thread(target=cluster_worker)
t.daemon = True t.daemon = True

54
other/gunicorn_conf.py Normal file
View File

@ -0,0 +1,54 @@
from llm_server.helpers import resolve_path
try:
import gevent.monkey
gevent.monkey.patch_all()
except ImportError:
pass
import logging
import os
import sys
import time
from pathlib import Path
from llm_server.config.global_config import GlobalConfig
from llm_server.config.load import load_config
from llm_server.custom_redis import redis
from llm_server.database.conn import Database
from llm_server.database.create import create_db
from llm_server.logging import init_logging, create_logger
def post_fork(server, worker):
"""
Initalize the worker after gunicorn has forked. This is done to avoid issues with the database manager.
"""
script_path = Path(os.path.dirname(os.path.realpath(__file__)))
config_path_environ = os.getenv("CONFIG_PATH")
if config_path_environ:
config_path = config_path_environ
else:
config_path = script_path / '../config/config.yml'
config_path = resolve_path(config_path)
success, msg = load_config(config_path)
if not success:
logger = logging.getLogger('llm_server')
logger.setLevel(logging.INFO)
logger.error(f'Failed to load config: {msg}')
sys.exit(1)
init_logging()
logger = create_logger('Server')
logger.debug('Debug logging enabled.')
while not redis.get('daemon_started', dtype=bool):
logger.warning('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?')
time.sleep(10)
Database.initialise(**GlobalConfig.get().postgresql.dict())
create_db()
logger.info('Started HTTP worker!')

View File

@ -11,6 +11,8 @@ WorkingDirectory=/srv/server/local-llm-server
# Sometimes the old processes aren't terminated when the service is restarted. # Sometimes the old processes aren't terminated when the service is restarted.
ExecStartPre=/usr/bin/pkill -9 -f "/srv/server/local-llm-server/venv/bin/python3 /srv/server/local-llm-server/venv/bin/gunicorn" ExecStartPre=/usr/bin/pkill -9 -f "/srv/server/local-llm-server/venv/bin/python3 /srv/server/local-llm-server/venv/bin/gunicorn"
# TODO: make sure gunicorn logs to stdout and logging also goes to stdout
# Need a lot of workers since we have long-running requests. This takes about 3.5G memory. # Need a lot of workers since we have long-running requests. This takes about 3.5G memory.
ExecStart=/srv/server/local-llm-server/venv/bin/gunicorn --workers 20 --bind 0.0.0.0:5000 server:app --timeout 60 --worker-class gevent --access-logfile '-' --error-logfile '-' ExecStart=/srv/server/local-llm-server/venv/bin/gunicorn --workers 20 --bind 0.0.0.0:5000 server:app --timeout 60 --worker-class gevent --access-logfile '-' --error-logfile '-'

View File

@ -4,7 +4,6 @@ Flask-Caching==2.0.2
requests~=2.31.0 requests~=2.31.0
tiktoken~=0.5.0 tiktoken~=0.5.0
gevent~=23.9.0.post1 gevent~=23.9.0.post1
mysql-connector-python==8.4.0
simplejson~=3.19.1 simplejson~=3.19.1
websockets~=11.0.3 websockets~=11.0.3
basicauth~=1.0.0 basicauth~=1.0.0
@ -17,3 +16,4 @@ vllm==0.2.7
coloredlogs~=15.0.1 coloredlogs~=15.0.1
git+https://git.evulid.cc/cyberes/bison.git git+https://git.evulid.cc/cyberes/bison.git
pydantic pydantic
psycopg2-binary==2.9.9

View File

@ -1,33 +1,13 @@
import time
from llm_server.config.global_config import GlobalConfig
try:
import gevent.monkey
gevent.monkey.patch_all()
except ImportError:
pass
import logging
import os
import sys
from pathlib import Path
import simplejson as json import simplejson as json
from flask import Flask, jsonify, render_template, request, Response from flask import Flask, jsonify, render_template, request, Response
import config
from llm_server.cluster.backend import get_model_choices from llm_server.cluster.backend import get_model_choices
from llm_server.cluster.cluster_config import cluster_config from llm_server.cluster.cluster_config import cluster_config
from llm_server.config.config import mode_ui_names from llm_server.config.config import MODE_UI_NAMES
from llm_server.config.load import load_config from llm_server.config.global_config import GlobalConfig
from llm_server.custom_redis import flask_cache, redis from llm_server.custom_redis import flask_cache, redis
from llm_server.database.conn import Database
from llm_server.database.create import create_db
from llm_server.helpers import auto_set_base_client_api from llm_server.helpers import auto_set_base_client_api
from llm_server.llm.vllm.info import vllm_info from llm_server.llm.vllm.info import vllm_info
from llm_server.logging import init_logging, create_logger
from llm_server.routes.openai import openai_bp, openai_model_bp from llm_server.routes.openai import openai_bp, openai_model_bp
from llm_server.routes.server_error import handle_server_error from llm_server.routes.server_error import handle_server_error
from llm_server.routes.v1 import bp from llm_server.routes.v1 import bp
@ -63,32 +43,6 @@ from llm_server.sock import init_wssocket
# TODO: add more excluding to SYSTEM__ tokens # TODO: add more excluding to SYSTEM__ tokens
# TODO: return 200 when returning formatted sillytavern error # TODO: return 200 when returning formatted sillytavern error
script_path = os.path.dirname(os.path.realpath(__file__))
config_path_environ = os.getenv("CONFIG_PATH")
if config_path_environ:
config_path = config_path_environ
else:
config_path = Path(script_path, 'config', 'config.yml')
success, msg = load_config(config_path)
if not success:
logger = logging.getLogger('llm_server')
logger.setLevel(logging.INFO)
logger.error(f'Failed to load config: {msg}')
sys.exit(1)
init_logging(Path(GlobalConfig.get().webserver_log_directory) / 'server.log')
logger = create_logger('Server')
logger.debug('Debug logging enabled.')
while not redis.get('daemon_started', dtype=bool):
logger.warning('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?')
time.sleep(10)
logger.info('Started HTTP worker!')
Database.initialise(maxconn=GlobalConfig.get().mysql.maxconn, host=GlobalConfig.get().mysql.host, user=GlobalConfig.get().mysql.username, password=GlobalConfig.get().mysql.password, database=GlobalConfig.get().mysql.database)
create_db()
app = Flask(__name__) app = Flask(__name__)
@ -139,13 +93,13 @@ def home():
# to None by the daemon. # to None by the daemon.
default_model_info['context_size'] = '-' default_model_info['context_size'] = '-'
if len(config['analytics_tracking_code']): if len(GlobalConfig.get().analytics_tracking_code):
analytics_tracking_code = f"<script>\n{config['analytics_tracking_code']}\n</script>" analytics_tracking_code = f"<script>\n{GlobalConfig.get().analytics_tracking_code}\n</script>"
else: else:
analytics_tracking_code = '' analytics_tracking_code = ''
if config['info_html']: if GlobalConfig.get().info_html:
info_html = config['info_html'] info_html = GlobalConfig.get().info_html
else: else:
info_html = '' info_html = ''
@ -166,9 +120,9 @@ def home():
client_api=f'https://{base_client_api}', client_api=f'https://{base_client_api}',
ws_client_api=f'wss://{base_client_api}/v1/stream' if GlobalConfig.get().enable_streaming else 'disabled', ws_client_api=f'wss://{base_client_api}/v1/stream' if GlobalConfig.get().enable_streaming else 'disabled',
default_estimated_wait=default_estimated_wait_sec, default_estimated_wait=default_estimated_wait_sec,
mode_name=mode_ui_names[GlobalConfig.get().frontend_api_mode][0], mode_name=MODE_UI_NAMES[GlobalConfig.get().frontend_api_mode].name,
api_input_textbox=mode_ui_names[GlobalConfig.get().frontend_api_mode][1], api_input_textbox=MODE_UI_NAMES[GlobalConfig.get().frontend_api_mode].api_name,
streaming_input_textbox=mode_ui_names[GlobalConfig.get().frontend_api_mode][2], streaming_input_textbox=MODE_UI_NAMES[GlobalConfig.get().frontend_api_mode].streaming_name,
default_context_size=default_model_info['context_size'], default_context_size=default_model_info['context_size'],
stats_json=json.dumps(stats, indent=4, ensure_ascii=False), stats_json=json.dumps(stats, indent=4, ensure_ascii=False),
extra_info=mode_info, extra_info=mode_info,
@ -212,6 +166,6 @@ def before_app_request():
if __name__ == "__main__": if __name__ == "__main__":
# server_startup(None) print('Do not run this file directly. Instead, use gunicorn:')
print('FLASK MODE - Startup complete!') print("gunicorn -c other/gunicorn_conf.py server:app -b 0.0.0.0:5000 --worker-class gevent --workers 3 --access-logfile '-' --error-logfile '-'")
app.run(host='0.0.0.0', threaded=False, processes=15) quit(1)