MatrixGPT/main.py

232 lines
9.1 KiB
Python
Raw Normal View History

2023-03-18 02:14:45 -06:00
#!/usr/bin/env python3
import argparse
import asyncio
import logging
import os
import sys
import time
import traceback
from pathlib import Path
2023-03-18 13:05:00 -06:00
from uuid import uuid4
2023-03-18 02:14:45 -06:00
import openai
import yaml
from aiohttp import ClientConnectionError, ServerDisconnectedError
from nio import InviteMemberEvent, JoinResponse, MegolmEvent, RoomMessageText, UnknownEvent
2023-03-18 02:14:45 -06:00
from matrix_gpt import MatrixNioGPTHelper
from matrix_gpt.bot.callbacks import Callbacks
from matrix_gpt.bot.storage import Storage
from matrix_gpt.config import check_config_value_exists
script_directory = os.path.abspath(os.path.dirname(__file__))
logging.basicConfig()
logger = logging.getLogger('MatrixGPT')
parser = argparse.ArgumentParser(description='MatrixGPT Bot')
parser.add_argument('--config', default=Path(script_directory, 'config.yaml'), help='Path to config.yaml if it is not located next to this executable.')
args = parser.parse_args()
# Load config
if not Path(args.config).exists():
print('Config file does not exist:', args.config)
sys.exit(1)
else:
try:
with open(args.config, 'r') as file:
config_data = yaml.safe_load(file)
except Exception as e:
print(f'Failed to load config file: {e}')
sys.exit(1)
2023-03-19 14:46:42 -06:00
# Lazy way to validate config
2023-03-18 02:14:45 -06:00
check_config_value_exists(config_data, 'bot_auth', dict)
check_config_value_exists(config_data['bot_auth'], 'username')
check_config_value_exists(config_data['bot_auth'], 'password')
check_config_value_exists(config_data['bot_auth'], 'homeserver')
check_config_value_exists(config_data['bot_auth'], 'store_path')
check_config_value_exists(config_data, 'allowed_to_chat')
check_config_value_exists(config_data, 'allowed_to_invite', allow_empty=True)
check_config_value_exists(config_data, 'data_storage')
2023-03-31 23:08:27 -06:00
check_config_value_exists(config_data, 'command')
2023-03-18 02:14:45 -06:00
2023-03-19 15:22:05 -06:00
check_config_value_exists(config_data, 'logging')
check_config_value_exists(config_data['logging'], 'log_level')
2023-03-19 14:46:42 -06:00
check_config_value_exists(config_data, 'openai')
2023-03-31 23:08:27 -06:00
command_prefixes = {}
for k, v in config_data['command'].items():
2023-04-19 17:01:07 -06:00
check_config_value_exists(v, 'model')
check_config_value_exists(v, 'mode', default='default')
if 'allowed_to_chat' not in v.keys():
# Set default value
v['allowed_to_chat'] = 'all'
2023-03-31 23:08:27 -06:00
command_prefixes[k] = v
2023-03-18 15:54:00 -06:00
2023-03-18 02:14:45 -06:00
# check_config_value_exists(config_data, 'autojoin_rooms')
def retry(msg=None):
if msg:
logger.warning(f'{msg}, retrying in 15s...')
else:
logger.warning(f'Retrying in 15s...')
time.sleep(15)
async def main():
2023-03-19 15:22:05 -06:00
if config_data['logging']['log_level'] == 'info':
log_level = logging.INFO
elif config_data['logging']['log_level'] == 'debug':
log_level = logging.DEBUG
elif config_data['logging']['log_level'] == 'warning':
log_level = logging.WARNING
elif config_data['logging']['log_level'] == 'critical':
log_level = logging.CRITICAL
else:
log_level = logging.INFO
logger.setLevel(log_level)
l = logger.getEffectiveLevel()
if l == 10:
logger.debug('Log level is DEBUG')
elif l == 20:
logger.info('Log level is INFO')
elif l == 30:
logger.warning('Log level is WARNING')
elif l == 40:
logger.error('Log level is ERROR')
elif l == 50:
logger.critical('Log level is CRITICAL')
else:
logger.info(f'Log level is {l}')
del l
2023-09-15 22:00:37 -06:00
if len(config_data['command'].keys()) == 1 and config_data['command'][list(config_data['command'].keys())[0]]['mode'] == 'local':
# Need the logger to be initalized for this
logger.info('Running in local mode, OpenAI API key not required.')
openai.api_key = 'abc123'
else:
check_config_value_exists(config_data['openai'], 'api_key')
openai.api_key = config_data['openai']['api_key']
logger.info(f'Command Prefixes: {[k for k, v in command_prefixes.items()]}')
2023-03-18 13:05:00 -06:00
# Logging in with a new device each time seems to fix encryption errors
2023-03-18 13:54:13 -06:00
device_id = config_data['bot_auth'].get('device_id', str(uuid4()))
2023-03-18 13:05:00 -06:00
2023-03-19 15:24:02 -06:00
matrix_helper = MatrixNioGPTHelper(
auth_file=Path(config_data['bot_auth']['store_path'], 'bot_auth.json'),
user_id=config_data['bot_auth']['username'],
passwd=config_data['bot_auth']['password'],
homeserver=config_data['bot_auth']['homeserver'],
store_path=config_data['bot_auth']['store_path'],
device_id=device_id,
)
2023-03-18 02:14:45 -06:00
client = matrix_helper.client
2023-09-15 22:00:37 -06:00
if config_data['openai'].get('api_base'):
logger.info(f'Set OpenAI API base URL to: {config_data["openai"].get("api_base")}')
openai.api_base = config_data['openai'].get('api_base')
2023-03-18 02:14:45 -06:00
storage = Storage(Path(config_data['data_storage'], 'matrixgpt.db'))
# Set up event callbacks
2023-03-19 15:24:02 -06:00
callbacks = Callbacks(client, storage,
openai_obj=openai,
2023-03-31 23:08:27 -06:00
command_prefixes=command_prefixes,
# openai_model=config_data['openai']['model'],
2023-03-19 15:24:02 -06:00
reply_in_thread=config_data.get('reply_in_thread', False),
allowed_to_invite=config_data['allowed_to_invite'],
allowed_to_chat=config_data['allowed_to_chat'],
log_full_response=config_data['logging'].get('log_full_response', False),
system_prompt=config_data['openai'].get('system_prompt'),
injected_system_prompt=config_data['openai'].get('injected_system_prompt', False),
2023-03-22 16:44:17 -06:00
openai_temperature=config_data['openai'].get('temperature', 0),
2023-09-15 22:00:37 -06:00
send_extra_messages=config_data.get('send_extra_messages', False),
2023-03-22 16:44:17 -06:00
log_level=log_level
2023-03-19 15:24:02 -06:00
)
2023-03-18 02:14:45 -06:00
client.add_event_callback(callbacks.message, RoomMessageText)
client.add_event_callback(callbacks.invite_event_filtered_callback, InviteMemberEvent)
client.add_event_callback(callbacks.decryption_failure, MegolmEvent)
client.add_event_callback(callbacks.unknown, UnknownEvent)
# TODO: make the bot move its read marker on these events too:
# RoomMessageImage
# RoomMessageAudio
# RoomMessageVideo
# RoomMessageFile
# Also see about this parent class: RoomMessageMedia
2023-03-18 02:14:45 -06:00
# Keep trying to reconnect on failure (with some time in-between)
while True:
try:
logger.info('Logging in...')
while True:
login_success, login_response = await matrix_helper.login()
if not login_success:
if 'M_LIMIT_EXCEEDED' in str(login_response):
try:
wait = int((int(str(login_response).split(' ')[-1][:-2]) / 1000) / 2) # only wait half the ratelimited time
logger.error(f'Ratelimited, sleeping {wait}s...')
time.sleep(wait)
except:
logger.error('Could not parse M_LIMIT_EXCEEDED')
else:
logger.error(f'Failed to login, retrying: {login_response}')
time.sleep(5)
else:
break
2023-03-18 02:14:45 -06:00
# Login succeeded!
2023-03-18 13:54:13 -06:00
logger.info(f"Logged in as {client.user_id} using device {device_id}.")
2023-03-18 02:14:45 -06:00
if config_data.get('autojoin_rooms'):
for room in config_data.get('autojoin_rooms'):
r = await client.join(room)
if not isinstance(r, JoinResponse):
logger.critical(f'Failed to join room {room}: {vars(r)}')
2023-03-18 13:05:00 -06:00
time.sleep(1.5)
# Log out old devices to keep the session clean
2023-03-18 14:37:50 -06:00
if config_data.get('logout_other_devices', False):
logger.info('Logging out other devices...')
devices = list((await client.devices()).devices)
device_list = [x.id for x in devices]
if device_id in device_list:
device_list.remove(device_id)
x = await client.delete_devices(device_list, {
2023-09-15 22:00:37 -06:00
"type": "m.login.password",
"user": config_data['bot_auth']['username'],
"password": config_data['bot_auth']['password']
}
)
2023-03-18 14:37:50 -06:00
logger.info(f'Logged out: {device_list}')
2023-03-18 02:14:45 -06:00
await client.sync_forever(timeout=10000, full_state=True)
# except LocalProtocolError:
# logger.error(f'Failed to login, retrying in 5s...')
# time.sleep(5)
2023-03-18 02:14:45 -06:00
except (ClientConnectionError, ServerDisconnectedError):
logger.warning("Unable to connect to homeserver, retrying in 15s...")
time.sleep(15)
2023-03-18 13:32:04 -06:00
except KeyboardInterrupt:
await client.close()
sys.exit()
2023-03-18 13:05:00 -06:00
except Exception:
logger.critical(traceback.format_exc())
logger.critical('Sleeping 5s...')
time.sleep(5)
2023-03-18 02:14:45 -06:00
if __name__ == "__main__":
while True:
try:
asyncio.run(main())
2023-03-18 15:54:00 -06:00
except KeyboardInterrupt:
sys.exit()
2023-03-18 02:14:45 -06:00
except Exception:
logger.critical(traceback.format_exc())
2023-03-18 03:21:03 -06:00
time.sleep(5)