MatrixGPT/main.py

232 lines
9.1 KiB
Python
Executable File

#!/usr/bin/env python3
import argparse
import asyncio
import logging
import os
import sys
import time
import traceback
from pathlib import Path
from uuid import uuid4
import openai
import yaml
from aiohttp import ClientConnectionError, ServerDisconnectedError
from nio import InviteMemberEvent, JoinResponse, MegolmEvent, RoomMessageText, UnknownEvent
from matrix_gpt import MatrixNioGPTHelper
from matrix_gpt.bot.callbacks import Callbacks
from matrix_gpt.bot.storage import Storage
from matrix_gpt.config import check_config_value_exists
script_directory = os.path.abspath(os.path.dirname(__file__))
logging.basicConfig()
logger = logging.getLogger('MatrixGPT')
parser = argparse.ArgumentParser(description='MatrixGPT Bot')
parser.add_argument('--config', default=Path(script_directory, 'config.yaml'), help='Path to config.yaml if it is not located next to this executable.')
args = parser.parse_args()
# Load config
if not Path(args.config).exists():
print('Config file does not exist:', args.config)
sys.exit(1)
else:
try:
with open(args.config, 'r') as file:
config_data = yaml.safe_load(file)
except Exception as e:
print(f'Failed to load config file: {e}')
sys.exit(1)
# Lazy way to validate config
check_config_value_exists(config_data, 'bot_auth', dict)
check_config_value_exists(config_data['bot_auth'], 'username')
check_config_value_exists(config_data['bot_auth'], 'password')
check_config_value_exists(config_data['bot_auth'], 'homeserver')
check_config_value_exists(config_data['bot_auth'], 'store_path')
check_config_value_exists(config_data, 'allowed_to_chat')
check_config_value_exists(config_data, 'allowed_to_invite', allow_empty=True)
check_config_value_exists(config_data, 'data_storage')
check_config_value_exists(config_data, 'command')
check_config_value_exists(config_data, 'logging')
check_config_value_exists(config_data['logging'], 'log_level')
check_config_value_exists(config_data, 'openai')
command_prefixes = {}
for k, v in config_data['command'].items():
check_config_value_exists(v, 'model')
check_config_value_exists(v, 'mode', default='default')
if 'allowed_to_chat' not in v.keys():
# Set default value
v['allowed_to_chat'] = 'all'
command_prefixes[k] = v
# check_config_value_exists(config_data, 'autojoin_rooms')
def retry(msg=None):
if msg:
logger.warning(f'{msg}, retrying in 15s...')
else:
logger.warning(f'Retrying in 15s...')
time.sleep(15)
async def main():
if config_data['logging']['log_level'] == 'info':
log_level = logging.INFO
elif config_data['logging']['log_level'] == 'debug':
log_level = logging.DEBUG
elif config_data['logging']['log_level'] == 'warning':
log_level = logging.WARNING
elif config_data['logging']['log_level'] == 'critical':
log_level = logging.CRITICAL
else:
log_level = logging.INFO
logger.setLevel(log_level)
l = logger.getEffectiveLevel()
if l == 10:
logger.debug('Log level is DEBUG')
elif l == 20:
logger.info('Log level is INFO')
elif l == 30:
logger.warning('Log level is WARNING')
elif l == 40:
logger.error('Log level is ERROR')
elif l == 50:
logger.critical('Log level is CRITICAL')
else:
logger.info(f'Log level is {l}')
del l
if len(config_data['command'].keys()) == 1 and config_data['command'][list(config_data['command'].keys())[0]]['mode'] == 'local':
# Need the logger to be initalized for this
logger.info('Running in local mode, OpenAI API key not required.')
openai.api_key = 'abc123'
else:
check_config_value_exists(config_data['openai'], 'api_key')
openai.api_key = config_data['openai']['api_key']
logger.info(f'Command Prefixes: {[k for k, v in command_prefixes.items()]}')
# Logging in with a new device each time seems to fix encryption errors
device_id = config_data['bot_auth'].get('device_id', str(uuid4()))
matrix_helper = MatrixNioGPTHelper(
auth_file=Path(config_data['bot_auth']['store_path'], 'bot_auth.json'),
user_id=config_data['bot_auth']['username'],
passwd=config_data['bot_auth']['password'],
homeserver=config_data['bot_auth']['homeserver'],
store_path=config_data['bot_auth']['store_path'],
device_id=device_id,
)
client = matrix_helper.client
if config_data['openai'].get('api_base'):
logger.info(f'Set OpenAI API base URL to: {config_data["openai"].get("api_base")}')
openai.api_base = config_data['openai'].get('api_base')
storage = Storage(Path(config_data['data_storage'], 'matrixgpt.db'))
# Set up event callbacks
callbacks = Callbacks(client, storage,
openai_obj=openai,
command_prefixes=command_prefixes,
# openai_model=config_data['openai']['model'],
reply_in_thread=config_data.get('reply_in_thread', False),
allowed_to_invite=config_data['allowed_to_invite'],
allowed_to_chat=config_data['allowed_to_chat'],
log_full_response=config_data['logging'].get('log_full_response', False),
system_prompt=config_data['openai'].get('system_prompt'),
injected_system_prompt=config_data['openai'].get('injected_system_prompt', False),
openai_temperature=config_data['openai'].get('temperature', 0),
send_extra_messages=config_data.get('send_extra_messages', False),
log_level=log_level
)
client.add_event_callback(callbacks.message, RoomMessageText)
client.add_event_callback(callbacks.invite_event_filtered_callback, InviteMemberEvent)
client.add_event_callback(callbacks.decryption_failure, MegolmEvent)
client.add_event_callback(callbacks.unknown, UnknownEvent)
# TODO: make the bot move its read marker on these events too:
# RoomMessageImage
# RoomMessageAudio
# RoomMessageVideo
# RoomMessageFile
# Also see about this parent class: RoomMessageMedia
# Keep trying to reconnect on failure (with some time in-between)
while True:
try:
logger.info('Logging in...')
while True:
login_success, login_response = await matrix_helper.login()
if not login_success:
if 'M_LIMIT_EXCEEDED' in str(login_response):
try:
wait = int((int(str(login_response).split(' ')[-1][:-2]) / 1000) / 2) # only wait half the ratelimited time
logger.error(f'Ratelimited, sleeping {wait}s...')
time.sleep(wait)
except:
logger.error('Could not parse M_LIMIT_EXCEEDED')
else:
logger.error(f'Failed to login, retrying: {login_response}')
time.sleep(5)
else:
break
# Login succeeded!
logger.info(f"Logged in as {client.user_id} using device {device_id}.")
if config_data.get('autojoin_rooms'):
for room in config_data.get('autojoin_rooms'):
r = await client.join(room)
if not isinstance(r, JoinResponse):
logger.critical(f'Failed to join room {room}: {vars(r)}')
time.sleep(1.5)
# Log out old devices to keep the session clean
if config_data.get('logout_other_devices', False):
logger.info('Logging out other devices...')
devices = list((await client.devices()).devices)
device_list = [x.id for x in devices]
if device_id in device_list:
device_list.remove(device_id)
x = await client.delete_devices(device_list, {
"type": "m.login.password",
"user": config_data['bot_auth']['username'],
"password": config_data['bot_auth']['password']
}
)
logger.info(f'Logged out: {device_list}')
await client.sync_forever(timeout=10000, full_state=True)
# except LocalProtocolError:
# logger.error(f'Failed to login, retrying in 5s...')
# time.sleep(5)
except (ClientConnectionError, ServerDisconnectedError):
logger.warning("Unable to connect to homeserver, retrying in 15s...")
time.sleep(15)
except KeyboardInterrupt:
await client.close()
sys.exit()
except Exception:
logger.critical(traceback.format_exc())
logger.critical('Sleeping 5s...')
time.sleep(5)
if __name__ == "__main__":
while True:
try:
asyncio.run(main())
except KeyboardInterrupt:
sys.exit()
except Exception:
logger.critical(traceback.format_exc())
time.sleep(5)