refactor a lot of things, major cleanup, use postgresql
This commit is contained in:
parent
ee9a0d4858
commit
fd09c783d3
|
@ -3,7 +3,6 @@ import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from redis import Redis
|
from redis import Redis
|
||||||
|
|
||||||
|
@ -14,6 +13,7 @@ from llm_server.custom_redis import redis
|
||||||
from llm_server.database.conn import Database
|
from llm_server.database.conn import Database
|
||||||
from llm_server.database.create import create_db
|
from llm_server.database.create import create_db
|
||||||
from llm_server.database.database import get_number_of_rows
|
from llm_server.database.database import get_number_of_rows
|
||||||
|
from llm_server.helpers import resolve_path
|
||||||
from llm_server.logging import create_logger, logging_info, init_logging
|
from llm_server.logging import create_logger, logging_info, init_logging
|
||||||
from llm_server.routes.v1.generate_stats import generate_stats
|
from llm_server.routes.v1.generate_stats import generate_stats
|
||||||
from llm_server.workers.threader import start_background
|
from llm_server.workers.threader import start_background
|
||||||
|
@ -23,7 +23,7 @@ config_path_environ = os.getenv("CONFIG_PATH")
|
||||||
if config_path_environ:
|
if config_path_environ:
|
||||||
config_path = config_path_environ
|
config_path = config_path_environ
|
||||||
else:
|
else:
|
||||||
config_path = Path(script_path, 'config', 'config.yml')
|
config_path = resolve_path(script_path, 'config', 'config.yml')
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description='Daemon microservice.')
|
parser = argparse.ArgumentParser(description='Daemon microservice.')
|
||||||
|
@ -47,7 +47,7 @@ if __name__ == "__main__":
|
||||||
logger.info(f'Failed to load config: {msg}')
|
logger.info(f'Failed to load config: {msg}')
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
Database.initialise(maxconn=GlobalConfig.get().mysql.maxconn, host=GlobalConfig.get().mysql.host, user=GlobalConfig.get().mysql.username, password=GlobalConfig.get().mysql.password, database=GlobalConfig.get().mysql.database)
|
Database.initialise(**GlobalConfig.get().postgresql.dict())
|
||||||
create_db()
|
create_db()
|
||||||
|
|
||||||
cluster_config.clear()
|
cluster_config.clear()
|
||||||
|
@ -57,7 +57,7 @@ if __name__ == "__main__":
|
||||||
generate_stats(regen=True)
|
generate_stats(regen=True)
|
||||||
|
|
||||||
if GlobalConfig.get().load_num_prompts:
|
if GlobalConfig.get().load_num_prompts:
|
||||||
redis.set('proompts', get_number_of_rows('prompts'))
|
redis.set('proompts', get_number_of_rows('messages'))
|
||||||
|
|
||||||
start_background()
|
start_background()
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@ def get_backends_from_model(model_name: str):
|
||||||
:param model_name:
|
:param model_name:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
assert isinstance(model_name, str)
|
||||||
return [x.decode('utf-8') for x in redis_running_models.smembers(model_name)]
|
return [x.decode('utf-8') for x in redis_running_models.smembers(model_name)]
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,7 +26,7 @@ def get_running_models():
|
||||||
Get all the models that are in the cluster.
|
Get all the models that are in the cluster.
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
return list(redis_running_models.keys())
|
return [x.decode('utf-8') for x in list(redis_running_models.keys())]
|
||||||
|
|
||||||
|
|
||||||
def is_valid_model(model_name: str) -> bool:
|
def is_valid_model(model_name: str) -> bool:
|
||||||
|
@ -81,6 +82,7 @@ def get_model_choices(regen: bool = False) -> tuple[dict, dict]:
|
||||||
|
|
||||||
base_client_api = redis.get('base_client_api', dtype=str)
|
base_client_api = redis.get('base_client_api', dtype=str)
|
||||||
running_models = get_running_models()
|
running_models = get_running_models()
|
||||||
|
|
||||||
model_choices = {}
|
model_choices = {}
|
||||||
for model in running_models:
|
for model in running_models:
|
||||||
b = get_backends_from_model(model)
|
b = get_backends_from_model(model)
|
||||||
|
|
|
@ -33,7 +33,7 @@ class RedisClusterStore:
|
||||||
item.backend_url = backend_url
|
item.backend_url = backend_url
|
||||||
stuff[backend_url] = item
|
stuff[backend_url] = item
|
||||||
for k, v in stuff.items():
|
for k, v in stuff.items():
|
||||||
self.add_backend(k, v)
|
self.add_backend(k, v.dict())
|
||||||
|
|
||||||
def add_backend(self, name: str, values: dict):
|
def add_backend(self, name: str, values: dict):
|
||||||
self.config_redis.hset(name, mapping={k: pickle.dumps(v) for k, v in values.items()})
|
self.config_redis.hset(name, mapping={k: pickle.dumps(v) for k, v in values.items()})
|
||||||
|
|
|
@ -1,14 +1,22 @@
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llm_server.config.global_config import GlobalConfig
|
from llm_server.config.global_config import GlobalConfig
|
||||||
|
|
||||||
|
|
||||||
def cluster_worker_count():
|
def cluster_worker_count():
|
||||||
count = 0
|
count = 0
|
||||||
for item in GlobalConfig.get().cluster:
|
for item in GlobalConfig.get().cluster:
|
||||||
count += item['concurrent_gens']
|
count += item.concurrent_gens
|
||||||
return count
|
return count
|
||||||
|
|
||||||
|
|
||||||
mode_ui_names = {
|
class ModeUINameStr(BaseModel):
|
||||||
'ooba': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
|
name: str
|
||||||
'vllm': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
|
api_name: str
|
||||||
|
streaming_name: str
|
||||||
|
|
||||||
|
|
||||||
|
MODE_UI_NAMES = {
|
||||||
|
'ooba': ModeUINameStr(name='Text Gen WebUI (ooba)', api_name='Blocking API url', streaming_name='Streaming API url'),
|
||||||
|
'vllm': ModeUINameStr(name='Text Gen WebUI (ooba)', api_name='Blocking API url', streaming_name='Streaming API url'),
|
||||||
}
|
}
|
||||||
|
|
|
@ -63,8 +63,8 @@ def load_config(config_path: Path):
|
||||||
config_model = ConfigModel(**config.config)
|
config_model = ConfigModel(**config.config)
|
||||||
GlobalConfig.initalize(config_model)
|
GlobalConfig.initalize(config_model)
|
||||||
|
|
||||||
if not (0 < GlobalConfig.get().mysql.maxconn <= 32):
|
if GlobalConfig.get().postgresql.maxconn < 0:
|
||||||
return False, f'"maxcon" should be higher than 0 and lower or equal to 32. Current value: "{GlobalConfig.get().mysql.maxconn}"'
|
return False, f'"maxcon" should be higher than 0. Current value: "{GlobalConfig.get().postgresql.maxconn}"'
|
||||||
|
|
||||||
openai.api_key = GlobalConfig.get().openai_api_key
|
openai.api_key = GlobalConfig.get().openai_api_key
|
||||||
|
|
||||||
|
|
|
@ -19,9 +19,9 @@ class ConfigFrontendApiModes(str, Enum):
|
||||||
ooba = 'ooba'
|
ooba = 'ooba'
|
||||||
|
|
||||||
|
|
||||||
class ConfigMysql(BaseModel):
|
class ConfigPostgresql(BaseModel):
|
||||||
host: str
|
host: str
|
||||||
username: str
|
user: str
|
||||||
password: str
|
password: str
|
||||||
database: str
|
database: str
|
||||||
maxconn: int
|
maxconn: int
|
||||||
|
@ -37,9 +37,8 @@ class ConfigModel(BaseModel):
|
||||||
cluster: List[ConfigCluser]
|
cluster: List[ConfigCluser]
|
||||||
prioritize_by_size: bool
|
prioritize_by_size: bool
|
||||||
admin_token: Union[str, None]
|
admin_token: Union[str, None]
|
||||||
mysql: ConfigMysql
|
postgresql: ConfigPostgresql
|
||||||
http_host: str
|
http_host: str
|
||||||
webserver_log_directory: str
|
|
||||||
include_system_tokens_in_stats: bool
|
include_system_tokens_in_stats: bool
|
||||||
background_homepage_cacher: bool
|
background_homepage_cacher: bool
|
||||||
max_new_tokens: int
|
max_new_tokens: int
|
||||||
|
@ -55,6 +54,7 @@ class ConfigModel(BaseModel):
|
||||||
info_html: Union[str, None]
|
info_html: Union[str, None]
|
||||||
enable_openi_compatible_backend: bool
|
enable_openi_compatible_backend: bool
|
||||||
openai_api_key: Union[str, None]
|
openai_api_key: Union[str, None]
|
||||||
|
openai_system_prompt: str
|
||||||
expose_openai_system_prompt: bool
|
expose_openai_system_prompt: bool
|
||||||
openai_expose_our_model: bool
|
openai_expose_our_model: bool
|
||||||
openai_force_no_hashes: bool
|
openai_force_no_hashes: bool
|
||||||
|
@ -72,3 +72,4 @@ class ConfigModel(BaseModel):
|
||||||
load_num_prompts: bool
|
load_num_prompts: bool
|
||||||
manual_model_name: Union[str, None]
|
manual_model_name: Union[str, None]
|
||||||
backend_request_timeout: int
|
backend_request_timeout: int
|
||||||
|
backend_generate_request_timeout: int
|
||||||
|
|
|
@ -2,7 +2,7 @@ from typing import Union
|
||||||
|
|
||||||
import bison
|
import bison
|
||||||
|
|
||||||
from llm_server.opts import default_openai_system_prompt
|
from llm_server.globals import DEFAULT_OPENAI_SYSTEM_PROMPT
|
||||||
|
|
||||||
config_scheme = bison.Scheme(
|
config_scheme = bison.Scheme(
|
||||||
bison.Option('frontend_api_mode', choices=['ooba'], field_type=str),
|
bison.Option('frontend_api_mode', choices=['ooba'], field_type=str),
|
||||||
|
@ -14,15 +14,14 @@ config_scheme = bison.Scheme(
|
||||||
)),
|
)),
|
||||||
bison.Option('prioritize_by_size', default=True, field_type=bool),
|
bison.Option('prioritize_by_size', default=True, field_type=bool),
|
||||||
bison.Option('admin_token', default=None, field_type=Union[str, None]),
|
bison.Option('admin_token', default=None, field_type=Union[str, None]),
|
||||||
bison.ListOption('mysql', member_scheme=bison.Scheme(
|
bison.ListOption('postgresql', member_scheme=bison.Scheme(
|
||||||
bison.Option('host', field_type=str),
|
bison.Option('host', field_type=str),
|
||||||
bison.Option('username', field_type=str),
|
bison.Option('user', field_type=str),
|
||||||
bison.Option('password', field_type=str),
|
bison.Option('password', field_type=str),
|
||||||
bison.Option('database', field_type=str),
|
bison.Option('database', field_type=str),
|
||||||
bison.Option('maxconn', field_type=int)
|
bison.Option('maxconn', field_type=int)
|
||||||
)),
|
)),
|
||||||
bison.Option('http_host', default='', field_type=str),
|
bison.Option('http_host', default='', field_type=str),
|
||||||
bison.Option('webserver_log_directory', default='/var/log/localllm', field_type=str),
|
|
||||||
bison.Option('include_system_tokens_in_stats', default=True, field_type=bool),
|
bison.Option('include_system_tokens_in_stats', default=True, field_type=bool),
|
||||||
bison.Option('background_homepage_cacher', default=True, field_type=bool),
|
bison.Option('background_homepage_cacher', default=True, field_type=bool),
|
||||||
bison.Option('max_new_tokens', default=500, field_type=int),
|
bison.Option('max_new_tokens', default=500, field_type=int),
|
||||||
|
@ -41,7 +40,7 @@ config_scheme = bison.Scheme(
|
||||||
bison.Option('expose_openai_system_prompt', default=True, field_type=bool),
|
bison.Option('expose_openai_system_prompt', default=True, field_type=bool),
|
||||||
bison.Option('openai_expose_our_model', default='', field_type=bool),
|
bison.Option('openai_expose_our_model', default='', field_type=bool),
|
||||||
bison.Option('openai_force_no_hashes', default=True, field_type=bool),
|
bison.Option('openai_force_no_hashes', default=True, field_type=bool),
|
||||||
bison.Option('openai_system_prompt', default=default_openai_system_prompt, field_type=str),
|
bison.Option('openai_system_prompt', default=DEFAULT_OPENAI_SYSTEM_PROMPT, field_type=str),
|
||||||
bison.Option('openai_moderation_enabled', default=False, field_type=bool),
|
bison.Option('openai_moderation_enabled', default=False, field_type=bool),
|
||||||
bison.Option('openai_moderation_timeout', default=5, field_type=int),
|
bison.Option('openai_moderation_timeout', default=5, field_type=int),
|
||||||
bison.Option('openai_moderation_scan_last_n', default=5, field_type=int),
|
bison.Option('openai_moderation_scan_last_n', default=5, field_type=int),
|
||||||
|
@ -55,5 +54,6 @@ config_scheme = bison.Scheme(
|
||||||
bison.Option('show_backend_info', default=True, field_type=bool),
|
bison.Option('show_backend_info', default=True, field_type=bool),
|
||||||
bison.Option('load_num_prompts', default=True, field_type=bool),
|
bison.Option('load_num_prompts', default=True, field_type=bool),
|
||||||
bison.Option('manual_model_name', default=None, field_type=Union[str, None]),
|
bison.Option('manual_model_name', default=None, field_type=Union[str, None]),
|
||||||
bison.Option('backend_request_timeout', default=30, field_type=int)
|
bison.Option('backend_request_timeout', default=30, field_type=int),
|
||||||
|
bison.Option('backend_generate_request_timeout', default=95, field_type=int)
|
||||||
)
|
)
|
||||||
|
|
|
@ -2,13 +2,13 @@ import logging
|
||||||
import pickle
|
import pickle
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Callable, List, Mapping, Optional, Union
|
from typing import Union
|
||||||
|
|
||||||
import redis as redis_pkg
|
import redis as redis_pkg
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
from flask_caching import Cache
|
from flask_caching import Cache
|
||||||
from redis import Redis
|
from redis import Redis
|
||||||
from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, PatternT, ZScoreBoundT, AbsExpiryT
|
from redis.typing import ExpiryT, KeyT, PatternT
|
||||||
|
|
||||||
flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/15', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
|
flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/15', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
|
||||||
|
|
||||||
|
@ -38,18 +38,11 @@ class RedisCustom(Redis):
|
||||||
def _key(self, key):
|
def _key(self, key):
|
||||||
return f"{self.prefix}:{key}"
|
return f"{self.prefix}:{key}"
|
||||||
|
|
||||||
def set(self, key: KeyT,
|
def execute_command(self, *args, **options):
|
||||||
value: EncodableT,
|
if args[0] != 'GET':
|
||||||
ex: Union[ExpiryT, None] = None,
|
args = list(args)
|
||||||
px: Union[ExpiryT, None] = None,
|
args[1] = self._key(args[1])
|
||||||
nx: bool = False,
|
return super().execute_command(*args, **options)
|
||||||
xx: bool = False,
|
|
||||||
keepttl: bool = False,
|
|
||||||
get: bool = False,
|
|
||||||
exat: Union[AbsExpiryT, None] = None,
|
|
||||||
pxat: Union[AbsExpiryT, None] = None
|
|
||||||
):
|
|
||||||
return self.redis.set(self._key(key), value, ex=ex)
|
|
||||||
|
|
||||||
def get(self, key, default=None, dtype=None):
|
def get(self, key, default=None, dtype=None):
|
||||||
# TODO: use pickle
|
# TODO: use pickle
|
||||||
|
@ -73,103 +66,6 @@ class RedisCustom(Redis):
|
||||||
else:
|
else:
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def incr(self, key, amount=1):
|
|
||||||
return self.redis.incr(self._key(key), amount)
|
|
||||||
|
|
||||||
def decr(self, key, amount=1):
|
|
||||||
return self.redis.decr(self._key(key), amount)
|
|
||||||
|
|
||||||
def sadd(self, key: str, *values: FieldT):
|
|
||||||
return self.redis.sadd(self._key(key), *values)
|
|
||||||
|
|
||||||
def srem(self, key: str, *values: FieldT):
|
|
||||||
return self.redis.srem(self._key(key), *values)
|
|
||||||
|
|
||||||
def sismember(self, key: str, value: str):
|
|
||||||
return self.redis.sismember(self._key(key), value)
|
|
||||||
|
|
||||||
def lindex(
|
|
||||||
self, name: str, index: int
|
|
||||||
):
|
|
||||||
return self.redis.lindex(self._key(name), index)
|
|
||||||
|
|
||||||
def lrem(self, name: str, count: int, value: str):
|
|
||||||
return self.redis.lrem(self._key(name), count, value)
|
|
||||||
|
|
||||||
def rpush(self, name: str, *values: FieldT):
|
|
||||||
return self.redis.rpush(self._key(name), *values)
|
|
||||||
|
|
||||||
def llen(self, name: str):
|
|
||||||
return self.redis.llen(self._key(name))
|
|
||||||
|
|
||||||
def zrangebyscore(
|
|
||||||
self,
|
|
||||||
name: KeyT,
|
|
||||||
min: ZScoreBoundT,
|
|
||||||
max: ZScoreBoundT,
|
|
||||||
start: Union[int, None] = None,
|
|
||||||
num: Union[int, None] = None,
|
|
||||||
withscores: bool = False,
|
|
||||||
score_cast_func: Union[type, Callable] = float,
|
|
||||||
):
|
|
||||||
return self.redis.zrangebyscore(self._key(name), min, max, start, num, withscores, score_cast_func)
|
|
||||||
|
|
||||||
def zremrangebyscore(
|
|
||||||
self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT
|
|
||||||
):
|
|
||||||
return self.redis.zremrangebyscore(self._key(name), min, max)
|
|
||||||
|
|
||||||
def hincrby(
|
|
||||||
self, name: str, key: str, amount: int = 1
|
|
||||||
):
|
|
||||||
return self.redis.hincrby(self._key(name), key, amount)
|
|
||||||
|
|
||||||
def zcard(self, name: KeyT):
|
|
||||||
return self.redis.zcard(self._key(name))
|
|
||||||
|
|
||||||
def hdel(self, name: str, *keys: str):
|
|
||||||
return self.redis.hdel(self._key(name), *keys)
|
|
||||||
|
|
||||||
def hget(
|
|
||||||
self, name: str, key: str
|
|
||||||
):
|
|
||||||
return self.redis.hget(self._key(name), key)
|
|
||||||
|
|
||||||
def zadd(
|
|
||||||
self,
|
|
||||||
name: KeyT,
|
|
||||||
mapping: Mapping[AnyKeyT, EncodableT],
|
|
||||||
nx: bool = False,
|
|
||||||
xx: bool = False,
|
|
||||||
ch: bool = False,
|
|
||||||
incr: bool = False,
|
|
||||||
gt: bool = False,
|
|
||||||
lt: bool = False,
|
|
||||||
):
|
|
||||||
return self.redis.zadd(self._key(name), mapping, nx, xx, ch, incr, gt, lt)
|
|
||||||
|
|
||||||
def lpush(self, name: str, *values: FieldT):
|
|
||||||
return self.redis.lpush(self._key(name), *values)
|
|
||||||
|
|
||||||
def hset(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
key: Optional = None,
|
|
||||||
value=None,
|
|
||||||
mapping: Optional[dict] = None,
|
|
||||||
items: Optional[list] = None,
|
|
||||||
):
|
|
||||||
return self.redis.hset(self._key(name), key, value, mapping, items)
|
|
||||||
|
|
||||||
def hkeys(self, name: str):
|
|
||||||
return self.redis.hkeys(self._key(name))
|
|
||||||
|
|
||||||
def hmget(self, name: str, keys: List, *args: List):
|
|
||||||
return self.redis.hmget(self._key(name), keys, *args)
|
|
||||||
|
|
||||||
def hgetall(self, name: str):
|
|
||||||
return self.redis.hgetall(self._key(name))
|
|
||||||
|
|
||||||
def keys(self, pattern: PatternT = "*", **kwargs):
|
def keys(self, pattern: PatternT = "*", **kwargs):
|
||||||
raw_keys = self.redis.keys(self._key(pattern), **kwargs)
|
raw_keys = self.redis.keys(self._key(pattern), **kwargs)
|
||||||
keys = []
|
keys = []
|
||||||
|
@ -179,25 +75,9 @@ class RedisCustom(Redis):
|
||||||
# Delete prefix
|
# Delete prefix
|
||||||
del p[0]
|
del p[0]
|
||||||
k = ':'.join(p)
|
k = ':'.join(p)
|
||||||
if k != '____':
|
# keys.append(k)
|
||||||
keys.append(k)
|
|
||||||
return keys
|
return keys
|
||||||
|
|
||||||
def pipeline(self, transaction=True, shard_hint=None):
|
|
||||||
return self.redis.pipeline(transaction, shard_hint)
|
|
||||||
|
|
||||||
def smembers(self, name: str):
|
|
||||||
return self.redis.smembers(self._key(name))
|
|
||||||
|
|
||||||
def spop(self, name: str, count: Optional[int] = None):
|
|
||||||
return self.redis.spop(self._key(name), count)
|
|
||||||
|
|
||||||
def rpoplpush(self, src, dst):
|
|
||||||
return self.redis.rpoplpush(src, dst)
|
|
||||||
|
|
||||||
def zpopmin(self, name: KeyT, count: Union[int, None] = None):
|
|
||||||
return self.redis.zpopmin(self._key(name), count)
|
|
||||||
|
|
||||||
def exists(self, *names: KeyT):
|
def exists(self, *names: KeyT):
|
||||||
n = []
|
n = []
|
||||||
for name in names:
|
for name in names:
|
||||||
|
@ -238,32 +118,5 @@ class RedisCustom(Redis):
|
||||||
self.flush()
|
self.flush()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def lrange(self, name: str, start: int, end: int):
|
|
||||||
return self.redis.lrange(self._key(name), start, end)
|
|
||||||
|
|
||||||
def delete(self, *names: KeyT):
|
|
||||||
return self.redis.delete(*[self._key(i) for i in names])
|
|
||||||
|
|
||||||
def lpop(self, name: str, count: Optional[int] = None):
|
|
||||||
return self.redis.lpop(self._key(name), count)
|
|
||||||
|
|
||||||
def zrange(
|
|
||||||
self,
|
|
||||||
name: KeyT,
|
|
||||||
start: int,
|
|
||||||
end: int,
|
|
||||||
desc: bool = False,
|
|
||||||
withscores: bool = False,
|
|
||||||
score_cast_func: Union[type, Callable] = float,
|
|
||||||
byscore: bool = False,
|
|
||||||
bylex: bool = False,
|
|
||||||
offset: int = None,
|
|
||||||
num: int = None,
|
|
||||||
):
|
|
||||||
return self.redis.zrange(self._key(name), start, end, desc, withscores, score_cast_func, byscore, bylex, offset, num)
|
|
||||||
|
|
||||||
def zrem(self, name: KeyT, *values: FieldT):
|
|
||||||
return self.redis.zrem(self._key(name), *values)
|
|
||||||
|
|
||||||
|
|
||||||
redis = RedisCustom('local_llm')
|
redis = RedisCustom('local_llm')
|
||||||
|
|
|
@ -1,39 +1,42 @@
|
||||||
from mysql.connector import pooling
|
from psycopg2 import pool, InterfaceError
|
||||||
|
|
||||||
|
|
||||||
class Database:
|
class Database:
|
||||||
__connection_pool = None
|
__connection_pool = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def initialise(cls, maxconn: int, **kwargs):
|
def initialise(cls, maxconn, **kwargs):
|
||||||
if cls.__connection_pool is not None:
|
if cls.__connection_pool is not None:
|
||||||
raise Exception('Database connection pool is already initialised')
|
raise Exception('Database connection pool is already initialised')
|
||||||
cls.__connection_pool = pooling.MySQLConnectionPool(pool_size=maxconn,
|
cls.__connection_pool = pool.ThreadedConnectionPool(minconn=1, maxconn=maxconn, **kwargs)
|
||||||
pool_reset_session=True,
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_connection(cls):
|
def get_connection(cls):
|
||||||
return cls.__connection_pool.get_connection()
|
return cls.__connection_pool.getconn()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def return_connection(cls, connection):
|
def return_connection(cls, connection):
|
||||||
connection.close()
|
cls.__connection_pool.putconn(connection)
|
||||||
|
|
||||||
|
|
||||||
class CursorFromConnectionFromPool:
|
class CursorFromConnectionFromPool:
|
||||||
def __init__(self):
|
def __init__(self, cursor_factory=None):
|
||||||
self.conn = None
|
self.conn = None
|
||||||
self.cursor = None
|
self.cursor = None
|
||||||
|
self.cursor_factory = cursor_factory
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.conn = Database.get_connection()
|
self.conn = Database.get_connection()
|
||||||
self.cursor = self.conn.cursor()
|
self.cursor = self.conn.cursor(cursor_factory=self.cursor_factory)
|
||||||
return self.cursor
|
return self.cursor
|
||||||
|
|
||||||
def __exit__(self, exception_type, exception_value, exception_traceback):
|
def __exit__(self, exception_type, exception_value, exception_traceback):
|
||||||
if exception_value is not None: # This is equivalent of saying if there is an exception
|
if exception_value is not None: # This is equivalent of saying if there is an exception
|
||||||
self.conn.rollback()
|
try:
|
||||||
|
self.conn.rollback()
|
||||||
|
except InterfaceError as e:
|
||||||
|
if e != 'connection already closed':
|
||||||
|
raise
|
||||||
else:
|
else:
|
||||||
self.cursor.close()
|
self.cursor.close()
|
||||||
self.conn.commit()
|
self.conn.commit()
|
||||||
|
|
|
@ -4,36 +4,39 @@ from llm_server.database.conn import CursorFromConnectionFromPool
|
||||||
def create_db():
|
def create_db():
|
||||||
with CursorFromConnectionFromPool() as cursor:
|
with CursorFromConnectionFromPool() as cursor:
|
||||||
cursor.execute('''
|
cursor.execute('''
|
||||||
CREATE TABLE IF NOT EXISTS prompts (
|
CREATE TABLE IF NOT EXISTS public.messages
|
||||||
ip TEXT,
|
(
|
||||||
token TEXT DEFAULT NULL,
|
ip text COLLATE pg_catalog."default" NOT NULL,
|
||||||
model TEXT,
|
token text COLLATE pg_catalog."default",
|
||||||
backend_mode TEXT,
|
model text COLLATE pg_catalog."default" NOT NULL,
|
||||||
backend_url TEXT,
|
backend_mode text COLLATE pg_catalog."default" NOT NULL,
|
||||||
request_url TEXT,
|
backend_url text COLLATE pg_catalog."default" NOT NULL,
|
||||||
generation_time FLOAT,
|
request_url text COLLATE pg_catalog."default" NOT NULL,
|
||||||
prompt LONGTEXT,
|
generation_time double precision NOT NULL,
|
||||||
prompt_tokens INTEGER,
|
prompt text COLLATE pg_catalog."default" NOT NULL,
|
||||||
response LONGTEXT,
|
prompt_tokens integer NOT NULL,
|
||||||
response_tokens INTEGER,
|
response text COLLATE pg_catalog."default" NOT NULL,
|
||||||
response_status INTEGER,
|
response_tokens integer NOT NULL,
|
||||||
parameters TEXT,
|
response_status integer NOT NULL,
|
||||||
# CHECK (parameters IS NULL OR JSON_VALID(parameters)),
|
parameters jsonb NOT NULL,
|
||||||
headers TEXT,
|
headers jsonb,
|
||||||
# CHECK (headers IS NULL OR JSON_VALID(headers)),
|
"timestamp" timestamp with time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
timestamp INTEGER
|
id SERIAL PRIMARY KEY
|
||||||
)
|
);
|
||||||
''')
|
''')
|
||||||
cursor.execute('''
|
cursor.execute('''
|
||||||
CREATE TABLE IF NOT EXISTS token_auth (
|
CREATE TABLE IF NOT EXISTS public.token_auth
|
||||||
token TEXT,
|
(
|
||||||
UNIQUE (token),
|
token text COLLATE pg_catalog."default" NOT NULL,
|
||||||
type TEXT NOT NULL,
|
type text COLLATE pg_catalog."default" NOT NULL,
|
||||||
priority INTEGER DEFAULT 9999,
|
priority integer NOT NULL DEFAULT 9999,
|
||||||
simultaneous_ip INTEGER DEFAULT NULL,
|
simultaneous_ip text COLLATE pg_catalog."default",
|
||||||
uses INTEGER DEFAULT 0,
|
openai_moderation_enabled boolean NOT NULL DEFAULT true,
|
||||||
max_uses INTEGER,
|
uses integer NOT NULL DEFAULT 0,
|
||||||
expire INTEGER,
|
max_uses integer,
|
||||||
disabled BOOLEAN DEFAULT 0
|
expire timestamp with time zone,
|
||||||
)
|
disabled boolean NOT NULL DEFAULT false,
|
||||||
|
notes text COLLATE pg_catalog."default" NOT NULL DEFAULT ''::text,
|
||||||
|
CONSTRAINT token_auth_pkey PRIMARY KEY (token)
|
||||||
|
)
|
||||||
''')
|
''')
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
from datetime import datetime, timedelta
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from llm_server.cluster.cluster_config import cluster_config
|
from llm_server.cluster.cluster_config import cluster_config
|
||||||
|
@ -51,10 +52,10 @@ def do_db_log(ip: str, token: str, prompt: str, response: Union[str, None], gen_
|
||||||
backend_info = cluster_config.get_backend(backend_url)
|
backend_info = cluster_config.get_backend(backend_url)
|
||||||
running_model = backend_info.get('model')
|
running_model = backend_info.get('model')
|
||||||
backend_mode = backend_info['mode']
|
backend_mode = backend_info['mode']
|
||||||
timestamp = int(time.time())
|
timestamp = datetime.now()
|
||||||
with CursorFromConnectionFromPool() as cursor:
|
with CursorFromConnectionFromPool() as cursor:
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
INSERT INTO prompts
|
INSERT INTO messages
|
||||||
(ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp)
|
(ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp)
|
||||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||||
""",
|
""",
|
||||||
|
@ -65,12 +66,12 @@ def is_valid_api_key(api_key):
|
||||||
with CursorFromConnectionFromPool() as cursor:
|
with CursorFromConnectionFromPool() as cursor:
|
||||||
cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = %s", (api_key,))
|
cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = %s", (api_key,))
|
||||||
row = cursor.fetchone()
|
row = cursor.fetchone()
|
||||||
if row is not None:
|
if row is not None:
|
||||||
token, uses, max_uses, expire, disabled = row
|
token, uses, max_uses, expire, disabled = row
|
||||||
disabled = bool(disabled)
|
disabled = bool(disabled)
|
||||||
if ((uses is None or max_uses is None) or uses < max_uses) and (expire is None or expire > time.time()) and not disabled:
|
if ((uses is None or max_uses is None) or uses < max_uses) and (expire is None or expire > time.time()) and not disabled:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_api_key_moderated(api_key):
|
def is_api_key_moderated(api_key):
|
||||||
|
@ -146,9 +147,9 @@ def sum_column(table_name, column_name):
|
||||||
|
|
||||||
def get_distinct_ips_24h():
|
def get_distinct_ips_24h():
|
||||||
# Get the current time and subtract 24 hours (in seconds)
|
# Get the current time and subtract 24 hours (in seconds)
|
||||||
past_24_hours = int(time.time()) - 24 * 60 * 60
|
past_24_hours = datetime.now() - timedelta(days=1)
|
||||||
with CursorFromConnectionFromPool() as cursor:
|
with CursorFromConnectionFromPool() as cursor:
|
||||||
cursor.execute("SELECT COUNT(DISTINCT ip) FROM prompts WHERE timestamp >= %s AND (token NOT LIKE 'SYSTEM__%%' OR token IS NULL)", (past_24_hours,))
|
cursor.execute("SELECT COUNT(DISTINCT ip) FROM messages WHERE timestamp >= %s AND (token NOT LIKE 'SYSTEM__%%' OR token IS NULL)", (past_24_hours,))
|
||||||
result = cursor.fetchone()
|
result = cursor.fetchone()
|
||||||
return result[0] if result else 0
|
return result[0] if result else 0
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,6 @@
|
||||||
# Read-only global variables
|
# Read-only global variables
|
||||||
|
|
||||||
default_openai_system_prompt = """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nLines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user.\n\n"""
|
DEFAULT_OPENAI_SYSTEM_PROMPT = """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nLines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user.\n\n"""
|
||||||
|
|
||||||
# cluster = {}
|
|
||||||
|
|
||||||
|
|
||||||
REDIS_STREAM_TIMEOUT = 25000
|
REDIS_STREAM_TIMEOUT = 25000
|
||||||
|
|
||||||
LOGGING_FORMAT = "%(asctime)s: %(levelname)s:%(name)s - %(message)s"
|
LOGGING_FORMAT = "%(asctime)s: %(levelname)s:%(name)s - %(message)s"
|
||||||
|
BACKEND_OFFLINE = 'The model you requested is not a valid choice. Please retry your query.'
|
|
@ -15,19 +15,6 @@ def resolve_path(*p: str):
|
||||||
return Path(*p).expanduser().resolve().absolute()
|
return Path(*p).expanduser().resolve().absolute()
|
||||||
|
|
||||||
|
|
||||||
def safe_list_get(l, idx, default):
|
|
||||||
"""
|
|
||||||
https://stackoverflow.com/a/5125636
|
|
||||||
:param l:
|
|
||||||
:param idx:
|
|
||||||
:param default:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return l[idx]
|
|
||||||
except IndexError:
|
|
||||||
return default
|
|
||||||
|
|
||||||
|
|
||||||
def deep_sort(obj):
|
def deep_sort(obj):
|
||||||
if isinstance(obj, dict):
|
if isinstance(obj, dict):
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from llm_server import opts
|
from llm_server import globals
|
||||||
from llm_server.cluster.cluster_config import cluster_config
|
from llm_server.cluster.cluster_config import cluster_config
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -8,14 +8,14 @@ import requests
|
||||||
from llm_server.config.global_config import GlobalConfig
|
from llm_server.config.global_config import GlobalConfig
|
||||||
|
|
||||||
|
|
||||||
def generate(json_data: dict):
|
# def generate(json_data: dict):
|
||||||
try:
|
# try:
|
||||||
r = requests.post(f'{GlobalConfig.get().backend_url}/api/v1/generate', json=json_data, verify=GlobalConfig.get().verify_ssl, timeout=GlobalConfig.get().backend_generate_request_timeout)
|
# r = requests.post(f'{GlobalConfig.get().backend_url}/api/v1/generate', json=json_data, verify=GlobalConfig.get().verify_ssl, timeout=GlobalConfig.get().backend_generate_request_timeout)
|
||||||
except requests.exceptions.ReadTimeout:
|
# except requests.exceptions.ReadTimeout:
|
||||||
return False, None, 'Request to backend timed out'
|
# return False, None, 'Request to backend timed out'
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
traceback.print_exc()
|
# traceback.print_exc()
|
||||||
return False, None, 'Request to backend encountered error'
|
# return False, None, 'Request to backend encountered error'
|
||||||
if r.status_code != 200:
|
# if r.status_code != 200:
|
||||||
return False, r, f'Backend returned {r.status_code}'
|
# return False, r, f'Backend returned {r.status_code}'
|
||||||
return True, r, None
|
# return True, r, None
|
||||||
|
|
|
@ -98,3 +98,14 @@ def return_invalid_model_err(requested_model: str):
|
||||||
"code": "model_not_found"
|
"code": "model_not_found"
|
||||||
}
|
}
|
||||||
}), 404
|
}), 404
|
||||||
|
|
||||||
|
|
||||||
|
def return_oai_internal_server_error():
|
||||||
|
return jsonify({
|
||||||
|
"error": {
|
||||||
|
"message": "Internal server error",
|
||||||
|
"type": "auth_subrequest_error",
|
||||||
|
"param": None,
|
||||||
|
"code": "internal_error"
|
||||||
|
}
|
||||||
|
}), 500
|
||||||
|
|
|
@ -31,7 +31,7 @@ def handle_blocking_request(json_data: dict, cluster_backend, timeout: int = 10)
|
||||||
return False, None, 'Request to backend timed out'
|
return False, None, 'Request to backend timed out'
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}')
|
# print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}')
|
||||||
return False, None, 'Request to backend encountered error'
|
return False, None, f'Request to backend encountered error -- {e.__class__.__name__}: {e}'
|
||||||
if r.status_code != 200:
|
if r.status_code != 200:
|
||||||
# print(f'Failed to reach VLLM inference endpoint - got code {r.status_code}')
|
# print(f'Failed to reach VLLM inference endpoint - got code {r.status_code}')
|
||||||
return False, r, f'Backend returned {r.status_code}'
|
return False, r, f'Backend returned {r.status_code}'
|
||||||
|
|
|
@ -1,16 +1,14 @@
|
||||||
import logging
|
import logging
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import coloredlogs
|
import coloredlogs
|
||||||
|
|
||||||
from llm_server import opts
|
from llm_server import globals
|
||||||
|
|
||||||
|
|
||||||
class LoggingInfo:
|
class LoggingInfo:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._level = logging.INFO
|
self._level = logging.INFO
|
||||||
self._format = opts.LOGGING_FORMAT
|
self._format = globals.LOGGING_FORMAT
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def level(self):
|
def level(self):
|
||||||
|
@ -30,30 +28,17 @@ class LoggingInfo:
|
||||||
|
|
||||||
|
|
||||||
logging_info = LoggingInfo()
|
logging_info = LoggingInfo()
|
||||||
LOG_DIRECTORY = None
|
|
||||||
|
|
||||||
|
|
||||||
def init_logging(filepath: Path = None):
|
def init_logging():
|
||||||
"""
|
"""
|
||||||
Set up the parent logger. Ensures this logger and all children to log to a file.
|
Set up the parent logger. Ensures this logger and all children to log to a file.
|
||||||
This is only called by `server.py` since there is wierdness with Gunicorn. The deamon doesn't need this.
|
This is only called by `server.py` since there is wierdness with Gunicorn. The deamon doesn't need this.
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
global LOG_DIRECTORY
|
|
||||||
logger = logging.getLogger('llm_server')
|
logger = logging.getLogger('llm_server')
|
||||||
logger.setLevel(logging_info.level)
|
logger.setLevel(logging_info.level)
|
||||||
|
|
||||||
if filepath:
|
|
||||||
p = Path(filepath)
|
|
||||||
if not p.parent.is_dir():
|
|
||||||
logger.fatal(f'Log directory does not exist: {p.parent}')
|
|
||||||
sys.exit(1)
|
|
||||||
LOG_DIRECTORY = p.parent
|
|
||||||
handler = logging.FileHandler(filepath)
|
|
||||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
||||||
handler.setFormatter(formatter)
|
|
||||||
logger.addHandler(handler)
|
|
||||||
|
|
||||||
|
|
||||||
def create_logger(name):
|
def create_logger(name):
|
||||||
logger = logging.getLogger('llm_server').getChild(name)
|
logger = logging.getLogger('llm_server').getChild(name)
|
||||||
|
@ -65,7 +50,4 @@ def create_logger(name):
|
||||||
handler.setFormatter(formatter)
|
handler.setFormatter(formatter)
|
||||||
logger.addHandler(handler)
|
logger.addHandler(handler)
|
||||||
coloredlogs.install(logger=logger, level=logging_info.level)
|
coloredlogs.install(logger=logger, level=logging_info.level)
|
||||||
if LOG_DIRECTORY:
|
|
||||||
handler = logging.FileHandler(LOG_DIRECTORY / f'{name}.log')
|
|
||||||
logger.addHandler(handler)
|
|
||||||
return logger
|
return logger
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
BACKEND_OFFLINE = 'The model you requested is not a valid choice. Please retry your query.'
|
|
|
@ -3,7 +3,7 @@ from typing import Tuple
|
||||||
import flask
|
import flask
|
||||||
from flask import jsonify, request
|
from flask import jsonify, request
|
||||||
|
|
||||||
from llm_server import messages
|
import llm_server.globals
|
||||||
from llm_server.config.global_config import GlobalConfig
|
from llm_server.config.global_config import GlobalConfig
|
||||||
from llm_server.database.log_to_db import log_to_db
|
from llm_server.database.log_to_db import log_to_db
|
||||||
from llm_server.logging import create_logger
|
from llm_server.logging import create_logger
|
||||||
|
@ -21,7 +21,7 @@ class OobaRequestHandler(RequestHandler):
|
||||||
assert not self.used
|
assert not self.used
|
||||||
if self.offline:
|
if self.offline:
|
||||||
# _logger.debug(f'This backend is offline.')
|
# _logger.debug(f'This backend is offline.')
|
||||||
return self.handle_error(messages.BACKEND_OFFLINE)
|
return self.handle_error(llm_server.globals.BACKEND_OFFLINE)
|
||||||
|
|
||||||
request_valid, invalid_response = self.validate_request()
|
request_valid, invalid_response = self.validate_request()
|
||||||
if not request_valid:
|
if not request_valid:
|
||||||
|
|
|
@ -2,6 +2,7 @@ from flask import Blueprint
|
||||||
|
|
||||||
from ..request_handler import before_request
|
from ..request_handler import before_request
|
||||||
from ...config.global_config import GlobalConfig
|
from ...config.global_config import GlobalConfig
|
||||||
|
from ...llm.openai.oai_to_vllm import return_oai_internal_server_error
|
||||||
from ...logging import create_logger
|
from ...logging import create_logger
|
||||||
|
|
||||||
_logger = create_logger('OpenAI')
|
_logger = create_logger('OpenAI')
|
||||||
|
@ -26,15 +27,8 @@ def handle_error(e):
|
||||||
"auth_subrequest_error"
|
"auth_subrequest_error"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_logger(f'OAI returning error: {e}')
|
_logger.error(f'OAI returning error: {e}')
|
||||||
return jsonify({
|
return_oai_internal_server_error()
|
||||||
"error": {
|
|
||||||
"message": "Internal server error",
|
|
||||||
"type": "auth_subrequest_error",
|
|
||||||
"param": None,
|
|
||||||
"code": "internal_error"
|
|
||||||
}
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
|
|
||||||
from .models import openai_list_models
|
from .models import openai_list_models
|
||||||
|
|
|
@ -13,7 +13,7 @@ from ..openai_request_handler import OpenAIRequestHandler
|
||||||
from ..queue import priority_queue
|
from ..queue import priority_queue
|
||||||
from ...config.global_config import GlobalConfig
|
from ...config.global_config import GlobalConfig
|
||||||
from ...database.log_to_db import log_to_db
|
from ...database.log_to_db import log_to_db
|
||||||
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
|
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai, return_oai_internal_server_error
|
||||||
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
||||||
from ...logging import create_logger
|
from ...logging import create_logger
|
||||||
|
|
||||||
|
@ -32,7 +32,8 @@ def openai_chat_completions(model_name=None):
|
||||||
else:
|
else:
|
||||||
handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name)
|
handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name)
|
||||||
if handler.offline:
|
if handler.offline:
|
||||||
return return_invalid_model_err(model_name)
|
# return return_invalid_model_err(model_name)
|
||||||
|
return_oai_internal_server_error()
|
||||||
|
|
||||||
if not request_json_body.get('stream'):
|
if not request_json_body.get('stream'):
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -14,7 +14,7 @@ from llm_server.custom_redis import redis
|
||||||
from llm_server.database.database import is_api_key_moderated
|
from llm_server.database.database import is_api_key_moderated
|
||||||
from llm_server.database.log_to_db import log_to_db
|
from llm_server.database.log_to_db import log_to_db
|
||||||
from llm_server.llm import get_token_count
|
from llm_server.llm import get_token_count
|
||||||
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_invalid_model_err
|
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_invalid_model_err, return_oai_internal_server_error
|
||||||
from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
||||||
from llm_server.logging import create_logger
|
from llm_server.logging import create_logger
|
||||||
from llm_server.routes.request_handler import RequestHandler
|
from llm_server.routes.request_handler import RequestHandler
|
||||||
|
@ -31,9 +31,10 @@ class OpenAIRequestHandler(RequestHandler):
|
||||||
def handle_request(self) -> Tuple[flask.Response, int]:
|
def handle_request(self) -> Tuple[flask.Response, int]:
|
||||||
assert not self.used
|
assert not self.used
|
||||||
if self.offline:
|
if self.offline:
|
||||||
msg = return_invalid_model_err(self.selected_model)
|
# msg = return_invalid_model_err(self.selected_model)
|
||||||
_logger.error(f'OAI is offline: {msg}')
|
# _logger.error(f'OAI is offline: {msg}')
|
||||||
return self.handle_error(msg)
|
# return self.handle_error(msg)
|
||||||
|
return_oai_internal_server_error()
|
||||||
|
|
||||||
if GlobalConfig.get().openai_silent_trim:
|
if GlobalConfig.get().openai_silent_trim:
|
||||||
oai_messages = trim_messages_to_fit(self.request.json['messages'], self.cluster_backend_info['model_config']['max_position_embeddings'], self.backend_url)
|
oai_messages = trim_messages_to_fit(self.request.json['messages'], self.cluster_backend_info['model_config']['max_position_embeddings'], self.backend_url)
|
||||||
|
@ -109,7 +110,7 @@ class OpenAIRequestHandler(RequestHandler):
|
||||||
return response, 429
|
return response, 429
|
||||||
|
|
||||||
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
|
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
|
||||||
_logger.error('OAI Error: {error_msg}')
|
_logger.error(f'OAI Error: {error_msg}')
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"error": {
|
"error": {
|
||||||
"message": "Invalid request, check your parameters and try again.",
|
"message": "Invalid request, check your parameters and try again.",
|
||||||
|
|
|
@ -34,7 +34,7 @@ def generate_stats(regen: bool = False):
|
||||||
'proompts_total': get_total_proompts() if GlobalConfig.get().show_num_prompts else None,
|
'proompts_total': get_total_proompts() if GlobalConfig.get().show_num_prompts else None,
|
||||||
'uptime': int((datetime.now() - server_start_time).total_seconds()) if GlobalConfig.get().show_uptime else None,
|
'uptime': int((datetime.now() - server_start_time).total_seconds()) if GlobalConfig.get().show_uptime else None,
|
||||||
# 'estimated_avg_tps': estimated_avg_tps,
|
# 'estimated_avg_tps': estimated_avg_tps,
|
||||||
'tokens_generated': sum_column('prompts', 'response_tokens') if GlobalConfig.get().show_total_output_tokens else None,
|
'tokens_generated': sum_column('messages', 'response_tokens') if GlobalConfig.get().show_total_output_tokens else None,
|
||||||
'num_backends': len(cluster_config.all()) if GlobalConfig.get().show_backends else None,
|
'num_backends': len(cluster_config.all()) if GlobalConfig.get().show_backends else None,
|
||||||
},
|
},
|
||||||
'endpoints': {
|
'endpoints': {
|
||||||
|
|
|
@ -8,6 +8,7 @@ import ujson
|
||||||
from redis import Redis
|
from redis import Redis
|
||||||
|
|
||||||
from llm_server.cluster.cluster_config import cluster_config
|
from llm_server.cluster.cluster_config import cluster_config
|
||||||
|
from llm_server.config.global_config import GlobalConfig
|
||||||
from llm_server.custom_redis import RedisCustom, redis
|
from llm_server.custom_redis import RedisCustom, redis
|
||||||
from llm_server.llm.generator import generator
|
from llm_server.llm.generator import generator
|
||||||
from llm_server.logging import create_logger
|
from llm_server.logging import create_logger
|
||||||
|
@ -148,12 +149,12 @@ def worker(backend_url):
|
||||||
status_redis.setp(str(worker_id), None)
|
status_redis.setp(str(worker_id), None)
|
||||||
|
|
||||||
|
|
||||||
def start_workers(cluster: dict):
|
def start_workers():
|
||||||
logger = create_logger('inferencer')
|
logger = create_logger('inferencer')
|
||||||
i = 0
|
i = 0
|
||||||
for item in cluster:
|
for item in GlobalConfig.get().cluster:
|
||||||
for _ in range(item['concurrent_gens']):
|
for _ in range(item.concurrent_gens):
|
||||||
t = threading.Thread(target=worker, args=(item['backend_url'],))
|
t = threading.Thread(target=worker, args=(item.backend_url,))
|
||||||
t.daemon = True
|
t.daemon = True
|
||||||
t.start()
|
t.start()
|
||||||
i += 1
|
i += 1
|
||||||
|
|
|
@ -49,10 +49,10 @@ def main_background_thread():
|
||||||
def calc_stats_for_backend(backend_url, running_model, backend_mode):
|
def calc_stats_for_backend(backend_url, running_model, backend_mode):
|
||||||
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
|
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
|
||||||
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now.
|
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now.
|
||||||
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time',
|
average_generation_elapsed_sec = weighted_average_column_for_model('messages', 'generation_time',
|
||||||
running_model, backend_mode, backend_url, exclude_zeros=True,
|
running_model, backend_mode, backend_url, exclude_zeros=True,
|
||||||
include_system_tokens=GlobalConfig.get().include_system_tokens_in_stats) or 0
|
include_system_tokens=GlobalConfig.get().include_system_tokens_in_stats) or 0
|
||||||
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens',
|
average_output_tokens = weighted_average_column_for_model('messages', 'response_tokens',
|
||||||
running_model, backend_mode, backend_url, exclude_zeros=True,
|
running_model, backend_mode, backend_url, exclude_zeros=True,
|
||||||
include_system_tokens=GlobalConfig.get().include_system_tokens_in_stats) or 0
|
include_system_tokens=GlobalConfig.get().include_system_tokens_in_stats) or 0
|
||||||
estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero
|
estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero
|
||||||
|
|
|
@ -22,7 +22,7 @@ def cache_stats():
|
||||||
|
|
||||||
def start_background():
|
def start_background():
|
||||||
logger = create_logger('threader')
|
logger = create_logger('threader')
|
||||||
start_workers(GlobalConfig.get().cluster)
|
start_workers()
|
||||||
|
|
||||||
t = Thread(target=main_background_thread)
|
t = Thread(target=main_background_thread)
|
||||||
t.daemon = True
|
t.daemon = True
|
||||||
|
@ -46,7 +46,7 @@ def start_background():
|
||||||
t = Thread(target=console_printer)
|
t = Thread(target=console_printer)
|
||||||
t.daemon = True
|
t.daemon = True
|
||||||
t.start()
|
t.start()
|
||||||
logger.info('Started the console logger.infoer.')
|
logger.info('Started the console logger.')
|
||||||
|
|
||||||
t = Thread(target=cluster_worker)
|
t = Thread(target=cluster_worker)
|
||||||
t.daemon = True
|
t.daemon = True
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
from llm_server.helpers import resolve_path
|
||||||
|
|
||||||
|
try:
|
||||||
|
import gevent.monkey
|
||||||
|
|
||||||
|
gevent.monkey.patch_all()
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from llm_server.config.global_config import GlobalConfig
|
||||||
|
from llm_server.config.load import load_config
|
||||||
|
from llm_server.custom_redis import redis
|
||||||
|
from llm_server.database.conn import Database
|
||||||
|
from llm_server.database.create import create_db
|
||||||
|
from llm_server.logging import init_logging, create_logger
|
||||||
|
|
||||||
|
|
||||||
|
def post_fork(server, worker):
|
||||||
|
"""
|
||||||
|
Initalize the worker after gunicorn has forked. This is done to avoid issues with the database manager.
|
||||||
|
"""
|
||||||
|
script_path = Path(os.path.dirname(os.path.realpath(__file__)))
|
||||||
|
config_path_environ = os.getenv("CONFIG_PATH")
|
||||||
|
if config_path_environ:
|
||||||
|
config_path = config_path_environ
|
||||||
|
else:
|
||||||
|
config_path = script_path / '../config/config.yml'
|
||||||
|
config_path = resolve_path(config_path)
|
||||||
|
|
||||||
|
success, msg = load_config(config_path)
|
||||||
|
if not success:
|
||||||
|
logger = logging.getLogger('llm_server')
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
logger.error(f'Failed to load config: {msg}')
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
init_logging()
|
||||||
|
logger = create_logger('Server')
|
||||||
|
logger.debug('Debug logging enabled.')
|
||||||
|
|
||||||
|
while not redis.get('daemon_started', dtype=bool):
|
||||||
|
logger.warning('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?')
|
||||||
|
time.sleep(10)
|
||||||
|
|
||||||
|
Database.initialise(**GlobalConfig.get().postgresql.dict())
|
||||||
|
create_db()
|
||||||
|
|
||||||
|
logger.info('Started HTTP worker!')
|
|
@ -11,6 +11,8 @@ WorkingDirectory=/srv/server/local-llm-server
|
||||||
# Sometimes the old processes aren't terminated when the service is restarted.
|
# Sometimes the old processes aren't terminated when the service is restarted.
|
||||||
ExecStartPre=/usr/bin/pkill -9 -f "/srv/server/local-llm-server/venv/bin/python3 /srv/server/local-llm-server/venv/bin/gunicorn"
|
ExecStartPre=/usr/bin/pkill -9 -f "/srv/server/local-llm-server/venv/bin/python3 /srv/server/local-llm-server/venv/bin/gunicorn"
|
||||||
|
|
||||||
|
# TODO: make sure gunicorn logs to stdout and logging also goes to stdout
|
||||||
|
|
||||||
# Need a lot of workers since we have long-running requests. This takes about 3.5G memory.
|
# Need a lot of workers since we have long-running requests. This takes about 3.5G memory.
|
||||||
ExecStart=/srv/server/local-llm-server/venv/bin/gunicorn --workers 20 --bind 0.0.0.0:5000 server:app --timeout 60 --worker-class gevent --access-logfile '-' --error-logfile '-'
|
ExecStart=/srv/server/local-llm-server/venv/bin/gunicorn --workers 20 --bind 0.0.0.0:5000 server:app --timeout 60 --worker-class gevent --access-logfile '-' --error-logfile '-'
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,6 @@ Flask-Caching==2.0.2
|
||||||
requests~=2.31.0
|
requests~=2.31.0
|
||||||
tiktoken~=0.5.0
|
tiktoken~=0.5.0
|
||||||
gevent~=23.9.0.post1
|
gevent~=23.9.0.post1
|
||||||
mysql-connector-python==8.4.0
|
|
||||||
simplejson~=3.19.1
|
simplejson~=3.19.1
|
||||||
websockets~=11.0.3
|
websockets~=11.0.3
|
||||||
basicauth~=1.0.0
|
basicauth~=1.0.0
|
||||||
|
@ -17,3 +16,4 @@ vllm==0.2.7
|
||||||
coloredlogs~=15.0.1
|
coloredlogs~=15.0.1
|
||||||
git+https://git.evulid.cc/cyberes/bison.git
|
git+https://git.evulid.cc/cyberes/bison.git
|
||||||
pydantic
|
pydantic
|
||||||
|
psycopg2-binary==2.9.9
|
70
server.py
70
server.py
|
@ -1,33 +1,13 @@
|
||||||
import time
|
|
||||||
|
|
||||||
from llm_server.config.global_config import GlobalConfig
|
|
||||||
|
|
||||||
try:
|
|
||||||
import gevent.monkey
|
|
||||||
|
|
||||||
gevent.monkey.patch_all()
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
from flask import Flask, jsonify, render_template, request, Response
|
from flask import Flask, jsonify, render_template, request, Response
|
||||||
|
|
||||||
import config
|
|
||||||
from llm_server.cluster.backend import get_model_choices
|
from llm_server.cluster.backend import get_model_choices
|
||||||
from llm_server.cluster.cluster_config import cluster_config
|
from llm_server.cluster.cluster_config import cluster_config
|
||||||
from llm_server.config.config import mode_ui_names
|
from llm_server.config.config import MODE_UI_NAMES
|
||||||
from llm_server.config.load import load_config
|
from llm_server.config.global_config import GlobalConfig
|
||||||
from llm_server.custom_redis import flask_cache, redis
|
from llm_server.custom_redis import flask_cache, redis
|
||||||
from llm_server.database.conn import Database
|
|
||||||
from llm_server.database.create import create_db
|
|
||||||
from llm_server.helpers import auto_set_base_client_api
|
from llm_server.helpers import auto_set_base_client_api
|
||||||
from llm_server.llm.vllm.info import vllm_info
|
from llm_server.llm.vllm.info import vllm_info
|
||||||
from llm_server.logging import init_logging, create_logger
|
|
||||||
from llm_server.routes.openai import openai_bp, openai_model_bp
|
from llm_server.routes.openai import openai_bp, openai_model_bp
|
||||||
from llm_server.routes.server_error import handle_server_error
|
from llm_server.routes.server_error import handle_server_error
|
||||||
from llm_server.routes.v1 import bp
|
from llm_server.routes.v1 import bp
|
||||||
|
@ -63,32 +43,6 @@ from llm_server.sock import init_wssocket
|
||||||
# TODO: add more excluding to SYSTEM__ tokens
|
# TODO: add more excluding to SYSTEM__ tokens
|
||||||
# TODO: return 200 when returning formatted sillytavern error
|
# TODO: return 200 when returning formatted sillytavern error
|
||||||
|
|
||||||
script_path = os.path.dirname(os.path.realpath(__file__))
|
|
||||||
config_path_environ = os.getenv("CONFIG_PATH")
|
|
||||||
if config_path_environ:
|
|
||||||
config_path = config_path_environ
|
|
||||||
else:
|
|
||||||
config_path = Path(script_path, 'config', 'config.yml')
|
|
||||||
|
|
||||||
success, msg = load_config(config_path)
|
|
||||||
if not success:
|
|
||||||
logger = logging.getLogger('llm_server')
|
|
||||||
logger.setLevel(logging.INFO)
|
|
||||||
logger.error(f'Failed to load config: {msg}')
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
init_logging(Path(GlobalConfig.get().webserver_log_directory) / 'server.log')
|
|
||||||
logger = create_logger('Server')
|
|
||||||
logger.debug('Debug logging enabled.')
|
|
||||||
|
|
||||||
while not redis.get('daemon_started', dtype=bool):
|
|
||||||
logger.warning('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?')
|
|
||||||
time.sleep(10)
|
|
||||||
|
|
||||||
logger.info('Started HTTP worker!')
|
|
||||||
|
|
||||||
Database.initialise(maxconn=GlobalConfig.get().mysql.maxconn, host=GlobalConfig.get().mysql.host, user=GlobalConfig.get().mysql.username, password=GlobalConfig.get().mysql.password, database=GlobalConfig.get().mysql.database)
|
|
||||||
create_db()
|
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
|
|
||||||
|
@ -139,13 +93,13 @@ def home():
|
||||||
# to None by the daemon.
|
# to None by the daemon.
|
||||||
default_model_info['context_size'] = '-'
|
default_model_info['context_size'] = '-'
|
||||||
|
|
||||||
if len(config['analytics_tracking_code']):
|
if len(GlobalConfig.get().analytics_tracking_code):
|
||||||
analytics_tracking_code = f"<script>\n{config['analytics_tracking_code']}\n</script>"
|
analytics_tracking_code = f"<script>\n{GlobalConfig.get().analytics_tracking_code}\n</script>"
|
||||||
else:
|
else:
|
||||||
analytics_tracking_code = ''
|
analytics_tracking_code = ''
|
||||||
|
|
||||||
if config['info_html']:
|
if GlobalConfig.get().info_html:
|
||||||
info_html = config['info_html']
|
info_html = GlobalConfig.get().info_html
|
||||||
else:
|
else:
|
||||||
info_html = ''
|
info_html = ''
|
||||||
|
|
||||||
|
@ -166,9 +120,9 @@ def home():
|
||||||
client_api=f'https://{base_client_api}',
|
client_api=f'https://{base_client_api}',
|
||||||
ws_client_api=f'wss://{base_client_api}/v1/stream' if GlobalConfig.get().enable_streaming else 'disabled',
|
ws_client_api=f'wss://{base_client_api}/v1/stream' if GlobalConfig.get().enable_streaming else 'disabled',
|
||||||
default_estimated_wait=default_estimated_wait_sec,
|
default_estimated_wait=default_estimated_wait_sec,
|
||||||
mode_name=mode_ui_names[GlobalConfig.get().frontend_api_mode][0],
|
mode_name=MODE_UI_NAMES[GlobalConfig.get().frontend_api_mode].name,
|
||||||
api_input_textbox=mode_ui_names[GlobalConfig.get().frontend_api_mode][1],
|
api_input_textbox=MODE_UI_NAMES[GlobalConfig.get().frontend_api_mode].api_name,
|
||||||
streaming_input_textbox=mode_ui_names[GlobalConfig.get().frontend_api_mode][2],
|
streaming_input_textbox=MODE_UI_NAMES[GlobalConfig.get().frontend_api_mode].streaming_name,
|
||||||
default_context_size=default_model_info['context_size'],
|
default_context_size=default_model_info['context_size'],
|
||||||
stats_json=json.dumps(stats, indent=4, ensure_ascii=False),
|
stats_json=json.dumps(stats, indent=4, ensure_ascii=False),
|
||||||
extra_info=mode_info,
|
extra_info=mode_info,
|
||||||
|
@ -212,6 +166,6 @@ def before_app_request():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# server_startup(None)
|
print('Do not run this file directly. Instead, use gunicorn:')
|
||||||
print('FLASK MODE - Startup complete!')
|
print("gunicorn -c other/gunicorn_conf.py server:app -b 0.0.0.0:5000 --worker-class gevent --workers 3 --access-logfile '-' --error-logfile '-'")
|
||||||
app.run(host='0.0.0.0', threaded=False, processes=15)
|
quit(1)
|
||||||
|
|
Reference in New Issue