MatrixGPT/matrix_gpt/config.py

168 lines
6.6 KiB
Python
Raw Normal View History

2024-04-07 19:41:19 -06:00
import copy
from pathlib import Path
from types import NoneType
import bison
2024-04-07 22:27:00 -06:00
from bison.errors import SchemeValidationError
2024-04-07 19:41:19 -06:00
VALID_API_TYPES = ['openai', 'anthropic', 'copilot']
2024-04-07 19:41:19 -06:00
config_scheme = bison.Scheme(
bison.Option('store_path', default='bot-store/', field_type=str),
bison.DictOption('auth', scheme=bison.Scheme(
bison.Option('username', field_type=str, required=True),
bison.Option('password', field_type=str, required=True),
bison.Option('homeserver', field_type=str, required=True),
bison.Option('device_id', field_type=str, required=True),
)),
2024-04-07 22:27:00 -06:00
bison.ListOption('allowed_to_chat', member_type=str, default=['all']),
bison.ListOption('allowed_to_thread', member_type=str, default=['all']),
bison.ListOption('allowed_to_invite', member_type=str, default=['all']),
2024-04-07 19:41:19 -06:00
bison.ListOption('autojoin_rooms', default=[]),
bison.ListOption('blacklist_rooms', default=[]),
bison.Option('response_timeout', default=120, field_type=int),
2024-04-07 22:27:00 -06:00
bison.ListOption('command', required=True, member_scheme=bison.Scheme(
2024-04-07 19:41:19 -06:00
bison.Option('trigger', field_type=str, required=True),
bison.Option('api_type', field_type=str, choices=VALID_API_TYPES, required=True),
2024-04-07 19:41:19 -06:00
bison.Option('model', field_type=str, required=True),
2024-04-07 22:27:00 -06:00
bison.Option('max_tokens', field_type=int, default=0),
2024-04-07 22:57:24 -06:00
bison.Option('temperature', field_type=[int, float], default=0.5),
2024-04-07 22:27:00 -06:00
bison.ListOption('allowed_to_chat', member_type=str, default=[]),
bison.ListOption('allowed_to_thread', member_type=str, default=[]),
bison.ListOption('allowed_to_invite', member_type=str, default=[]),
bison.Option('system_prompt', field_type=str, default=None),
bison.Option('injected_system_prompt', field_type=str, default=None),
bison.Option('api_base', field_type=[str, NoneType], default=None),
2024-04-09 19:26:44 -06:00
bison.Option('vision', field_type=bool, default=False),
bison.Option('help', field_type=[str, NoneType], default=None),
2024-04-07 19:41:19 -06:00
)),
bison.DictOption('openai', scheme=bison.Scheme(
2024-04-07 22:27:00 -06:00
bison.Option('api_key', field_type=[str, NoneType], default=None, required=False),
)),
bison.DictOption('anthropic', scheme=bison.Scheme(
bison.Option('api_key', field_type=[str, NoneType], required=False, default=None),
2024-04-07 19:41:19 -06:00
)),
bison.DictOption('copilot', scheme=bison.Scheme(
bison.Option('api_key', field_type=[str, NoneType], required=False, default=None),
bison.Option('event_encryption_key', field_type=[str, NoneType], required=False, default=None),
)),
2024-04-07 19:41:19 -06:00
bison.DictOption('logging', scheme=bison.Scheme(
bison.Option('log_level', field_type=str, default='info'),
bison.Option('log_full_response', field_type=bool, default=True),
)),
)
2024-04-07 22:27:00 -06:00
# Bison does not support list default options in certain situations.
# Only one level recursive.
DEFAULT_LISTS = {
'command': {
'max_tokens': 0,
'temperature': 0.5,
'allowed_to_chat': [],
'allowed_to_thread': [],
'allowed_to_invite': [],
'system_prompt': None,
'injected_system_prompt': None,
'api_base': None,
2024-04-09 19:34:05 -06:00
'vision': False,
'help': None,
2024-04-07 22:27:00 -06:00
}
}
2024-04-07 19:41:19 -06:00
class ConfigManager:
def __init__(self):
self._config = bison.Bison(scheme=config_scheme)
self._command_prefixes = {}
2024-04-07 22:27:00 -06:00
self._parsed_config = {}
2024-04-07 19:41:19 -06:00
self._loaded = False
2024-04-07 22:27:00 -06:00
self._validated = False
2024-04-07 19:41:19 -06:00
def load(self, path: Path):
2024-04-07 22:27:00 -06:00
assert not self._loaded
2024-04-07 19:41:19 -06:00
self._config.config_name = 'config'
self._config.config_format = bison.bison.YAML
self._config.add_config_paths(str(path.parent))
self._config.parse()
self._loaded = True
def validate(self):
2024-04-07 22:27:00 -06:00
assert not self._validated
2024-04-07 19:41:19 -06:00
self._config.validate()
config_api_keys = 0
for api in VALID_API_TYPES:
if self._config.config[api].get('api_key'):
config_api_keys += 1
if config_api_keys < 1:
raise SchemeValidationError('You need an API key')
2024-04-07 22:27:00 -06:00
self._parsed_config = self._merge_in_list_defaults()
2024-04-09 19:26:44 -06:00
for item in self._config.config['command']:
if item['api_type'] == 'copilot' and item['model'] != 'copilot':
raise SchemeValidationError('The Copilot model type must be set to `copilot`')
2024-04-09 19:26:44 -06:00
# Make sure there aren't duplicate triggers
existing_triggers = []
for item in self._config.config['command']:
trigger = item['trigger']
if trigger in existing_triggers:
raise SchemeValidationError(f'Duplicate trigger {trigger}')
existing_triggers.append(trigger)
if self._config.config.get('copilot') and not self._config.config['copilot'].get('event_encryption_key'):
raise SchemeValidationError('You must set `event_encryption_key` when using copilot')
2024-04-07 22:27:00 -06:00
self._command_prefixes = self._generate_command_prefixes()
def _merge_in_list_defaults(self):
new_config = copy.copy(self._config.config)
for d_k, d_v in DEFAULT_LISTS.items():
for k, v in self._config.config.items():
2024-04-07 22:27:00 -06:00
if k == d_k:
assert isinstance(v, list)
new_list = []
for e in v:
merged_dict = copy.copy(d_v) # create a copy of the default dict
merged_dict.update(e) # update it with the new values
new_list.append(merged_dict)
new_config[k] = new_list
2024-04-07 22:27:00 -06:00
return new_config
2024-04-07 19:41:19 -06:00
@property
def config(self):
2024-04-07 22:27:00 -06:00
return copy.copy(self._parsed_config)
2024-04-07 19:41:19 -06:00
def _generate_command_prefixes(self):
2024-04-07 22:27:00 -06:00
assert not self._validated
2024-04-07 19:41:19 -06:00
command_prefixes = {}
for item in self._parsed_config['command']:
2024-04-07 19:41:19 -06:00
command_prefixes[item['trigger']] = item
if item['api_type'] == 'anthropic' and item.get('max_tokens', 0) < 1:
2024-04-07 22:27:00 -06:00
raise SchemeValidationError(f'Anthropic requires `max_tokens`. See <https://support.anthropic.com/en/articles/7996856-what-is-the-maximum-prompt-length>')
2024-04-07 19:41:19 -06:00
return command_prefixes
@property
def command_prefixes(self):
return self._command_prefixes
def get(self, key, default=None):
return copy.copy(self._config.get(key, default))
def __setitem__(self, key, item):
raise Exception
def __getitem__(self, key):
2024-04-07 22:27:00 -06:00
return self._parsed_config[key]
2024-04-07 19:41:19 -06:00
def __repr__(self):
2024-04-07 22:27:00 -06:00
return repr(self._parsed_config)
2024-04-07 19:41:19 -06:00
def __len__(self):
2024-04-07 22:27:00 -06:00
return len(self._parsed_config)
2024-04-07 19:41:19 -06:00
def __delitem__(self, key):
raise Exception
global_config = ConfigManager()