2024-04-07 19:41:19 -06:00
import copy
from pathlib import Path
from types import NoneType
import bison
2024-04-07 22:27:00 -06:00
from bison . errors import SchemeValidationError
from mergedeep import merge , Strategy
2024-04-07 19:41:19 -06:00
config_scheme = bison . Scheme (
bison . Option ( ' store_path ' , default = ' bot-store/ ' , field_type = str ) ,
bison . DictOption ( ' auth ' , scheme = bison . Scheme (
bison . Option ( ' username ' , field_type = str , required = True ) ,
bison . Option ( ' password ' , field_type = str , required = True ) ,
bison . Option ( ' homeserver ' , field_type = str , required = True ) ,
bison . Option ( ' device_id ' , field_type = str , required = True ) ,
) ) ,
2024-04-07 22:27:00 -06:00
bison . ListOption ( ' allowed_to_chat ' , member_type = str , default = [ ' all ' ] ) ,
bison . ListOption ( ' allowed_to_thread ' , member_type = str , default = [ ' all ' ] ) ,
bison . ListOption ( ' allowed_to_invite ' , member_type = str , default = [ ' all ' ] ) ,
2024-04-07 19:41:19 -06:00
bison . ListOption ( ' autojoin_rooms ' , default = [ ] ) ,
bison . ListOption ( ' blacklist_rooms ' , default = [ ] ) ,
bison . Option ( ' response_timeout ' , default = 120 , field_type = int ) ,
2024-04-07 22:27:00 -06:00
bison . ListOption ( ' command ' , required = True , member_scheme = bison . Scheme (
2024-04-07 19:41:19 -06:00
bison . Option ( ' trigger ' , field_type = str , required = True ) ,
2024-04-07 22:27:00 -06:00
bison . Option ( ' api_type ' , field_type = str , choices = [ ' openai ' , ' anth ' ] , required = True ) ,
2024-04-07 19:41:19 -06:00
bison . Option ( ' model ' , field_type = str , required = True ) ,
2024-04-07 22:27:00 -06:00
bison . Option ( ' max_tokens ' , field_type = int , default = 0 ) ,
bison . Option ( ' temperature ' , field_type = float , default = 0.5 ) ,
bison . ListOption ( ' allowed_to_chat ' , member_type = str , default = [ ] ) ,
bison . ListOption ( ' allowed_to_thread ' , member_type = str , default = [ ] ) ,
bison . ListOption ( ' allowed_to_invite ' , member_type = str , default = [ ] ) ,
bison . Option ( ' system_prompt ' , field_type = str , default = None ) ,
bison . Option ( ' injected_system_prompt ' , field_type = str , default = None ) ,
bison . Option ( ' api_base ' , field_type = [ str , NoneType ] , default = None ) ,
2024-04-07 19:41:19 -06:00
) ) ,
bison . DictOption ( ' openai ' , scheme = bison . Scheme (
2024-04-07 22:27:00 -06:00
bison . Option ( ' api_key ' , field_type = [ str , NoneType ] , default = None , required = False ) ,
) ) ,
bison . DictOption ( ' anthropic ' , scheme = bison . Scheme (
bison . Option ( ' api_key ' , field_type = [ str , NoneType ] , required = False , default = None ) ,
2024-04-07 19:41:19 -06:00
) ) ,
bison . DictOption ( ' logging ' , scheme = bison . Scheme (
bison . Option ( ' log_level ' , field_type = str , default = ' info ' ) ,
bison . Option ( ' log_full_response ' , field_type = bool , default = True ) ,
) ) ,
)
2024-04-07 22:27:00 -06:00
# Bison does not support list default options in certain situations.
# Only one level recursive.
DEFAULT_LISTS = {
' command ' : {
' max_tokens ' : 0 ,
' temperature ' : 0.5 ,
' allowed_to_chat ' : [ ] ,
' allowed_to_thread ' : [ ] ,
' allowed_to_invite ' : [ ] ,
' system_prompt ' : None ,
' injected_system_prompt ' : None ,
' api_base ' : None ,
}
}
2024-04-07 19:41:19 -06:00
class ConfigManager :
def __init__ ( self ) :
self . _config = bison . Bison ( scheme = config_scheme )
self . _command_prefixes = { }
2024-04-07 22:27:00 -06:00
self . _parsed_config = { }
2024-04-07 19:41:19 -06:00
self . _loaded = False
2024-04-07 22:27:00 -06:00
self . _validated = False
2024-04-07 19:41:19 -06:00
def load ( self , path : Path ) :
2024-04-07 22:27:00 -06:00
assert not self . _loaded
2024-04-07 19:41:19 -06:00
self . _config . config_name = ' config '
self . _config . config_format = bison . bison . YAML
self . _config . add_config_paths ( str ( path . parent ) )
self . _config . parse ( )
self . _loaded = True
def validate ( self ) :
2024-04-07 22:27:00 -06:00
assert not self . _validated
2024-04-07 19:41:19 -06:00
self . _config . validate ( )
2024-04-07 22:27:00 -06:00
if not self . _config . config [ ' openai ' ] [ ' api_key ' ] and not self . _config . config [ ' anthropic ' ] [ ' api_key ' ] :
raise SchemeValidationError ( ' You need an OpenAI or Anthropic API key ' )
self . _parsed_config = self . _merge_in_list_defaults ( )
self . _command_prefixes = self . _generate_command_prefixes ( )
def _merge_in_list_defaults ( self ) :
new_config = copy . copy ( self . _config . config )
for k , v in self . _config . config . items ( ) :
for d_k , d_v in DEFAULT_LISTS . items ( ) :
if k == d_k :
assert isinstance ( v , list )
for i in range ( len ( v ) ) :
new_config [ k ] [ i ] = merge ( d_v , v [ i ] , strategy = Strategy . ADDITIVE )
return new_config
2024-04-07 19:41:19 -06:00
@property
def config ( self ) :
2024-04-07 22:27:00 -06:00
return copy . copy ( self . _parsed_config )
2024-04-07 19:41:19 -06:00
def _generate_command_prefixes ( self ) :
2024-04-07 22:27:00 -06:00
assert not self . _validated
2024-04-07 19:41:19 -06:00
command_prefixes = { }
for item in self . _config . config [ ' command ' ] :
command_prefixes [ item [ ' trigger ' ] ] = item
2024-04-07 22:27:00 -06:00
if item . get ( ' max_tokens ' , 0 ) < 1 :
raise SchemeValidationError ( f ' Anthropic requires `max_tokens`. See <https://support.anthropic.com/en/articles/7996856-what-is-the-maximum-prompt-length> ' )
2024-04-07 19:41:19 -06:00
return command_prefixes
@property
def command_prefixes ( self ) :
return self . _command_prefixes
def get ( self , key , default = None ) :
return copy . copy ( self . _config . get ( key , default ) )
def __setitem__ ( self , key , item ) :
raise Exception
def __getitem__ ( self , key ) :
2024-04-07 22:27:00 -06:00
return self . _parsed_config [ key ]
2024-04-07 19:41:19 -06:00
def __repr__ ( self ) :
2024-04-07 22:27:00 -06:00
return repr ( self . _parsed_config )
2024-04-07 19:41:19 -06:00
def __len__ ( self ) :
2024-04-07 22:27:00 -06:00
return len ( self . _parsed_config )
2024-04-07 19:41:19 -06:00
def __delitem__ ( self , key ) :
raise Exception
global_config = ConfigManager ( )