61 lines
2.0 KiB
Python
61 lines
2.0 KiB
Python
import yaml
|
|
|
|
config_default_vars = {
|
|
'log_prompts': False,
|
|
'database_path': './proxy-server.db',
|
|
'auth_required': False,
|
|
'frontend_api_client': '',
|
|
'verify_ssl': True,
|
|
'load_num_prompts': False,
|
|
'show_num_prompts': True,
|
|
'show_uptime': True,
|
|
'analytics_tracking_code': '',
|
|
'average_generation_time_mode': 'database',
|
|
'info_html': None,
|
|
'show_total_output_tokens': True,
|
|
}
|
|
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']
|
|
|
|
mode_ui_names = {
|
|
'oobabooga': ('Text Gen WebUI (ooba)', 'Blocking API url'),
|
|
'hf-textgen': ('UNDEFINED', 'UNDEFINED'),
|
|
}
|
|
|
|
|
|
class ConfigLoader:
|
|
"""
|
|
default_vars = {"var1": "default1", "var2": "default2"}
|
|
required_vars = ["var1", "var3"]
|
|
config_loader = ConfigLoader("config.yaml", default_vars, required_vars)
|
|
config = config_loader.load_config()
|
|
"""
|
|
|
|
def __init__(self, config_file, default_vars=None, required_vars=None):
|
|
self.config_file = config_file
|
|
self.default_vars = default_vars if default_vars else {}
|
|
self.required_vars = required_vars if required_vars else []
|
|
self.config = {}
|
|
|
|
def load_config(self) -> (bool, str | None, str | None):
|
|
with open(self.config_file, 'r') as stream:
|
|
try:
|
|
self.config = yaml.safe_load(stream)
|
|
except yaml.YAMLError as exc:
|
|
return False, None, exc
|
|
|
|
if self.config is None:
|
|
# Handle empty file
|
|
self.config = {}
|
|
|
|
# Set default variables if they are not present in the config file
|
|
for var, default_value in self.default_vars.items():
|
|
if var not in self.config:
|
|
self.config[var] = default_value
|
|
|
|
# Check if required variables are present in the config file
|
|
for var in self.required_vars:
|
|
if var not in self.config:
|
|
return False, None, f'Required variable "{var}" is missing from the config file'
|
|
|
|
return True, self.config, None
|