diff --git a/daemon.py b/daemon.py index 3f67a76..909cf34 100644 --- a/daemon.py +++ b/daemon.py @@ -8,9 +8,12 @@ from pathlib import Path from redis import Redis from llm_server.cluster.cluster_config import cluster_config -from llm_server.config.load import load_config, parse_backends +from llm_server.config.global_config import GlobalConfig +from llm_server.config.load import load_config from llm_server.custom_redis import redis +from llm_server.database.conn import Database from llm_server.database.create import create_db +from llm_server.database.database import get_number_of_rows from llm_server.logging import create_logger, logging_info, init_logging from llm_server.routes.v1.generate_stats import generate_stats from llm_server.workers.threader import start_background @@ -39,19 +42,23 @@ if __name__ == "__main__": Redis().flushall() logger.info('Flushed Redis.') - success, config, msg = load_config(config_path) + success, msg = load_config(config_path) if not success: logger.info(f'Failed to load config: {msg}') sys.exit(1) + Database.initialise(maxconn=GlobalConfig.get().mysql.maxconn, host=GlobalConfig.get().mysql.host, user=GlobalConfig.get().mysql.username, password=GlobalConfig.get().mysql.password, database=GlobalConfig.get().mysql.database) create_db() cluster_config.clear() - cluster_config.load(parse_backends(config)) + cluster_config.load() logger.info('Loading backend stats...') generate_stats(regen=True) + if GlobalConfig.get().load_num_prompts: + redis.set('proompts', get_number_of_rows('prompts')) + start_background() # Give some time for the background threads to get themselves ready to go. diff --git a/llm_server/cluster/backend.py b/llm_server/cluster/backend.py index 1114439..7fa26c8 100644 --- a/llm_server/cluster/backend.py +++ b/llm_server/cluster/backend.py @@ -1,8 +1,8 @@ import numpy as np -from llm_server import opts from llm_server.cluster.cluster_config import get_a_cluster_backend, cluster_config from llm_server.cluster.stores import redis_running_models +from llm_server.config.global_config import GlobalConfig from llm_server.custom_redis import redis from llm_server.llm.generator import generator from llm_server.llm.info import get_info @@ -108,8 +108,8 @@ def get_model_choices(regen: bool = False) -> tuple[dict, dict]: model_choices[model] = { 'model': model, 'client_api': f'https://{base_client_api}/{model}', - 'ws_client_api': f'wss://{base_client_api}/{model}/v1/stream' if opts.enable_streaming else None, - 'openai_client_api': f'https://{base_client_api}/openai/{model}/v1' if opts.enable_openi_compatible_backend else 'disabled', + 'ws_client_api': f'wss://{base_client_api}/{model}/v1/stream' if GlobalConfig.get().enable_streaming else None, + 'openai_client_api': f'https://{base_client_api}/openai/{model}/v1' if GlobalConfig.get().enable_openi_compatible_backend else 'disabled', 'backend_count': len(b), 'estimated_wait': estimated_wait_sec, 'queued': proompters_in_queue, diff --git a/llm_server/cluster/cluster_config.py b/llm_server/cluster/cluster_config.py index 8040fa0..ecee40d 100644 --- a/llm_server/cluster/cluster_config.py +++ b/llm_server/cluster/cluster_config.py @@ -2,9 +2,9 @@ import hashlib import pickle import traceback -from llm_server import opts from llm_server.cluster.redis_cycle import add_backend_cycler, redis_cycle from llm_server.cluster.stores import redis_running_models +from llm_server.config.global_config import GlobalConfig from llm_server.custom_redis import RedisCustom from llm_server.logging import create_logger from llm_server.routes.helpers.model import estimate_model_size @@ -26,8 +26,13 @@ class RedisClusterStore: def clear(self): self.config_redis.flush() - def load(self, config: dict): - for k, v in config.items(): + def load(self): + stuff = {} + for item in GlobalConfig.get().cluster: + backend_url = item.backend_url.strip('/') + item.backend_url = backend_url + stuff[backend_url] = item + for k, v in stuff.items(): self.add_backend(k, v) def add_backend(self, name: str, values: dict): @@ -92,7 +97,7 @@ def get_backends(): result[k] = {'status': status, 'priority': priority} try: - if not opts.prioritize_by_size: + if not GlobalConfig.get().prioritize_by_size: online_backends = sorted( ((url, info) for url, info in backends.items() if info['online']), key=lambda kv: -kv[1]['priority'], diff --git a/llm_server/config/config.py b/llm_server/config/config.py index 9d40948..55c8538 100644 --- a/llm_server/config/config.py +++ b/llm_server/config/config.py @@ -1,81 +1,14 @@ -import yaml +from llm_server.config.global_config import GlobalConfig + + +def cluster_worker_count(): + count = 0 + for item in GlobalConfig.get().cluster: + count += item['concurrent_gens'] + return count -config_default_vars = { - 'log_prompts': False, - 'auth_required': False, - 'frontend_api_client': '', - 'verify_ssl': True, - 'load_num_prompts': False, - 'show_num_prompts': True, - 'show_uptime': True, - 'analytics_tracking_code': '', - 'average_generation_time_mode': 'database', - 'info_html': None, - 'show_total_output_tokens': True, - 'simultaneous_requests_per_ip': 3, - 'max_new_tokens': 500, - 'manual_model_name': False, - 'enable_streaming': True, - 'enable_openi_compatible_backend': True, - 'openai_api_key': None, - 'expose_openai_system_prompt': True, - 'openai_system_prompt': """You are an assistant chatbot. Your main function is to provide accurate and helpful responses to the user's queries. You should always be polite, respectful, and patient. You should not provide any personal opinions or advice unless specifically asked by the user. You should not make any assumptions about the user's knowledge or abilities. You should always strive to provide clear and concise answers. If you do not understand a user's query, ask for clarification. If you cannot provide an answer, apologize and suggest the user seek help elsewhere.\nLines that start with "### ASSISTANT" were messages you sent previously.\nLines that start with "### USER" were messages sent by the user you are chatting with.\nYou will respond to the "### RESPONSE:" prompt as the assistant and follow the instructions given by the user.\n\n""", - 'http_host': None, - 'admin_token': None, - 'openai_expose_our_model': False, - 'openai_force_no_hashes': True, - 'include_system_tokens_in_stats': True, - 'openai_moderation_scan_last_n': 5, - 'openai_org_name': 'OpenAI', - 'openai_silent_trim': False, - 'openai_moderation_enabled': True, - 'netdata_root': None, - 'show_backends': True, - 'background_homepage_cacher': True, - 'openai_moderation_timeout': 5, - 'prioritize_by_size': False -} -config_required_vars = ['cluster', 'frontend_api_mode', 'llm_middleware_name'] mode_ui_names = { 'ooba': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'), 'vllm': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'), } - - -class ConfigLoader: - """ - default_vars = {"var1": "default1", "var2": "default2"} - required_vars = ["var1", "var3"] - config_loader = ConfigLoader("config.yaml", default_vars, required_vars) - config = config_loader.load_config() - """ - - def __init__(self, config_file, default_vars=None, required_vars=None): - self.config_file = config_file - self.default_vars = default_vars if default_vars else {} - self.required_vars = required_vars if required_vars else [] - self.config = {} - - def load_config(self) -> (bool, str | None, str | None): - with open(self.config_file, 'r') as stream: - try: - self.config = yaml.safe_load(stream) - except yaml.YAMLError as exc: - return False, None, exc - - if self.config is None: - # Handle empty file - self.config = {} - - # Set default variables if they are not present in the config file - for var, default_value in self.default_vars.items(): - if var not in self.config: - self.config[var] = default_value - - # Check if required variables are present in the config file - for var in self.required_vars: - if var not in self.config: - return False, None, f'Required variable "{var}" is missing from the config file' - - return True, self.config, None diff --git a/llm_server/config/global_config.py b/llm_server/config/global_config.py new file mode 100644 index 0000000..cc62e46 --- /dev/null +++ b/llm_server/config/global_config.py @@ -0,0 +1,15 @@ +from llm_server.config.model import ConfigModel + + +class GlobalConfig: + __config_model: ConfigModel = None + + @classmethod + def initalize(cls, config: ConfigModel): + if cls.__config_model is not None: + raise Exception('Config is already initialised') + cls.__config_model = config + + @classmethod + def get(cls): + return cls.__config_model diff --git a/llm_server/config/load.py b/llm_server/config/load.py index 0917f25..39786bb 100644 --- a/llm_server/config/load.py +++ b/llm_server/config/load.py @@ -1,94 +1,86 @@ import re import sys +from pathlib import Path import openai +from bison import bison, Option, ListOption, Scheme import llm_server -from llm_server import opts -from llm_server.config.config import ConfigLoader, config_default_vars, config_required_vars +from llm_server.config.global_config import GlobalConfig +from llm_server.config.model import ConfigModel +from llm_server.config.scheme import config_scheme from llm_server.custom_redis import redis -from llm_server.database.conn import Database -from llm_server.database.database import get_number_of_rows from llm_server.logging import create_logger from llm_server.routes.queue import PriorityQueue _logger = create_logger('config') -def load_config(config_path): - config_loader = ConfigLoader(config_path, config_default_vars, config_required_vars) - success, config, msg = config_loader.load_config() - if not success: - return success, config, msg +def validate_config(config: bison.Bison): + def do(v, scheme: Scheme = None): + if isinstance(v, Option) and v.choices is None: + if not isinstance(config.config[v.name], v.type): + raise ValueError(f'"{v.name}" must be type {v.type}. Current value: "{config.config[v.name]}"') + elif isinstance(v, Option) and v.choices is not None: + if config.config[v.name] not in v.choices: + raise ValueError(f'"{v.name}" must be one of {v.choices}. Current value: "{config.config[v.name]}"') + elif isinstance(v, ListOption): + if isinstance(config.config[v.name], list): + for item in config.config[v.name]: + do(item, v.member_scheme) + elif isinstance(config.config[v.name], dict): + for kk, vv in config.config[v.name].items(): + scheme_dict = v.member_scheme.flatten() + if not isinstance(vv, scheme_dict[kk].type): + raise ValueError(f'"{kk}" must be type {scheme_dict[kk].type}. Current value: "{vv}"') + elif isinstance(scheme_dict[kk], Option) and scheme_dict[kk].choices is not None: + if vv not in scheme_dict[kk].choices: + raise ValueError(f'"{kk}" must be one of {scheme_dict[kk].choices}. Current value: "{vv}"') + elif isinstance(v, dict) and scheme is not None: + scheme_dict = scheme.flatten() + for kk, vv in v.items(): + if not isinstance(vv, scheme_dict[kk].type): + raise ValueError(f'"{kk}" must be type {scheme_dict[kk].type}. Current value: "{vv}"') + elif isinstance(scheme_dict[kk], Option) and scheme_dict[kk].choices is not None: + if vv not in scheme_dict[kk].choices: + raise ValueError(f'"{kk}" must be one of {scheme_dict[kk].choices}. Current value: "{vv}"') - # TODO: this is atrocious - opts.auth_required = config['auth_required'] - opts.log_prompts = config['log_prompts'] - opts.frontend_api_client = config['frontend_api_client'] - opts.show_num_prompts = config['show_num_prompts'] - opts.show_uptime = config['show_uptime'] - opts.cluster = config['cluster'] - opts.show_total_output_tokens = config['show_total_output_tokens'] - opts.netdata_root = config['netdata_root'] - opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip'] - opts.max_new_tokens = config['max_new_tokens'] - opts.manual_model_name = config['manual_model_name'] - opts.llm_middleware_name = config['llm_middleware_name'] - opts.enable_openi_compatible_backend = config['enable_openi_compatible_backend'] - opts.openai_system_prompt = config['openai_system_prompt'] - opts.expose_openai_system_prompt = config['expose_openai_system_prompt'] - opts.enable_streaming = config['enable_streaming'] - opts.openai_api_key = config['openai_api_key'] - openai.api_key = opts.openai_api_key - opts.admin_token = config['admin_token'] - opts.openai_expose_our_model = config['openai_expose_our_model'] - opts.openai_force_no_hashes = config['openai_force_no_hashes'] - opts.include_system_tokens_in_stats = config['include_system_tokens_in_stats'] - opts.openai_moderation_scan_last_n = config['openai_moderation_scan_last_n'] - opts.openai_org_name = config['openai_org_name'] - opts.openai_silent_trim = config['openai_silent_trim'] - opts.openai_moderation_enabled = config['openai_moderation_enabled'] - opts.show_backends = config['show_backends'] - opts.background_homepage_cacher = config['background_homepage_cacher'] - opts.openai_moderation_timeout = config['openai_moderation_timeout'] - opts.frontend_api_mode = config['frontend_api_mode'] - opts.prioritize_by_size = config['prioritize_by_size'] + for k, v in config_scheme.flatten().items(): + do(v) - # Scale the number of workers. - for item in config['cluster']: - opts.cluster_workers += item['concurrent_gens'] - llm_server.routes.queue.priority_queue = PriorityQueue([x['backend_url'] for x in config['cluster']]) +def load_config(config_path: Path): + config = bison.Bison(scheme=config_scheme) + config.config_name = 'config' + config.add_config_paths(str(config_path.parent)) + config.parse() - if opts.openai_expose_our_model and not opts.openai_api_key: + try: + validate_config(config) + except ValueError as e: + return False, str(e) + + config_model = ConfigModel(**config.config) + GlobalConfig.initalize(config_model) + + if not (0 < GlobalConfig.get().mysql.maxconn <= 32): + return False, f'"maxcon" should be higher than 0 and lower or equal to 32. Current value: "{GlobalConfig.get().mysql.maxconn}"' + + openai.api_key = GlobalConfig.get().openai_api_key + + llm_server.routes.queue.priority_queue = PriorityQueue(set([x.backend_url for x in config_model.cluster])) + + if GlobalConfig.get().openai_expose_our_model and not GlobalConfig.get().openai_api_key: _logger.error('If you set openai_expose_our_model to false, you must set your OpenAI key in openai_api_key.') sys.exit(1) - opts.verify_ssl = config['verify_ssl'] - if not opts.verify_ssl: + if not GlobalConfig.get().verify_ssl: import urllib3 urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - if config['http_host']: + if GlobalConfig.get().http_host: http_host = re.sub(r'https?://', '', config["http_host"]) redis.set('http_host', http_host) - redis.set('base_client_api', f'{http_host}/{opts.frontend_api_client.strip("/")}') + redis.set('base_client_api', f'{http_host}/{GlobalConfig.get().frontend_api_client.strip("/")}') - Database.initialise(maxconn=config['mysql']['maxconn'], host=config['mysql']['host'], user=config['mysql']['username'], password=config['mysql']['password'], database=config['mysql']['database']) - - if config['load_num_prompts']: - redis.set('proompts', get_number_of_rows('prompts')) - - return success, config, msg - - -def parse_backends(config): - if not config.get('cluster'): - return False - cluster = config.get('cluster') - config = {} - for item in cluster: - backend_url = item['backend_url'].strip('/') - item['backend_url'] = backend_url - config[backend_url] = item - return config + return True, None diff --git a/llm_server/config/model.py b/llm_server/config/model.py new file mode 100644 index 0000000..b7cf156 --- /dev/null +++ b/llm_server/config/model.py @@ -0,0 +1,74 @@ +from enum import Enum +from typing import Union, List + +from pydantic import BaseModel + + +class ConfigClusterMode(str, Enum): + vllm = 'vllm' + + +class ConfigCluser(BaseModel): + backend_url: str + concurrent_gens: int + mode: ConfigClusterMode + priority: int + + +class ConfigFrontendApiModes(str, Enum): + ooba = 'ooba' + + +class ConfigMysql(BaseModel): + host: str + username: str + password: str + database: str + maxconn: int + + +class ConfigAvgGenTimeModes(str, Enum): + database = 'database' + minute = 'minute' + + +class ConfigModel(BaseModel): + frontend_api_mode: ConfigFrontendApiModes + cluster: List[ConfigCluser] + prioritize_by_size: bool + admin_token: Union[str, None] + mysql: ConfigMysql + http_host: str + webserver_log_directory: str + include_system_tokens_in_stats: bool + background_homepage_cacher: bool + max_new_tokens: int + enable_streaming: int + show_backends: bool + log_prompts: bool + verify_ssl: bool + auth_required: bool + simultaneous_requests_per_ip: int + max_queued_prompts_per_ip: int + llm_middleware_name: str + analytics_tracking_code: Union[str, None] + info_html: Union[str, None] + enable_openi_compatible_backend: bool + openai_api_key: Union[str, None] + expose_openai_system_prompt: bool + openai_expose_our_model: bool + openai_force_no_hashes: bool + openai_moderation_enabled: bool + openai_moderation_timeout: int + openai_moderation_scan_last_n: int + openai_org_name: str + openai_silent_trim: bool + frontend_api_client: str + average_generation_time_mode: ConfigAvgGenTimeModes + show_num_prompts: bool + show_uptime: bool + show_total_output_tokens: bool + show_backend_info: bool + load_num_prompts: bool + manual_model_name: Union[str, None] + backend_request_timeout: int diff --git a/llm_server/config/scheme.py b/llm_server/config/scheme.py new file mode 100644 index 0000000..822a4b3 --- /dev/null +++ b/llm_server/config/scheme.py @@ -0,0 +1,59 @@ +from typing import Union + +import bison + +from llm_server.opts import default_openai_system_prompt + +config_scheme = bison.Scheme( + bison.Option('frontend_api_mode', choices=['ooba'], field_type=str), + bison.ListOption('cluster', member_scheme=bison.Scheme( + bison.Option('backend_url', field_type=str), + bison.Option('concurrent_gens', field_type=int), + bison.Option('mode', choices=['vllm'], field_type=str), + bison.Option('priority', field_type=int), + )), + bison.Option('prioritize_by_size', default=True, field_type=bool), + bison.Option('admin_token', default=None, field_type=Union[str, None]), + bison.ListOption('mysql', member_scheme=bison.Scheme( + bison.Option('host', field_type=str), + bison.Option('username', field_type=str), + bison.Option('password', field_type=str), + bison.Option('database', field_type=str), + bison.Option('maxconn', field_type=int) + )), + bison.Option('http_host', default='', field_type=str), + bison.Option('webserver_log_directory', default='/var/log/localllm', field_type=str), + bison.Option('include_system_tokens_in_stats', default=True, field_type=bool), + bison.Option('background_homepage_cacher', default=True, field_type=bool), + bison.Option('max_new_tokens', default=500, field_type=int), + bison.Option('enable_streaming', default=True, field_type=bool), + bison.Option('show_backends', default=True, field_type=bool), + bison.Option('log_prompts', default=True, field_type=bool), + bison.Option('verify_ssl', default=False, field_type=bool), + bison.Option('auth_required', default=False, field_type=bool), + bison.Option('simultaneous_requests_per_ip', default=1, field_type=int), + bison.Option('max_queued_prompts_per_ip', default=1, field_type=int), + bison.Option('llm_middleware_name', default='LocalLLM', field_type=str), + bison.Option('analytics_tracking_code', default=None, field_type=Union[str, None]), + bison.Option('info_html', default=None, field_type=Union[str, None]), + bison.Option('enable_openi_compatible_backend', default=True, field_type=bool), + bison.Option('openai_api_key', default=None, field_type=Union[str, None]), + bison.Option('expose_openai_system_prompt', default=True, field_type=bool), + bison.Option('openai_expose_our_model', default='', field_type=bool), + bison.Option('openai_force_no_hashes', default=True, field_type=bool), + bison.Option('openai_system_prompt', default=default_openai_system_prompt, field_type=str), + bison.Option('openai_moderation_enabled', default=False, field_type=bool), + bison.Option('openai_moderation_timeout', default=5, field_type=int), + bison.Option('openai_moderation_scan_last_n', default=5, field_type=int), + bison.Option('openai_org_name', default='OpenAI', field_type=str), + bison.Option('openai_silent_trim', default=True, field_type=bool), + bison.Option('frontend_api_client', default='/api', field_type=str), + bison.Option('average_generation_time_mode', default='database', choices=['database', 'minute'], field_type=str), + bison.Option('show_num_prompts', default=True, field_type=bool), + bison.Option('show_uptime', default=True, field_type=bool), + bison.Option('show_total_output_tokens', default=True, field_type=bool), + bison.Option('show_backend_info', default=True, field_type=bool), + bison.Option('load_num_prompts', default=True, field_type=bool), + bison.Option('manual_model_name', default=None, field_type=Union[str, None]), + bison.Option('backend_request_timeout', default=30, field_type=int) +) diff --git a/llm_server/database/conn.py b/llm_server/database/conn.py index 261b4c5..6a0f063 100644 --- a/llm_server/database/conn.py +++ b/llm_server/database/conn.py @@ -5,7 +5,7 @@ class Database: __connection_pool = None @classmethod - def initialise(cls, maxconn, **kwargs): + def initialise(cls, maxconn: int, **kwargs): if cls.__connection_pool is not None: raise Exception('Database connection pool is already initialised') cls.__connection_pool = pooling.MySQLConnectionPool(pool_size=maxconn, diff --git a/llm_server/database/database.py b/llm_server/database/database.py index a1c5d80..1d5e46e 100644 --- a/llm_server/database/database.py +++ b/llm_server/database/database.py @@ -3,8 +3,8 @@ import time import traceback from typing import Union -from llm_server import opts from llm_server.cluster.cluster_config import cluster_config +from llm_server.config.global_config import GlobalConfig from llm_server.database.conn import CursorFromConnectionFromPool from llm_server.llm import get_token_count @@ -38,10 +38,10 @@ def do_db_log(ip: str, token: str, prompt: str, response: Union[str, None], gen_ if is_error: gen_time = None - if not opts.log_prompts: + if not GlobalConfig.get().log_prompts: prompt = None - if not opts.log_prompts and not is_error: + if not GlobalConfig.get().log_prompts and not is_error: # TODO: test and verify this works as expected response = None @@ -75,13 +75,13 @@ def is_valid_api_key(api_key): def is_api_key_moderated(api_key): if not api_key: - return opts.openai_moderation_enabled + return GlobalConfig.get().openai_moderation_enabled with CursorFromConnectionFromPool() as cursor: cursor.execute("SELECT openai_moderation_enabled FROM token_auth WHERE token = %s", (api_key,)) row = cursor.fetchone() if row is not None: return bool(row[0]) - return opts.openai_moderation_enabled + return GlobalConfig.get().openai_moderation_enabled def get_number_of_rows(table_name): @@ -160,7 +160,7 @@ def increment_token_uses(token): def get_token_ratelimit(token): priority = 9990 - simultaneous_ip = opts.simultaneous_requests_per_ip + simultaneous_ip = GlobalConfig.get().simultaneous_requests_per_ip if token: with CursorFromConnectionFromPool() as cursor: cursor.execute("SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s", (token,)) diff --git a/llm_server/helpers.py b/llm_server/helpers.py index 91f3b15..add04ff 100644 --- a/llm_server/helpers.py +++ b/llm_server/helpers.py @@ -7,7 +7,7 @@ from typing import Union import simplejson as json from flask import make_response -from llm_server import opts +from llm_server.config.global_config import GlobalConfig from llm_server.custom_redis import redis @@ -68,4 +68,4 @@ def auto_set_base_client_api(request): return else: redis.set('http_host', host) - redis.set('base_client_api', f'{host}/{opts.frontend_api_client.strip("/")}') + redis.set('base_client_api', f'{host}/{GlobalConfig.get().frontend_api_client.strip("/")}') diff --git a/llm_server/llm/info.py b/llm_server/llm/info.py index d1218e2..a4a5654 100644 --- a/llm_server/llm/info.py +++ b/llm_server/llm/info.py @@ -1,19 +1,19 @@ import requests -from llm_server import opts +from llm_server.config.global_config import GlobalConfig def get_running_model(backend_url: str, mode: str): if mode == 'ooba': try: - backend_response = requests.get(f'{backend_url}/api/v1/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl) + backend_response = requests.get(f'{backend_url}/api/v1/model', timeout=GlobalConfig.get().backend_request_timeout, verify=GlobalConfig.get().verify_ssl) r_json = backend_response.json() return r_json['result'], None except Exception as e: return False, e elif mode == 'vllm': try: - backend_response = requests.get(f'{backend_url}/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl) + backend_response = requests.get(f'{backend_url}/model', timeout=GlobalConfig.get().backend_request_timeout, verify=GlobalConfig.get().verify_ssl) r_json = backend_response.json() return r_json['model'], None except Exception as e: @@ -28,7 +28,7 @@ def get_info(backend_url: str, mode: str): # raise NotImplementedError elif mode == 'vllm': try: - r = requests.get(f'{backend_url}/info', verify=opts.verify_ssl, timeout=opts.backend_request_timeout) + r = requests.get(f'{backend_url}/info', verify=GlobalConfig.get().verify_ssl, timeout=GlobalConfig.get().backend_request_timeout) j = r.json() except Exception as e: return {} diff --git a/llm_server/llm/oobabooga/generate.py b/llm_server/llm/oobabooga/generate.py index e352a30..98b0d4c 100644 --- a/llm_server/llm/oobabooga/generate.py +++ b/llm_server/llm/oobabooga/generate.py @@ -5,12 +5,12 @@ import traceback import requests -from llm_server import opts +from llm_server.config.global_config import GlobalConfig def generate(json_data: dict): try: - r = requests.post(f'{opts.backend_url}/api/v1/generate', json=json_data, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) + r = requests.post(f'{GlobalConfig.get().backend_url}/api/v1/generate', json=json_data, verify=GlobalConfig.get().verify_ssl, timeout=GlobalConfig.get().backend_generate_request_timeout) except requests.exceptions.ReadTimeout: return False, None, 'Request to backend timed out' except Exception as e: diff --git a/llm_server/llm/openai/moderation.py b/llm_server/llm/openai/moderation.py index ee63e5a..92409e7 100644 --- a/llm_server/llm/openai/moderation.py +++ b/llm_server/llm/openai/moderation.py @@ -1,6 +1,6 @@ import requests -from llm_server import opts +from llm_server.config.global_config import GlobalConfig from llm_server.logging import create_logger _logger = create_logger('moderation') @@ -9,7 +9,7 @@ _logger = create_logger('moderation') def check_moderation_endpoint(prompt: str): headers = { 'Content-Type': 'application/json', - 'Authorization': f"Bearer {opts.openai_api_key}", + 'Authorization': f"Bearer {GlobalConfig.get().openai_api_key}", } response = requests.post('https://api.openai.com/v1/moderations', headers=headers, json={"input": prompt}, timeout=10) if response.status_code != 200: diff --git a/llm_server/llm/openai/oai_to_vllm.py b/llm_server/llm/openai/oai_to_vllm.py index feb9364..4f8dd24 100644 --- a/llm_server/llm/openai/oai_to_vllm.py +++ b/llm_server/llm/openai/oai_to_vllm.py @@ -1,6 +1,6 @@ from flask import jsonify -from llm_server import opts +from llm_server.config.global_config import GlobalConfig from llm_server.logging import create_logger _logger = create_logger('oai_to_vllm') @@ -14,7 +14,7 @@ def oai_to_vllm(request_json_body, stop_hashes: bool, mode): request_json_body['stop'] = [request_json_body['stop']] if stop_hashes: - if opts.openai_force_no_hashes: + if GlobalConfig.get().openai_force_no_hashes: request_json_body['stop'].append('###') else: # TODO: make stopping strings a configurable @@ -30,7 +30,7 @@ def oai_to_vllm(request_json_body, stop_hashes: bool, mode): if mode == 'vllm' and request_json_body.get('top_p') == 0: request_json_body['top_p'] = 0.01 - request_json_body['max_tokens'] = min(max(request_json_body.get('max_new_tokens', 0), request_json_body.get('max_tokens', 0)), opts.max_new_tokens) + request_json_body['max_tokens'] = min(max(request_json_body.get('max_new_tokens', 0), request_json_body.get('max_tokens', 0)), GlobalConfig.get().max_new_tokens) if request_json_body['max_tokens'] == 0: # We don't want to set any defaults here. del request_json_body['max_tokens'] diff --git a/llm_server/llm/openai/transform.py b/llm_server/llm/openai/transform.py index daec3dc..88681c2 100644 --- a/llm_server/llm/openai/transform.py +++ b/llm_server/llm/openai/transform.py @@ -7,7 +7,7 @@ from typing import Dict, List import tiktoken -from llm_server import opts +from llm_server.config.global_config import GlobalConfig from llm_server.llm import get_token_count ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line. @@ -85,7 +85,7 @@ def trim_string_to_fit(prompt: str, context_token_limit: int, backend_url: str) def transform_messages_to_prompt(oai_messages): try: - prompt = f'### INSTRUCTION: {opts.openai_system_prompt}' + prompt = f'### INSTRUCTION: {GlobalConfig.get().openai_system_prompt}' for msg in oai_messages: if 'content' not in msg.keys() or 'role' not in msg.keys(): return False diff --git a/llm_server/llm/vllm/generate.py b/llm_server/llm/vllm/generate.py index 31cd511..3e7926e 100644 --- a/llm_server/llm/vllm/generate.py +++ b/llm_server/llm/vllm/generate.py @@ -4,7 +4,7 @@ This file is used by the worker that processes requests. import requests -from llm_server import opts +from llm_server.config.global_config import GlobalConfig # TODO: make the VLMM backend return TPS and time elapsed @@ -25,7 +25,7 @@ def transform_prompt_to_text(prompt: list): def handle_blocking_request(json_data: dict, cluster_backend, timeout: int = 10): try: - r = requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout if not timeout else timeout) + r = requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), verify=GlobalConfig.get().verify_ssl, timeout=GlobalConfig.get().backend_generate_request_timeout if not timeout else timeout) except requests.exceptions.ReadTimeout: # print(f'Failed to reach VLLM inference endpoint - request to backend timed out') return False, None, 'Request to backend timed out' @@ -41,7 +41,7 @@ def handle_blocking_request(json_data: dict, cluster_backend, timeout: int = 10) def generate(json_data: dict, cluster_backend, timeout: int = None): if json_data.get('stream'): try: - return requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout if not timeout else timeout) + return requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), stream=True, verify=GlobalConfig.get().verify_ssl, timeout=GlobalConfig.get().backend_generate_request_timeout if not timeout else timeout) except Exception as e: return False else: diff --git a/llm_server/llm/vllm/info.py b/llm_server/llm/vllm/info.py index 0142301..b6205c3 100644 --- a/llm_server/llm/vllm/info.py +++ b/llm_server/llm/vllm/info.py @@ -1,7 +1,3 @@ -import requests - -from llm_server import opts - vllm_info = """

Important: This endpoint is running vllm and not all Oobabooga parameters are supported.

Supported Parameters: