diff --git a/daemon.py b/daemon.py index 909cf34..4d2f77d 100644 --- a/daemon.py +++ b/daemon.py @@ -3,7 +3,6 @@ import logging import os import sys import time -from pathlib import Path from redis import Redis @@ -14,6 +13,7 @@ from llm_server.custom_redis import redis from llm_server.database.conn import Database from llm_server.database.create import create_db from llm_server.database.database import get_number_of_rows +from llm_server.helpers import resolve_path from llm_server.logging import create_logger, logging_info, init_logging from llm_server.routes.v1.generate_stats import generate_stats from llm_server.workers.threader import start_background @@ -23,7 +23,7 @@ config_path_environ = os.getenv("CONFIG_PATH") if config_path_environ: config_path = config_path_environ else: - config_path = Path(script_path, 'config', 'config.yml') + config_path = resolve_path(script_path, 'config', 'config.yml') if __name__ == "__main__": parser = argparse.ArgumentParser(description='Daemon microservice.') @@ -47,7 +47,7 @@ if __name__ == "__main__": logger.info(f'Failed to load config: {msg}') sys.exit(1) - Database.initialise(maxconn=GlobalConfig.get().mysql.maxconn, host=GlobalConfig.get().mysql.host, user=GlobalConfig.get().mysql.username, password=GlobalConfig.get().mysql.password, database=GlobalConfig.get().mysql.database) + Database.initialise(**GlobalConfig.get().postgresql.dict()) create_db() cluster_config.clear() @@ -57,7 +57,7 @@ if __name__ == "__main__": generate_stats(regen=True) if GlobalConfig.get().load_num_prompts: - redis.set('proompts', get_number_of_rows('prompts')) + redis.set('proompts', get_number_of_rows('messages')) start_background() diff --git a/llm_server/cluster/backend.py b/llm_server/cluster/backend.py index 7fa26c8..20e8140 100644 --- a/llm_server/cluster/backend.py +++ b/llm_server/cluster/backend.py @@ -17,6 +17,7 @@ def get_backends_from_model(model_name: str): :param model_name: :return: """ + assert isinstance(model_name, str) return [x.decode('utf-8') for x in redis_running_models.smembers(model_name)] @@ -25,7 +26,7 @@ def get_running_models(): Get all the models that are in the cluster. :return: """ - return list(redis_running_models.keys()) + return [x.decode('utf-8') for x in list(redis_running_models.keys())] def is_valid_model(model_name: str) -> bool: @@ -81,6 +82,7 @@ def get_model_choices(regen: bool = False) -> tuple[dict, dict]: base_client_api = redis.get('base_client_api', dtype=str) running_models = get_running_models() + model_choices = {} for model in running_models: b = get_backends_from_model(model) diff --git a/llm_server/cluster/cluster_config.py b/llm_server/cluster/cluster_config.py index ecee40d..3627813 100644 --- a/llm_server/cluster/cluster_config.py +++ b/llm_server/cluster/cluster_config.py @@ -33,7 +33,7 @@ class RedisClusterStore: item.backend_url = backend_url stuff[backend_url] = item for k, v in stuff.items(): - self.add_backend(k, v) + self.add_backend(k, v.dict()) def add_backend(self, name: str, values: dict): self.config_redis.hset(name, mapping={k: pickle.dumps(v) for k, v in values.items()}) diff --git a/llm_server/config/config.py b/llm_server/config/config.py index 55c8538..b2ad765 100644 --- a/llm_server/config/config.py +++ b/llm_server/config/config.py @@ -1,14 +1,22 @@ +from pydantic import BaseModel + from llm_server.config.global_config import GlobalConfig def cluster_worker_count(): count = 0 for item in GlobalConfig.get().cluster: - count += item['concurrent_gens'] + count += item.concurrent_gens return count -mode_ui_names = { - 'ooba': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'), - 'vllm': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'), +class ModeUINameStr(BaseModel): + name: str + api_name: str + streaming_name: str + + +MODE_UI_NAMES = { + 'ooba': ModeUINameStr(name='Text Gen WebUI (ooba)', api_name='Blocking API url', streaming_name='Streaming API url'), + 'vllm': ModeUINameStr(name='Text Gen WebUI (ooba)', api_name='Blocking API url', streaming_name='Streaming API url'), } diff --git a/llm_server/config/load.py b/llm_server/config/load.py index 39786bb..4f2c20f 100644 --- a/llm_server/config/load.py +++ b/llm_server/config/load.py @@ -63,8 +63,8 @@ def load_config(config_path: Path): config_model = ConfigModel(**config.config) GlobalConfig.initalize(config_model) - if not (0 < GlobalConfig.get().mysql.maxconn <= 32): - return False, f'"maxcon" should be higher than 0 and lower or equal to 32. Current value: "{GlobalConfig.get().mysql.maxconn}"' + if GlobalConfig.get().postgresql.maxconn < 0: + return False, f'"maxcon" should be higher than 0. Current value: "{GlobalConfig.get().postgresql.maxconn}"' openai.api_key = GlobalConfig.get().openai_api_key diff --git a/llm_server/config/model.py b/llm_server/config/model.py index b7cf156..3133e0b 100644 --- a/llm_server/config/model.py +++ b/llm_server/config/model.py @@ -19,9 +19,9 @@ class ConfigFrontendApiModes(str, Enum): ooba = 'ooba' -class ConfigMysql(BaseModel): +class ConfigPostgresql(BaseModel): host: str - username: str + user: str password: str database: str maxconn: int @@ -37,9 +37,8 @@ class ConfigModel(BaseModel): cluster: List[ConfigCluser] prioritize_by_size: bool admin_token: Union[str, None] - mysql: ConfigMysql + postgresql: ConfigPostgresql http_host: str - webserver_log_directory: str include_system_tokens_in_stats: bool background_homepage_cacher: bool max_new_tokens: int @@ -55,6 +54,7 @@ class ConfigModel(BaseModel): info_html: Union[str, None] enable_openi_compatible_backend: bool openai_api_key: Union[str, None] + openai_system_prompt: str expose_openai_system_prompt: bool openai_expose_our_model: bool openai_force_no_hashes: bool @@ -72,3 +72,4 @@ class ConfigModel(BaseModel): load_num_prompts: bool manual_model_name: Union[str, None] backend_request_timeout: int + backend_generate_request_timeout: int diff --git a/llm_server/config/scheme.py b/llm_server/config/scheme.py index 822a4b3..3468765 100644 --- a/llm_server/config/scheme.py +++ b/llm_server/config/scheme.py @@ -2,7 +2,7 @@ from typing import Union import bison -from llm_server.opts import default_openai_system_prompt +from llm_server.globals import DEFAULT_OPENAI_SYSTEM_PROMPT config_scheme = bison.Scheme( bison.Option('frontend_api_mode', choices=['ooba'], field_type=str), @@ -14,15 +14,14 @@ config_scheme = bison.Scheme( )), bison.Option('prioritize_by_size', default=True, field_type=bool), bison.Option('admin_token', default=None, field_type=Union[str, None]), - bison.ListOption('mysql', member_scheme=bison.Scheme( + bison.ListOption('postgresql', member_scheme=bison.Scheme( bison.Option('host', field_type=str), - bison.Option('username', field_type=str), + bison.Option('user', field_type=str), bison.Option('password', field_type=str), bison.Option('database', field_type=str), bison.Option('maxconn', field_type=int) )), bison.Option('http_host', default='', field_type=str), - bison.Option('webserver_log_directory', default='/var/log/localllm', field_type=str), bison.Option('include_system_tokens_in_stats', default=True, field_type=bool), bison.Option('background_homepage_cacher', default=True, field_type=bool), bison.Option('max_new_tokens', default=500, field_type=int), @@ -41,7 +40,7 @@ config_scheme = bison.Scheme( bison.Option('expose_openai_system_prompt', default=True, field_type=bool), bison.Option('openai_expose_our_model', default='', field_type=bool), bison.Option('openai_force_no_hashes', default=True, field_type=bool), - bison.Option('openai_system_prompt', default=default_openai_system_prompt, field_type=str), + bison.Option('openai_system_prompt', default=DEFAULT_OPENAI_SYSTEM_PROMPT, field_type=str), bison.Option('openai_moderation_enabled', default=False, field_type=bool), bison.Option('openai_moderation_timeout', default=5, field_type=int), bison.Option('openai_moderation_scan_last_n', default=5, field_type=int), @@ -55,5 +54,6 @@ config_scheme = bison.Scheme( bison.Option('show_backend_info', default=True, field_type=bool), bison.Option('load_num_prompts', default=True, field_type=bool), bison.Option('manual_model_name', default=None, field_type=Union[str, None]), - bison.Option('backend_request_timeout', default=30, field_type=int) + bison.Option('backend_request_timeout', default=30, field_type=int), + bison.Option('backend_generate_request_timeout', default=95, field_type=int) ) diff --git a/llm_server/custom_redis.py b/llm_server/custom_redis.py index 34f33e9..3aa338c 100644 --- a/llm_server/custom_redis.py +++ b/llm_server/custom_redis.py @@ -2,13 +2,13 @@ import logging import pickle import sys import traceback -from typing import Callable, List, Mapping, Optional, Union +from typing import Union import redis as redis_pkg import simplejson as json from flask_caching import Cache from redis import Redis -from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, PatternT, ZScoreBoundT, AbsExpiryT +from redis.typing import ExpiryT, KeyT, PatternT flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/15', 'CACHE_KEY_PREFIX': 'local_llm_flask'}) @@ -38,18 +38,11 @@ class RedisCustom(Redis): def _key(self, key): return f"{self.prefix}:{key}" - def set(self, key: KeyT, - value: EncodableT, - ex: Union[ExpiryT, None] = None, - px: Union[ExpiryT, None] = None, - nx: bool = False, - xx: bool = False, - keepttl: bool = False, - get: bool = False, - exat: Union[AbsExpiryT, None] = None, - pxat: Union[AbsExpiryT, None] = None - ): - return self.redis.set(self._key(key), value, ex=ex) + def execute_command(self, *args, **options): + if args[0] != 'GET': + args = list(args) + args[1] = self._key(args[1]) + return super().execute_command(*args, **options) def get(self, key, default=None, dtype=None): # TODO: use pickle @@ -73,103 +66,6 @@ class RedisCustom(Redis): else: return d - def incr(self, key, amount=1): - return self.redis.incr(self._key(key), amount) - - def decr(self, key, amount=1): - return self.redis.decr(self._key(key), amount) - - def sadd(self, key: str, *values: FieldT): - return self.redis.sadd(self._key(key), *values) - - def srem(self, key: str, *values: FieldT): - return self.redis.srem(self._key(key), *values) - - def sismember(self, key: str, value: str): - return self.redis.sismember(self._key(key), value) - - def lindex( - self, name: str, index: int - ): - return self.redis.lindex(self._key(name), index) - - def lrem(self, name: str, count: int, value: str): - return self.redis.lrem(self._key(name), count, value) - - def rpush(self, name: str, *values: FieldT): - return self.redis.rpush(self._key(name), *values) - - def llen(self, name: str): - return self.redis.llen(self._key(name)) - - def zrangebyscore( - self, - name: KeyT, - min: ZScoreBoundT, - max: ZScoreBoundT, - start: Union[int, None] = None, - num: Union[int, None] = None, - withscores: bool = False, - score_cast_func: Union[type, Callable] = float, - ): - return self.redis.zrangebyscore(self._key(name), min, max, start, num, withscores, score_cast_func) - - def zremrangebyscore( - self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT - ): - return self.redis.zremrangebyscore(self._key(name), min, max) - - def hincrby( - self, name: str, key: str, amount: int = 1 - ): - return self.redis.hincrby(self._key(name), key, amount) - - def zcard(self, name: KeyT): - return self.redis.zcard(self._key(name)) - - def hdel(self, name: str, *keys: str): - return self.redis.hdel(self._key(name), *keys) - - def hget( - self, name: str, key: str - ): - return self.redis.hget(self._key(name), key) - - def zadd( - self, - name: KeyT, - mapping: Mapping[AnyKeyT, EncodableT], - nx: bool = False, - xx: bool = False, - ch: bool = False, - incr: bool = False, - gt: bool = False, - lt: bool = False, - ): - return self.redis.zadd(self._key(name), mapping, nx, xx, ch, incr, gt, lt) - - def lpush(self, name: str, *values: FieldT): - return self.redis.lpush(self._key(name), *values) - - def hset( - self, - name: str, - key: Optional = None, - value=None, - mapping: Optional[dict] = None, - items: Optional[list] = None, - ): - return self.redis.hset(self._key(name), key, value, mapping, items) - - def hkeys(self, name: str): - return self.redis.hkeys(self._key(name)) - - def hmget(self, name: str, keys: List, *args: List): - return self.redis.hmget(self._key(name), keys, *args) - - def hgetall(self, name: str): - return self.redis.hgetall(self._key(name)) - def keys(self, pattern: PatternT = "*", **kwargs): raw_keys = self.redis.keys(self._key(pattern), **kwargs) keys = [] @@ -179,25 +75,9 @@ class RedisCustom(Redis): # Delete prefix del p[0] k = ':'.join(p) - if k != '____': - keys.append(k) + # keys.append(k) return keys - def pipeline(self, transaction=True, shard_hint=None): - return self.redis.pipeline(transaction, shard_hint) - - def smembers(self, name: str): - return self.redis.smembers(self._key(name)) - - def spop(self, name: str, count: Optional[int] = None): - return self.redis.spop(self._key(name), count) - - def rpoplpush(self, src, dst): - return self.redis.rpoplpush(src, dst) - - def zpopmin(self, name: KeyT, count: Union[int, None] = None): - return self.redis.zpopmin(self._key(name), count) - def exists(self, *names: KeyT): n = [] for name in names: @@ -238,32 +118,5 @@ class RedisCustom(Redis): self.flush() return True - def lrange(self, name: str, start: int, end: int): - return self.redis.lrange(self._key(name), start, end) - - def delete(self, *names: KeyT): - return self.redis.delete(*[self._key(i) for i in names]) - - def lpop(self, name: str, count: Optional[int] = None): - return self.redis.lpop(self._key(name), count) - - def zrange( - self, - name: KeyT, - start: int, - end: int, - desc: bool = False, - withscores: bool = False, - score_cast_func: Union[type, Callable] = float, - byscore: bool = False, - bylex: bool = False, - offset: int = None, - num: int = None, - ): - return self.redis.zrange(self._key(name), start, end, desc, withscores, score_cast_func, byscore, bylex, offset, num) - - def zrem(self, name: KeyT, *values: FieldT): - return self.redis.zrem(self._key(name), *values) - redis = RedisCustom('local_llm') diff --git a/llm_server/database/conn.py b/llm_server/database/conn.py index 6a0f063..248b62d 100644 --- a/llm_server/database/conn.py +++ b/llm_server/database/conn.py @@ -1,39 +1,42 @@ -from mysql.connector import pooling +from psycopg2 import pool, InterfaceError class Database: __connection_pool = None @classmethod - def initialise(cls, maxconn: int, **kwargs): + def initialise(cls, maxconn, **kwargs): if cls.__connection_pool is not None: raise Exception('Database connection pool is already initialised') - cls.__connection_pool = pooling.MySQLConnectionPool(pool_size=maxconn, - pool_reset_session=True, - **kwargs) + cls.__connection_pool = pool.ThreadedConnectionPool(minconn=1, maxconn=maxconn, **kwargs) @classmethod def get_connection(cls): - return cls.__connection_pool.get_connection() + return cls.__connection_pool.getconn() @classmethod def return_connection(cls, connection): - connection.close() + cls.__connection_pool.putconn(connection) class CursorFromConnectionFromPool: - def __init__(self): + def __init__(self, cursor_factory=None): self.conn = None self.cursor = None + self.cursor_factory = cursor_factory def __enter__(self): self.conn = Database.get_connection() - self.cursor = self.conn.cursor() + self.cursor = self.conn.cursor(cursor_factory=self.cursor_factory) return self.cursor def __exit__(self, exception_type, exception_value, exception_traceback): if exception_value is not None: # This is equivalent of saying if there is an exception - self.conn.rollback() + try: + self.conn.rollback() + except InterfaceError as e: + if e != 'connection already closed': + raise else: self.cursor.close() self.conn.commit() diff --git a/llm_server/database/create.py b/llm_server/database/create.py index 302f8a6..3febc78 100644 --- a/llm_server/database/create.py +++ b/llm_server/database/create.py @@ -4,36 +4,39 @@ from llm_server.database.conn import CursorFromConnectionFromPool def create_db(): with CursorFromConnectionFromPool() as cursor: cursor.execute(''' - CREATE TABLE IF NOT EXISTS prompts ( - ip TEXT, - token TEXT DEFAULT NULL, - model TEXT, - backend_mode TEXT, - backend_url TEXT, - request_url TEXT, - generation_time FLOAT, - prompt LONGTEXT, - prompt_tokens INTEGER, - response LONGTEXT, - response_tokens INTEGER, - response_status INTEGER, - parameters TEXT, - # CHECK (parameters IS NULL OR JSON_VALID(parameters)), - headers TEXT, - # CHECK (headers IS NULL OR JSON_VALID(headers)), - timestamp INTEGER - ) + CREATE TABLE IF NOT EXISTS public.messages + ( + ip text COLLATE pg_catalog."default" NOT NULL, + token text COLLATE pg_catalog."default", + model text COLLATE pg_catalog."default" NOT NULL, + backend_mode text COLLATE pg_catalog."default" NOT NULL, + backend_url text COLLATE pg_catalog."default" NOT NULL, + request_url text COLLATE pg_catalog."default" NOT NULL, + generation_time double precision NOT NULL, + prompt text COLLATE pg_catalog."default" NOT NULL, + prompt_tokens integer NOT NULL, + response text COLLATE pg_catalog."default" NOT NULL, + response_tokens integer NOT NULL, + response_status integer NOT NULL, + parameters jsonb NOT NULL, + headers jsonb, + "timestamp" timestamp with time zone NOT NULL DEFAULT CURRENT_TIMESTAMP, + id SERIAL PRIMARY KEY + ); ''') cursor.execute(''' - CREATE TABLE IF NOT EXISTS token_auth ( - token TEXT, - UNIQUE (token), - type TEXT NOT NULL, - priority INTEGER DEFAULT 9999, - simultaneous_ip INTEGER DEFAULT NULL, - uses INTEGER DEFAULT 0, - max_uses INTEGER, - expire INTEGER, - disabled BOOLEAN DEFAULT 0 - ) + CREATE TABLE IF NOT EXISTS public.token_auth + ( + token text COLLATE pg_catalog."default" NOT NULL, + type text COLLATE pg_catalog."default" NOT NULL, + priority integer NOT NULL DEFAULT 9999, + simultaneous_ip text COLLATE pg_catalog."default", + openai_moderation_enabled boolean NOT NULL DEFAULT true, + uses integer NOT NULL DEFAULT 0, + max_uses integer, + expire timestamp with time zone, + disabled boolean NOT NULL DEFAULT false, + notes text COLLATE pg_catalog."default" NOT NULL DEFAULT ''::text, + CONSTRAINT token_auth_pkey PRIMARY KEY (token) + ) ''') diff --git a/llm_server/database/database.py b/llm_server/database/database.py index 1d5e46e..90cf219 100644 --- a/llm_server/database/database.py +++ b/llm_server/database/database.py @@ -1,6 +1,7 @@ import json import time import traceback +from datetime import datetime, timedelta from typing import Union from llm_server.cluster.cluster_config import cluster_config @@ -51,10 +52,10 @@ def do_db_log(ip: str, token: str, prompt: str, response: Union[str, None], gen_ backend_info = cluster_config.get_backend(backend_url) running_model = backend_info.get('model') backend_mode = backend_info['mode'] - timestamp = int(time.time()) + timestamp = datetime.now() with CursorFromConnectionFromPool() as cursor: cursor.execute(""" - INSERT INTO prompts + INSERT INTO messages (ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) """, @@ -65,12 +66,12 @@ def is_valid_api_key(api_key): with CursorFromConnectionFromPool() as cursor: cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = %s", (api_key,)) row = cursor.fetchone() - if row is not None: - token, uses, max_uses, expire, disabled = row - disabled = bool(disabled) - if ((uses is None or max_uses is None) or uses < max_uses) and (expire is None or expire > time.time()) and not disabled: - return True - return False + if row is not None: + token, uses, max_uses, expire, disabled = row + disabled = bool(disabled) + if ((uses is None or max_uses is None) or uses < max_uses) and (expire is None or expire > time.time()) and not disabled: + return True + return False def is_api_key_moderated(api_key): @@ -146,9 +147,9 @@ def sum_column(table_name, column_name): def get_distinct_ips_24h(): # Get the current time and subtract 24 hours (in seconds) - past_24_hours = int(time.time()) - 24 * 60 * 60 + past_24_hours = datetime.now() - timedelta(days=1) with CursorFromConnectionFromPool() as cursor: - cursor.execute("SELECT COUNT(DISTINCT ip) FROM prompts WHERE timestamp >= %s AND (token NOT LIKE 'SYSTEM__%%' OR token IS NULL)", (past_24_hours,)) + cursor.execute("SELECT COUNT(DISTINCT ip) FROM messages WHERE timestamp >= %s AND (token NOT LIKE 'SYSTEM__%%' OR token IS NULL)", (past_24_hours,)) result = cursor.fetchone() return result[0] if result else 0 diff --git a/llm_server/opts.py b/llm_server/globals.py similarity index 85% rename from llm_server/opts.py rename to llm_server/globals.py index ef735b2..4a755cb 100644 --- a/llm_server/opts.py +++ b/llm_server/globals.py @@ -1,10 +1,6 @@ # Read-only global variables -default_openai_system_prompt = """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nLines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user.\n\n""" - -# cluster = {} - - +DEFAULT_OPENAI_SYSTEM_PROMPT = """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nLines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user.\n\n""" REDIS_STREAM_TIMEOUT = 25000 - LOGGING_FORMAT = "%(asctime)s: %(levelname)s:%(name)s - %(message)s" +BACKEND_OFFLINE = 'The model you requested is not a valid choice. Please retry your query.' diff --git a/llm_server/helpers.py b/llm_server/helpers.py index add04ff..38e348d 100644 --- a/llm_server/helpers.py +++ b/llm_server/helpers.py @@ -15,19 +15,6 @@ def resolve_path(*p: str): return Path(*p).expanduser().resolve().absolute() -def safe_list_get(l, idx, default): - """ - https://stackoverflow.com/a/5125636 - :param l: - :param idx: - :param default: - :return: - """ - try: - return l[idx] - except IndexError: - return default - def deep_sort(obj): if isinstance(obj, dict): diff --git a/llm_server/llm/generator.py b/llm_server/llm/generator.py index c924d38..754c7e0 100644 --- a/llm_server/llm/generator.py +++ b/llm_server/llm/generator.py @@ -1,4 +1,4 @@ -from llm_server import opts +from llm_server import globals from llm_server.cluster.cluster_config import cluster_config diff --git a/llm_server/llm/oobabooga/generate.py b/llm_server/llm/oobabooga/generate.py index 98b0d4c..4e64a72 100644 --- a/llm_server/llm/oobabooga/generate.py +++ b/llm_server/llm/oobabooga/generate.py @@ -8,14 +8,14 @@ import requests from llm_server.config.global_config import GlobalConfig -def generate(json_data: dict): - try: - r = requests.post(f'{GlobalConfig.get().backend_url}/api/v1/generate', json=json_data, verify=GlobalConfig.get().verify_ssl, timeout=GlobalConfig.get().backend_generate_request_timeout) - except requests.exceptions.ReadTimeout: - return False, None, 'Request to backend timed out' - except Exception as e: - traceback.print_exc() - return False, None, 'Request to backend encountered error' - if r.status_code != 200: - return False, r, f'Backend returned {r.status_code}' - return True, r, None +# def generate(json_data: dict): +# try: +# r = requests.post(f'{GlobalConfig.get().backend_url}/api/v1/generate', json=json_data, verify=GlobalConfig.get().verify_ssl, timeout=GlobalConfig.get().backend_generate_request_timeout) +# except requests.exceptions.ReadTimeout: +# return False, None, 'Request to backend timed out' +# except Exception as e: +# traceback.print_exc() +# return False, None, 'Request to backend encountered error' +# if r.status_code != 200: +# return False, r, f'Backend returned {r.status_code}' +# return True, r, None diff --git a/llm_server/llm/openai/oai_to_vllm.py b/llm_server/llm/openai/oai_to_vllm.py index 4f8dd24..ef2bed0 100644 --- a/llm_server/llm/openai/oai_to_vllm.py +++ b/llm_server/llm/openai/oai_to_vllm.py @@ -98,3 +98,14 @@ def return_invalid_model_err(requested_model: str): "code": "model_not_found" } }), 404 + + +def return_oai_internal_server_error(): + return jsonify({ + "error": { + "message": "Internal server error", + "type": "auth_subrequest_error", + "param": None, + "code": "internal_error" + } + }), 500 diff --git a/llm_server/llm/vllm/generate.py b/llm_server/llm/vllm/generate.py index 3e7926e..72996ac 100644 --- a/llm_server/llm/vllm/generate.py +++ b/llm_server/llm/vllm/generate.py @@ -31,7 +31,7 @@ def handle_blocking_request(json_data: dict, cluster_backend, timeout: int = 10) return False, None, 'Request to backend timed out' except Exception as e: # print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}') - return False, None, 'Request to backend encountered error' + return False, None, f'Request to backend encountered error -- {e.__class__.__name__}: {e}' if r.status_code != 200: # print(f'Failed to reach VLLM inference endpoint - got code {r.status_code}') return False, r, f'Backend returned {r.status_code}' diff --git a/llm_server/logging.py b/llm_server/logging.py index 1d42fd0..4f43dfb 100644 --- a/llm_server/logging.py +++ b/llm_server/logging.py @@ -1,16 +1,14 @@ import logging -import sys -from pathlib import Path import coloredlogs -from llm_server import opts +from llm_server import globals class LoggingInfo: def __init__(self): self._level = logging.INFO - self._format = opts.LOGGING_FORMAT + self._format = globals.LOGGING_FORMAT @property def level(self): @@ -30,30 +28,17 @@ class LoggingInfo: logging_info = LoggingInfo() -LOG_DIRECTORY = None -def init_logging(filepath: Path = None): +def init_logging(): """ Set up the parent logger. Ensures this logger and all children to log to a file. This is only called by `server.py` since there is wierdness with Gunicorn. The deamon doesn't need this. :return: """ - global LOG_DIRECTORY logger = logging.getLogger('llm_server') logger.setLevel(logging_info.level) - if filepath: - p = Path(filepath) - if not p.parent.is_dir(): - logger.fatal(f'Log directory does not exist: {p.parent}') - sys.exit(1) - LOG_DIRECTORY = p.parent - handler = logging.FileHandler(filepath) - formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') - handler.setFormatter(formatter) - logger.addHandler(handler) - def create_logger(name): logger = logging.getLogger('llm_server').getChild(name) @@ -65,7 +50,4 @@ def create_logger(name): handler.setFormatter(formatter) logger.addHandler(handler) coloredlogs.install(logger=logger, level=logging_info.level) - if LOG_DIRECTORY: - handler = logging.FileHandler(LOG_DIRECTORY / f'{name}.log') - logger.addHandler(handler) return logger diff --git a/llm_server/messages.py b/llm_server/messages.py deleted file mode 100644 index c7e3eb7..0000000 --- a/llm_server/messages.py +++ /dev/null @@ -1 +0,0 @@ -BACKEND_OFFLINE = 'The model you requested is not a valid choice. Please retry your query.' diff --git a/llm_server/routes/ooba_request_handler.py b/llm_server/routes/ooba_request_handler.py index 96ad44f..e848d90 100644 --- a/llm_server/routes/ooba_request_handler.py +++ b/llm_server/routes/ooba_request_handler.py @@ -3,7 +3,7 @@ from typing import Tuple import flask from flask import jsonify, request -from llm_server import messages +import llm_server.globals from llm_server.config.global_config import GlobalConfig from llm_server.database.log_to_db import log_to_db from llm_server.logging import create_logger @@ -21,7 +21,7 @@ class OobaRequestHandler(RequestHandler): assert not self.used if self.offline: # _logger.debug(f'This backend is offline.') - return self.handle_error(messages.BACKEND_OFFLINE) + return self.handle_error(llm_server.globals.BACKEND_OFFLINE) request_valid, invalid_response = self.validate_request() if not request_valid: diff --git a/llm_server/routes/openai/__init__.py b/llm_server/routes/openai/__init__.py index 6591700..70e63b6 100644 --- a/llm_server/routes/openai/__init__.py +++ b/llm_server/routes/openai/__init__.py @@ -2,6 +2,7 @@ from flask import Blueprint from ..request_handler import before_request from ...config.global_config import GlobalConfig +from ...llm.openai.oai_to_vllm import return_oai_internal_server_error from ...logging import create_logger _logger = create_logger('OpenAI') @@ -26,15 +27,8 @@ def handle_error(e): "auth_subrequest_error" """ - _logger(f'OAI returning error: {e}') - return jsonify({ - "error": { - "message": "Internal server error", - "type": "auth_subrequest_error", - "param": None, - "code": "internal_error" - } - }), 500 + _logger.error(f'OAI returning error: {e}') + return_oai_internal_server_error() from .models import openai_list_models diff --git a/llm_server/routes/openai/chat_completions.py b/llm_server/routes/openai/chat_completions.py index cc25b09..7691a7d 100644 --- a/llm_server/routes/openai/chat_completions.py +++ b/llm_server/routes/openai/chat_completions.py @@ -13,7 +13,7 @@ from ..openai_request_handler import OpenAIRequestHandler from ..queue import priority_queue from ...config.global_config import GlobalConfig from ...database.log_to_db import log_to_db -from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai +from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai, return_oai_internal_server_error from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit from ...logging import create_logger @@ -32,7 +32,8 @@ def openai_chat_completions(model_name=None): else: handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name) if handler.offline: - return return_invalid_model_err(model_name) + # return return_invalid_model_err(model_name) + return_oai_internal_server_error() if not request_json_body.get('stream'): try: diff --git a/llm_server/routes/openai_request_handler.py b/llm_server/routes/openai_request_handler.py index 40e1559..5948d6e 100644 --- a/llm_server/routes/openai_request_handler.py +++ b/llm_server/routes/openai_request_handler.py @@ -14,7 +14,7 @@ from llm_server.custom_redis import redis from llm_server.database.database import is_api_key_moderated from llm_server.database.log_to_db import log_to_db from llm_server.llm import get_token_count -from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_invalid_model_err +from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_invalid_model_err, return_oai_internal_server_error from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit from llm_server.logging import create_logger from llm_server.routes.request_handler import RequestHandler @@ -31,9 +31,10 @@ class OpenAIRequestHandler(RequestHandler): def handle_request(self) -> Tuple[flask.Response, int]: assert not self.used if self.offline: - msg = return_invalid_model_err(self.selected_model) - _logger.error(f'OAI is offline: {msg}') - return self.handle_error(msg) + # msg = return_invalid_model_err(self.selected_model) + # _logger.error(f'OAI is offline: {msg}') + # return self.handle_error(msg) + return_oai_internal_server_error() if GlobalConfig.get().openai_silent_trim: oai_messages = trim_messages_to_fit(self.request.json['messages'], self.cluster_backend_info['model_config']['max_position_embeddings'], self.backend_url) @@ -109,7 +110,7 @@ class OpenAIRequestHandler(RequestHandler): return response, 429 def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]: - _logger.error('OAI Error: {error_msg}') + _logger.error(f'OAI Error: {error_msg}') return jsonify({ "error": { "message": "Invalid request, check your parameters and try again.", diff --git a/llm_server/routes/v1/generate_stats.py b/llm_server/routes/v1/generate_stats.py index f665deb..1c38983 100644 --- a/llm_server/routes/v1/generate_stats.py +++ b/llm_server/routes/v1/generate_stats.py @@ -34,7 +34,7 @@ def generate_stats(regen: bool = False): 'proompts_total': get_total_proompts() if GlobalConfig.get().show_num_prompts else None, 'uptime': int((datetime.now() - server_start_time).total_seconds()) if GlobalConfig.get().show_uptime else None, # 'estimated_avg_tps': estimated_avg_tps, - 'tokens_generated': sum_column('prompts', 'response_tokens') if GlobalConfig.get().show_total_output_tokens else None, + 'tokens_generated': sum_column('messages', 'response_tokens') if GlobalConfig.get().show_total_output_tokens else None, 'num_backends': len(cluster_config.all()) if GlobalConfig.get().show_backends else None, }, 'endpoints': { diff --git a/llm_server/workers/inferencer.py b/llm_server/workers/inferencer.py index 21e45d0..80a156c 100644 --- a/llm_server/workers/inferencer.py +++ b/llm_server/workers/inferencer.py @@ -8,6 +8,7 @@ import ujson from redis import Redis from llm_server.cluster.cluster_config import cluster_config +from llm_server.config.global_config import GlobalConfig from llm_server.custom_redis import RedisCustom, redis from llm_server.llm.generator import generator from llm_server.logging import create_logger @@ -148,12 +149,12 @@ def worker(backend_url): status_redis.setp(str(worker_id), None) -def start_workers(cluster: dict): +def start_workers(): logger = create_logger('inferencer') i = 0 - for item in cluster: - for _ in range(item['concurrent_gens']): - t = threading.Thread(target=worker, args=(item['backend_url'],)) + for item in GlobalConfig.get().cluster: + for _ in range(item.concurrent_gens): + t = threading.Thread(target=worker, args=(item.backend_url,)) t.daemon = True t.start() i += 1 diff --git a/llm_server/workers/mainer.py b/llm_server/workers/mainer.py index 772452c..100ddfd 100644 --- a/llm_server/workers/mainer.py +++ b/llm_server/workers/mainer.py @@ -49,10 +49,10 @@ def main_background_thread(): def calc_stats_for_backend(backend_url, running_model, backend_mode): # exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0 # was entered into the column. The new code enters null instead but we need to be backwards compatible for now. - average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', + average_generation_elapsed_sec = weighted_average_column_for_model('messages', 'generation_time', running_model, backend_mode, backend_url, exclude_zeros=True, include_system_tokens=GlobalConfig.get().include_system_tokens_in_stats) or 0 - average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', + average_output_tokens = weighted_average_column_for_model('messages', 'response_tokens', running_model, backend_mode, backend_url, exclude_zeros=True, include_system_tokens=GlobalConfig.get().include_system_tokens_in_stats) or 0 estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero diff --git a/llm_server/workers/threader.py b/llm_server/workers/threader.py index 51e1b77..11da2e6 100644 --- a/llm_server/workers/threader.py +++ b/llm_server/workers/threader.py @@ -22,7 +22,7 @@ def cache_stats(): def start_background(): logger = create_logger('threader') - start_workers(GlobalConfig.get().cluster) + start_workers() t = Thread(target=main_background_thread) t.daemon = True @@ -46,7 +46,7 @@ def start_background(): t = Thread(target=console_printer) t.daemon = True t.start() - logger.info('Started the console logger.infoer.') + logger.info('Started the console logger.') t = Thread(target=cluster_worker) t.daemon = True diff --git a/other/gunicorn_conf.py b/other/gunicorn_conf.py new file mode 100644 index 0000000..fd4d4b7 --- /dev/null +++ b/other/gunicorn_conf.py @@ -0,0 +1,54 @@ +from llm_server.helpers import resolve_path + +try: + import gevent.monkey + + gevent.monkey.patch_all() +except ImportError: + pass + +import logging +import os +import sys +import time +from pathlib import Path + +from llm_server.config.global_config import GlobalConfig +from llm_server.config.load import load_config +from llm_server.custom_redis import redis +from llm_server.database.conn import Database +from llm_server.database.create import create_db +from llm_server.logging import init_logging, create_logger + + +def post_fork(server, worker): + """ + Initalize the worker after gunicorn has forked. This is done to avoid issues with the database manager. + """ + script_path = Path(os.path.dirname(os.path.realpath(__file__))) + config_path_environ = os.getenv("CONFIG_PATH") + if config_path_environ: + config_path = config_path_environ + else: + config_path = script_path / '../config/config.yml' + config_path = resolve_path(config_path) + + success, msg = load_config(config_path) + if not success: + logger = logging.getLogger('llm_server') + logger.setLevel(logging.INFO) + logger.error(f'Failed to load config: {msg}') + sys.exit(1) + + init_logging() + logger = create_logger('Server') + logger.debug('Debug logging enabled.') + + while not redis.get('daemon_started', dtype=bool): + logger.warning('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?') + time.sleep(10) + + Database.initialise(**GlobalConfig.get().postgresql.dict()) + create_db() + + logger.info('Started HTTP worker!') diff --git a/other/local-llm-server.service b/other/local-llm-server.service index f1bbe56..fe1c9f8 100644 --- a/other/local-llm-server.service +++ b/other/local-llm-server.service @@ -11,6 +11,8 @@ WorkingDirectory=/srv/server/local-llm-server # Sometimes the old processes aren't terminated when the service is restarted. ExecStartPre=/usr/bin/pkill -9 -f "/srv/server/local-llm-server/venv/bin/python3 /srv/server/local-llm-server/venv/bin/gunicorn" +# TODO: make sure gunicorn logs to stdout and logging also goes to stdout + # Need a lot of workers since we have long-running requests. This takes about 3.5G memory. ExecStart=/srv/server/local-llm-server/venv/bin/gunicorn --workers 20 --bind 0.0.0.0:5000 server:app --timeout 60 --worker-class gevent --access-logfile '-' --error-logfile '-' diff --git a/requirements.txt b/requirements.txt index 059d2f0..f503974 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,6 @@ Flask-Caching==2.0.2 requests~=2.31.0 tiktoken~=0.5.0 gevent~=23.9.0.post1 -mysql-connector-python==8.4.0 simplejson~=3.19.1 websockets~=11.0.3 basicauth~=1.0.0 @@ -17,3 +16,4 @@ vllm==0.2.7 coloredlogs~=15.0.1 git+https://git.evulid.cc/cyberes/bison.git pydantic +psycopg2-binary==2.9.9 \ No newline at end of file diff --git a/server.py b/server.py index ec69a04..e9b358e 100644 --- a/server.py +++ b/server.py @@ -1,33 +1,13 @@ -import time - -from llm_server.config.global_config import GlobalConfig - -try: - import gevent.monkey - - gevent.monkey.patch_all() -except ImportError: - pass - -import logging -import os -import sys -from pathlib import Path - import simplejson as json from flask import Flask, jsonify, render_template, request, Response -import config from llm_server.cluster.backend import get_model_choices from llm_server.cluster.cluster_config import cluster_config -from llm_server.config.config import mode_ui_names -from llm_server.config.load import load_config +from llm_server.config.config import MODE_UI_NAMES +from llm_server.config.global_config import GlobalConfig from llm_server.custom_redis import flask_cache, redis -from llm_server.database.conn import Database -from llm_server.database.create import create_db from llm_server.helpers import auto_set_base_client_api from llm_server.llm.vllm.info import vllm_info -from llm_server.logging import init_logging, create_logger from llm_server.routes.openai import openai_bp, openai_model_bp from llm_server.routes.server_error import handle_server_error from llm_server.routes.v1 import bp @@ -63,32 +43,6 @@ from llm_server.sock import init_wssocket # TODO: add more excluding to SYSTEM__ tokens # TODO: return 200 when returning formatted sillytavern error -script_path = os.path.dirname(os.path.realpath(__file__)) -config_path_environ = os.getenv("CONFIG_PATH") -if config_path_environ: - config_path = config_path_environ -else: - config_path = Path(script_path, 'config', 'config.yml') - -success, msg = load_config(config_path) -if not success: - logger = logging.getLogger('llm_server') - logger.setLevel(logging.INFO) - logger.error(f'Failed to load config: {msg}') - sys.exit(1) - -init_logging(Path(GlobalConfig.get().webserver_log_directory) / 'server.log') -logger = create_logger('Server') -logger.debug('Debug logging enabled.') - -while not redis.get('daemon_started', dtype=bool): - logger.warning('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?') - time.sleep(10) - -logger.info('Started HTTP worker!') - -Database.initialise(maxconn=GlobalConfig.get().mysql.maxconn, host=GlobalConfig.get().mysql.host, user=GlobalConfig.get().mysql.username, password=GlobalConfig.get().mysql.password, database=GlobalConfig.get().mysql.database) -create_db() app = Flask(__name__) @@ -139,13 +93,13 @@ def home(): # to None by the daemon. default_model_info['context_size'] = '-' - if len(config['analytics_tracking_code']): - analytics_tracking_code = f"" + if len(GlobalConfig.get().analytics_tracking_code): + analytics_tracking_code = f"" else: analytics_tracking_code = '' - if config['info_html']: - info_html = config['info_html'] + if GlobalConfig.get().info_html: + info_html = GlobalConfig.get().info_html else: info_html = '' @@ -166,9 +120,9 @@ def home(): client_api=f'https://{base_client_api}', ws_client_api=f'wss://{base_client_api}/v1/stream' if GlobalConfig.get().enable_streaming else 'disabled', default_estimated_wait=default_estimated_wait_sec, - mode_name=mode_ui_names[GlobalConfig.get().frontend_api_mode][0], - api_input_textbox=mode_ui_names[GlobalConfig.get().frontend_api_mode][1], - streaming_input_textbox=mode_ui_names[GlobalConfig.get().frontend_api_mode][2], + mode_name=MODE_UI_NAMES[GlobalConfig.get().frontend_api_mode].name, + api_input_textbox=MODE_UI_NAMES[GlobalConfig.get().frontend_api_mode].api_name, + streaming_input_textbox=MODE_UI_NAMES[GlobalConfig.get().frontend_api_mode].streaming_name, default_context_size=default_model_info['context_size'], stats_json=json.dumps(stats, indent=4, ensure_ascii=False), extra_info=mode_info, @@ -212,6 +166,6 @@ def before_app_request(): if __name__ == "__main__": - # server_startup(None) - print('FLASK MODE - Startup complete!') - app.run(host='0.0.0.0', threaded=False, processes=15) + print('Do not run this file directly. Instead, use gunicorn:') + print("gunicorn -c other/gunicorn_conf.py server:app -b 0.0.0.0:5000 --worker-class gevent --workers 3 --access-logfile '-' --error-logfile '-'") + quit(1)