2024-04-07 19:41:19 -06:00
import copy
from pathlib import Path
from types import NoneType
import bison
2024-04-07 22:27:00 -06:00
from bison . errors import SchemeValidationError
2024-04-07 19:41:19 -06:00
2024-04-10 16:42:52 -06:00
VALID_API_TYPES = [ ' openai ' , ' anthropic ' , ' copilot ' ]
2024-04-07 19:41:19 -06:00
config_scheme = bison . Scheme (
bison . Option ( ' store_path ' , default = ' bot-store/ ' , field_type = str ) ,
bison . DictOption ( ' auth ' , scheme = bison . Scheme (
bison . Option ( ' username ' , field_type = str , required = True ) ,
bison . Option ( ' password ' , field_type = str , required = True ) ,
bison . Option ( ' homeserver ' , field_type = str , required = True ) ,
bison . Option ( ' device_id ' , field_type = str , required = True ) ,
) ) ,
2024-04-07 22:27:00 -06:00
bison . ListOption ( ' allowed_to_chat ' , member_type = str , default = [ ' all ' ] ) ,
bison . ListOption ( ' allowed_to_thread ' , member_type = str , default = [ ' all ' ] ) ,
bison . ListOption ( ' allowed_to_invite ' , member_type = str , default = [ ' all ' ] ) ,
2024-04-07 19:41:19 -06:00
bison . ListOption ( ' autojoin_rooms ' , default = [ ] ) ,
bison . ListOption ( ' blacklist_rooms ' , default = [ ] ) ,
bison . Option ( ' response_timeout ' , default = 120 , field_type = int ) ,
2024-04-07 22:27:00 -06:00
bison . ListOption ( ' command ' , required = True , member_scheme = bison . Scheme (
2024-04-07 19:41:19 -06:00
bison . Option ( ' trigger ' , field_type = str , required = True ) ,
2024-04-10 16:42:52 -06:00
bison . Option ( ' api_type ' , field_type = str , choices = VALID_API_TYPES , required = True ) ,
2024-04-07 19:41:19 -06:00
bison . Option ( ' model ' , field_type = str , required = True ) ,
2024-04-07 22:27:00 -06:00
bison . Option ( ' max_tokens ' , field_type = int , default = 0 ) ,
2024-04-07 22:57:24 -06:00
bison . Option ( ' temperature ' , field_type = [ int , float ] , default = 0.5 ) ,
2024-04-07 22:27:00 -06:00
bison . ListOption ( ' allowed_to_chat ' , member_type = str , default = [ ] ) ,
bison . ListOption ( ' allowed_to_thread ' , member_type = str , default = [ ] ) ,
bison . ListOption ( ' allowed_to_invite ' , member_type = str , default = [ ] ) ,
bison . Option ( ' system_prompt ' , field_type = str , default = None ) ,
bison . Option ( ' injected_system_prompt ' , field_type = str , default = None ) ,
bison . Option ( ' api_base ' , field_type = [ str , NoneType ] , default = None ) ,
2024-04-09 19:26:44 -06:00
bison . Option ( ' vision ' , field_type = bool , default = False ) ,
2024-04-07 23:42:09 -06:00
bison . Option ( ' help ' , field_type = [ str , NoneType ] , default = None ) ,
2024-04-07 19:41:19 -06:00
) ) ,
bison . DictOption ( ' openai ' , scheme = bison . Scheme (
2024-04-07 22:27:00 -06:00
bison . Option ( ' api_key ' , field_type = [ str , NoneType ] , default = None , required = False ) ,
) ) ,
bison . DictOption ( ' anthropic ' , scheme = bison . Scheme (
bison . Option ( ' api_key ' , field_type = [ str , NoneType ] , required = False , default = None ) ,
2024-04-07 19:41:19 -06:00
) ) ,
2024-04-10 16:42:52 -06:00
bison . DictOption ( ' copilot ' , scheme = bison . Scheme (
bison . Option ( ' api_key ' , field_type = [ str , NoneType ] , required = False , default = None ) ,
2024-04-10 22:47:15 -06:00
bison . Option ( ' event_encryption_key ' , field_type = [ str , NoneType ] , required = False , default = None ) ,
2024-04-10 16:42:52 -06:00
) ) ,
2024-04-07 19:41:19 -06:00
bison . DictOption ( ' logging ' , scheme = bison . Scheme (
bison . Option ( ' log_level ' , field_type = str , default = ' info ' ) ,
bison . Option ( ' log_full_response ' , field_type = bool , default = True ) ,
) ) ,
)
2024-04-07 22:27:00 -06:00
# Bison does not support list default options in certain situations.
# Only one level recursive.
DEFAULT_LISTS = {
' command ' : {
' max_tokens ' : 0 ,
' temperature ' : 0.5 ,
' allowed_to_chat ' : [ ] ,
' allowed_to_thread ' : [ ] ,
' allowed_to_invite ' : [ ] ,
' system_prompt ' : None ,
' injected_system_prompt ' : None ,
' api_base ' : None ,
2024-04-09 19:34:05 -06:00
' vision ' : False ,
2024-04-07 23:42:09 -06:00
' help ' : None ,
2024-04-07 22:27:00 -06:00
}
}
2024-04-07 19:41:19 -06:00
class ConfigManager :
def __init__ ( self ) :
self . _config = bison . Bison ( scheme = config_scheme )
self . _command_prefixes = { }
2024-04-07 22:27:00 -06:00
self . _parsed_config = { }
2024-04-07 19:41:19 -06:00
self . _loaded = False
2024-04-07 22:27:00 -06:00
self . _validated = False
2024-04-07 19:41:19 -06:00
def load ( self , path : Path ) :
2024-04-07 22:27:00 -06:00
assert not self . _loaded
2024-04-07 19:41:19 -06:00
self . _config . config_name = ' config '
self . _config . config_format = bison . bison . YAML
self . _config . add_config_paths ( str ( path . parent ) )
self . _config . parse ( )
self . _loaded = True
def validate ( self ) :
2024-04-07 22:27:00 -06:00
assert not self . _validated
2024-04-07 19:41:19 -06:00
self . _config . validate ( )
2024-04-10 16:42:52 -06:00
config_api_keys = 0
for api in VALID_API_TYPES :
if self . _config . config [ api ] . get ( ' api_key ' ) :
config_api_keys + = 1
if config_api_keys < 1 :
raise SchemeValidationError ( ' You need an API key ' )
2024-04-07 22:27:00 -06:00
self . _parsed_config = self . _merge_in_list_defaults ( )
2024-04-09 19:26:44 -06:00
2024-04-10 16:42:52 -06:00
for item in self . _config . config [ ' command ' ] :
if item [ ' api_type ' ] == ' copilot ' and item [ ' model ' ] != ' copilot ' :
raise SchemeValidationError ( ' The Copilot model type must be set to `copilot` ' )
2024-04-09 19:26:44 -06:00
# Make sure there aren't duplicate triggers
existing_triggers = [ ]
for item in self . _config . config [ ' command ' ] :
trigger = item [ ' trigger ' ]
if trigger in existing_triggers :
raise SchemeValidationError ( f ' Duplicate trigger { trigger } ' )
existing_triggers . append ( trigger )
2024-04-10 22:47:15 -06:00
if self . _config . config . get ( ' copilot ' ) and not self . _config . config [ ' copilot ' ] . get ( ' event_encryption_key ' ) :
raise SchemeValidationError ( ' You must set `event_encryption_key` when using copilot ' )
2024-04-07 22:27:00 -06:00
self . _command_prefixes = self . _generate_command_prefixes ( )
def _merge_in_list_defaults ( self ) :
new_config = copy . copy ( self . _config . config )
2024-04-07 23:42:09 -06:00
for d_k , d_v in DEFAULT_LISTS . items ( ) :
for k , v in self . _config . config . items ( ) :
2024-04-07 22:27:00 -06:00
if k == d_k :
assert isinstance ( v , list )
2024-04-07 23:42:09 -06:00
new_list = [ ]
for e in v :
merged_dict = copy . copy ( d_v ) # create a copy of the default dict
merged_dict . update ( e ) # update it with the new values
new_list . append ( merged_dict )
new_config [ k ] = new_list
2024-04-07 22:27:00 -06:00
return new_config
2024-04-07 19:41:19 -06:00
@property
def config ( self ) :
2024-04-07 22:27:00 -06:00
return copy . copy ( self . _parsed_config )
2024-04-07 19:41:19 -06:00
def _generate_command_prefixes ( self ) :
2024-04-07 22:27:00 -06:00
assert not self . _validated
2024-04-07 19:41:19 -06:00
command_prefixes = { }
2024-04-07 23:42:09 -06:00
for item in self . _parsed_config [ ' command ' ] :
2024-04-07 19:41:19 -06:00
command_prefixes [ item [ ' trigger ' ] ] = item
2024-04-10 16:42:52 -06:00
if item [ ' api_type ' ] == ' anthropic ' and item . get ( ' max_tokens ' , 0 ) < 1 :
2024-04-07 22:27:00 -06:00
raise SchemeValidationError ( f ' Anthropic requires `max_tokens`. See <https://support.anthropic.com/en/articles/7996856-what-is-the-maximum-prompt-length> ' )
2024-04-07 19:41:19 -06:00
return command_prefixes
@property
def command_prefixes ( self ) :
return self . _command_prefixes
def get ( self , key , default = None ) :
return copy . copy ( self . _config . get ( key , default ) )
def __setitem__ ( self , key , item ) :
raise Exception
def __getitem__ ( self , key ) :
2024-04-07 22:27:00 -06:00
return self . _parsed_config [ key ]
2024-04-07 19:41:19 -06:00
def __repr__ ( self ) :
2024-04-07 22:27:00 -06:00
return repr ( self . _parsed_config )
2024-04-07 19:41:19 -06:00
def __len__ ( self ) :
2024-04-07 22:27:00 -06:00
return len ( self . _parsed_config )
2024-04-07 19:41:19 -06:00
def __delitem__ ( self , key ) :
raise Exception
global_config = ConfigManager ( )