rewrite
This commit is contained in:
parent
c56bf9160e
commit
e008cc2014
|
@ -3,70 +3,43 @@ import argparse
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
import openai
|
||||
import yaml
|
||||
from aiohttp import ClientConnectionError, ServerDisconnectedError
|
||||
from bison.errors import SchemeValidationError
|
||||
from nio import InviteMemberEvent, JoinResponse, MegolmEvent, RoomMessageText, UnknownEvent
|
||||
|
||||
from matrix_gpt import MatrixNioGPTHelper
|
||||
from matrix_gpt.bot.callbacks import Callbacks
|
||||
from matrix_gpt.bot.storage import Storage
|
||||
from matrix_gpt.config import check_config_value_exists
|
||||
from matrix_gpt import MatrixClientHelper
|
||||
from matrix_gpt.callbacks import MatrixBotCallbacks
|
||||
from matrix_gpt.config import global_config
|
||||
|
||||
script_directory = os.path.abspath(os.path.dirname(__file__))
|
||||
SCRIPT_DIR = os.path.abspath(os.path.dirname(__file__))
|
||||
|
||||
logging.basicConfig()
|
||||
logger = logging.getLogger('MatrixGPT')
|
||||
|
||||
parser = argparse.ArgumentParser(description='MatrixGPT Bot')
|
||||
parser.add_argument('--config', default=Path(script_directory, 'config.yaml'), help='Path to config.yaml if it is not located next to this executable.')
|
||||
parser.add_argument('--config', default=Path(SCRIPT_DIR, 'config.yaml'), help='Path to config.yaml if it is not located next to this executable.')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load config
|
||||
if not Path(args.config).exists():
|
||||
print('Config file does not exist:', args.config)
|
||||
args.config = Path(args.config)
|
||||
if not args.config.exists():
|
||||
logger.critical('Config file does not exist:', args.config)
|
||||
sys.exit(1)
|
||||
else:
|
||||
try:
|
||||
with open(args.config, 'r') as file:
|
||||
config_data = yaml.safe_load(file)
|
||||
except Exception as e:
|
||||
print(f'Failed to load config file: {e}')
|
||||
sys.exit(1)
|
||||
|
||||
# Lazy way to validate config
|
||||
check_config_value_exists(config_data, 'bot_auth', dict)
|
||||
check_config_value_exists(config_data['bot_auth'], 'username')
|
||||
check_config_value_exists(config_data['bot_auth'], 'password')
|
||||
check_config_value_exists(config_data['bot_auth'], 'homeserver')
|
||||
check_config_value_exists(config_data['bot_auth'], 'store_path')
|
||||
check_config_value_exists(config_data, 'allowed_to_chat')
|
||||
check_config_value_exists(config_data, 'allowed_to_invite', allow_empty=True)
|
||||
check_config_value_exists(config_data, 'data_storage')
|
||||
check_config_value_exists(config_data, 'command')
|
||||
global_config.load(args.config)
|
||||
try:
|
||||
global_config.validate()
|
||||
except SchemeValidationError as e:
|
||||
logger.critical(f'Config validation error: {e}')
|
||||
sys.exit(1)
|
||||
config_data = global_config.config.config
|
||||
|
||||
check_config_value_exists(config_data, 'logging')
|
||||
check_config_value_exists(config_data['logging'], 'log_level')
|
||||
|
||||
check_config_value_exists(config_data, 'openai')
|
||||
|
||||
command_prefixes = {}
|
||||
for k, v in config_data['command'].items():
|
||||
check_config_value_exists(v, 'model')
|
||||
check_config_value_exists(v, 'mode', default='default')
|
||||
if 'allowed_to_chat' not in v.keys():
|
||||
# Set default value
|
||||
v['allowed_to_chat'] = 'all'
|
||||
command_prefixes[k] = v
|
||||
|
||||
|
||||
# check_config_value_exists(config_data, 'autojoin_rooms')
|
||||
|
||||
def retry(msg=None):
|
||||
if msg:
|
||||
|
@ -104,68 +77,39 @@ async def main():
|
|||
logger.info(f'Log level is {l}')
|
||||
del l
|
||||
|
||||
if len(config_data['command'].keys()) == 1 and config_data['command'][list(config_data['command'].keys())[0]].get('mode') == 'local':
|
||||
# Need the logger to be initalized for this
|
||||
if len(config_data['command']) == 1 and config_data['command'][0].get('mode') == 'local':
|
||||
logger.info('Running in local mode, OpenAI API key not required.')
|
||||
openai.api_key = 'abc123'
|
||||
else:
|
||||
check_config_value_exists(config_data['openai'], 'api_key')
|
||||
openai.api_key = config_data['openai']['api_key']
|
||||
|
||||
logger.info(f'Command Prefixes: {[k for k, v in command_prefixes.items()]}')
|
||||
logger.debug(f'Command Prefixes: {[k for k, v in global_config.command_prefixes.items()]}')
|
||||
|
||||
# Logging in with a new device each time seems to fix encryption errors
|
||||
device_id = config_data['bot_auth'].get('device_id', str(uuid4()))
|
||||
|
||||
matrix_helper = MatrixNioGPTHelper(
|
||||
auth_file=Path(config_data['bot_auth']['store_path'], 'bot_auth.json'),
|
||||
user_id=config_data['bot_auth']['username'],
|
||||
passwd=config_data['bot_auth']['password'],
|
||||
homeserver=config_data['bot_auth']['homeserver'],
|
||||
store_path=config_data['bot_auth']['store_path'],
|
||||
device_id=device_id,
|
||||
client_helper = MatrixClientHelper(
|
||||
user_id=config_data['auth']['username'],
|
||||
passwd=config_data['auth']['password'],
|
||||
homeserver=config_data['auth']['homeserver'],
|
||||
store_path=config_data['store_path'],
|
||||
device_name='MatrixGPT'
|
||||
)
|
||||
client = matrix_helper.client
|
||||
client = client_helper.client
|
||||
|
||||
if config_data['openai'].get('api_base'):
|
||||
logger.info(f'Set OpenAI API base URL to: {config_data["openai"].get("api_base")}')
|
||||
openai.api_base = config_data['openai'].get('api_base')
|
||||
|
||||
storage = Storage(Path(config_data['data_storage'], 'matrixgpt.db'))
|
||||
|
||||
# Set up event callbacks
|
||||
callbacks = Callbacks(client, storage,
|
||||
openai_obj=openai,
|
||||
command_prefixes=command_prefixes,
|
||||
# openai_model=config_data['openai']['model'],
|
||||
reply_in_thread=config_data.get('reply_in_thread', False),
|
||||
allowed_to_invite=config_data['allowed_to_invite'],
|
||||
allowed_to_chat=config_data['allowed_to_chat'],
|
||||
log_full_response=config_data['logging'].get('log_full_response', False),
|
||||
system_prompt=config_data['openai'].get('system_prompt'),
|
||||
injected_system_prompt=config_data['openai'].get('injected_system_prompt', False),
|
||||
openai_temperature=config_data['openai'].get('temperature', 0),
|
||||
send_extra_messages=config_data.get('send_extra_messages', False),
|
||||
log_level=log_level
|
||||
)
|
||||
client.add_event_callback(callbacks.message, RoomMessageText)
|
||||
client.add_event_callback(callbacks.invite_event_filtered_callback, InviteMemberEvent)
|
||||
callbacks = MatrixBotCallbacks(client=client_helper)
|
||||
client.add_event_callback(callbacks.handle_message, RoomMessageText)
|
||||
client.add_event_callback(callbacks.handle_invite, InviteMemberEvent)
|
||||
client.add_event_callback(callbacks.decryption_failure, MegolmEvent)
|
||||
client.add_event_callback(callbacks.unknown, UnknownEvent)
|
||||
|
||||
# TODO: make the bot move its read marker on these events too:
|
||||
# TODO: multimedia mode?
|
||||
# RoomMessageImage
|
||||
# RoomMessageAudio
|
||||
# RoomMessageVideo
|
||||
# RoomMessageFile
|
||||
# Also see about this parent class: RoomMessageMedia
|
||||
|
||||
# Keep trying to reconnect on failure (with some time in-between)
|
||||
while True:
|
||||
try:
|
||||
logger.info('Logging in...')
|
||||
while True:
|
||||
login_success, login_response = await matrix_helper.login()
|
||||
login_success, login_response = await client_helper.login()
|
||||
if not login_success:
|
||||
if 'M_LIMIT_EXCEEDED' in str(login_response):
|
||||
try:
|
||||
|
@ -181,7 +125,7 @@ async def main():
|
|||
break
|
||||
|
||||
# Login succeeded!
|
||||
logger.info(f"Logged in as {client.user_id} using device {device_id}.")
|
||||
logger.info(f'Logged in as {client.user_id}')
|
||||
if config_data.get('autojoin_rooms'):
|
||||
for room in config_data.get('autojoin_rooms'):
|
||||
r = await client.join(room)
|
||||
|
@ -189,31 +133,18 @@ async def main():
|
|||
logger.critical(f'Failed to join room {room}: {vars(r)}')
|
||||
time.sleep(1.5)
|
||||
|
||||
# Log out old devices to keep the session clean
|
||||
if config_data.get('logout_other_devices', False):
|
||||
logger.info('Logging out other devices...')
|
||||
devices = list((await client.devices()).devices)
|
||||
device_list = [x.id for x in devices]
|
||||
if device_id in device_list:
|
||||
device_list.remove(device_id)
|
||||
x = await client.delete_devices(device_list, {
|
||||
"type": "m.login.password",
|
||||
"user": config_data['bot_auth']['username'],
|
||||
"password": config_data['bot_auth']['password']
|
||||
}
|
||||
)
|
||||
logger.info(f'Logged out: {device_list}')
|
||||
logger.info('Performing initial sync...')
|
||||
last_sync = (await client_helper.sync()).next_batch
|
||||
client_helper.run_sync_in_bg() # start a background thread to record our sync tokens
|
||||
|
||||
await client.sync_forever(timeout=10000, full_state=True)
|
||||
# except LocalProtocolError:
|
||||
# logger.error(f'Failed to login, retrying in 5s...')
|
||||
# time.sleep(5)
|
||||
logger.info('Bot is active')
|
||||
await client.sync_forever(timeout=10000, full_state=True, since=last_sync)
|
||||
except (ClientConnectionError, ServerDisconnectedError):
|
||||
logger.warning("Unable to connect to homeserver, retrying in 15s...")
|
||||
time.sleep(15)
|
||||
except KeyboardInterrupt:
|
||||
await client.close()
|
||||
sys.exit()
|
||||
os.kill(os.getpid(), signal.SIGTERM)
|
||||
except Exception:
|
||||
logger.critical(traceback.format_exc())
|
||||
logger.critical('Sleeping 5s...')
|
||||
|
@ -225,7 +156,7 @@ if __name__ == "__main__":
|
|||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
sys.exit()
|
||||
os.kill(os.getpid(), signal.SIGTERM)
|
||||
except Exception:
|
||||
logger.critical(traceback.format_exc())
|
||||
time.sleep(5)
|
||||
|
|
|
@ -1 +1 @@
|
|||
from .matrix import MatrixNioGPTHelper
|
||||
from .matrix import MatrixClientHelper
|
|
@ -1,126 +0,0 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from types import ModuleType
|
||||
|
||||
from nio import AsyncClient, MatrixRoom, RoomMessageText
|
||||
|
||||
from .chat_functions import process_chat, react_to_event, send_text_to_room
|
||||
# from .config import Config
|
||||
from .storage import Storage
|
||||
|
||||
logger = logging.getLogger('MatrixGPT')
|
||||
|
||||
|
||||
class Command:
|
||||
def __init__(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
store: Storage,
|
||||
# config: Config,
|
||||
command: str,
|
||||
room: MatrixRoom,
|
||||
event: RoomMessageText,
|
||||
openai_obj: ModuleType,
|
||||
openai_model: str,
|
||||
reply_in_thread,
|
||||
openai_temperature: float = 0,
|
||||
system_prompt: str = None,
|
||||
injected_system_prompt: str = None,
|
||||
log_full_response: bool = False,
|
||||
send_extra_messages: bool = True
|
||||
):
|
||||
"""A command made by a user.
|
||||
|
||||
Args:
|
||||
client: The client to communicate to matrix with.
|
||||
|
||||
store: Bot storage.
|
||||
|
||||
config: Bot configuration parameters.
|
||||
|
||||
command: The command and arguments.
|
||||
|
||||
room: The room the command was sent in.
|
||||
|
||||
event: The event describing the command.
|
||||
"""
|
||||
self.client = client
|
||||
self.store = store
|
||||
# self.config = config
|
||||
self.command = command
|
||||
self.room = room
|
||||
self.event = event
|
||||
self.args = self.command.split()[1:]
|
||||
self.openai_model = openai_model
|
||||
self.reply_in_thread = reply_in_thread
|
||||
self.system_prompt = system_prompt
|
||||
self.injected_system_prompt = injected_system_prompt
|
||||
self.log_full_response = log_full_response
|
||||
self.openai_obj = openai_obj
|
||||
self.openai_temperature = openai_temperature
|
||||
self.send_extra_messages = send_extra_messages
|
||||
|
||||
async def process(self):
|
||||
"""Process the command"""
|
||||
await self.client.room_read_markers(self.room.room_id, self.event.event_id, self.event.event_id)
|
||||
self.command = self.command.strip()
|
||||
# if self.command.startswith("echo"):
|
||||
# await self._echo()
|
||||
# elif self.command.startswith("react"):
|
||||
# await self._react()
|
||||
# if self.command.startswith("help"):
|
||||
# await self._show_help()
|
||||
# else:
|
||||
try:
|
||||
await self._process_chat()
|
||||
except Exception:
|
||||
await react_to_event(self.client, self.room.room_id, self.event.event_id, '❌')
|
||||
raise
|
||||
|
||||
async def _process_chat(self):
|
||||
async def inner():
|
||||
await process_chat(
|
||||
self.client,
|
||||
self.room,
|
||||
self.event,
|
||||
self.command,
|
||||
self.store,
|
||||
openai_obj=self.openai_obj,
|
||||
openai_model=self.openai_model,
|
||||
openai_temperature=self.openai_temperature,
|
||||
system_prompt=self.system_prompt,
|
||||
injected_system_prompt=self.injected_system_prompt,
|
||||
log_full_response=self.log_full_response,
|
||||
send_extra_messages=self.send_extra_messages
|
||||
)
|
||||
|
||||
asyncio.get_event_loop().create_task(inner())
|
||||
|
||||
async def _show_help(self):
|
||||
"""Show the help text"""
|
||||
# if not self.args:
|
||||
# text = (
|
||||
# "Hello, I am a bot made with matrix-nio! Use `help commands` to view "
|
||||
# "available commands."
|
||||
# )
|
||||
# await send_text_to_room(self.client, self.room.room_id, text)
|
||||
# return
|
||||
|
||||
# topic = self.args[0]
|
||||
# if topic == "rules":
|
||||
# text = "These are the rules!"
|
||||
# elif topic == "commands":
|
||||
# text = """Available commands:"""
|
||||
# else:
|
||||
# text = "Unknown help topic!"
|
||||
|
||||
text = 'Send your message to ChatGPT like this: `!c Hi ChatGPT, how are you?`'
|
||||
|
||||
await send_text_to_room(self.client, self.room.room_id, text)
|
||||
|
||||
async def _unknown_command(self):
|
||||
await send_text_to_room(
|
||||
self.client,
|
||||
self.room.room_id,
|
||||
f"Unknown command '{self.command}'. Try the 'help' command for more information.",
|
||||
)
|
|
@ -1,269 +0,0 @@
|
|||
# https://github.com/anoadragon453/nio-template
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
from types import ModuleType
|
||||
|
||||
from nio import (AsyncClient, InviteMemberEvent, JoinError, MatrixRoom, MegolmEvent, RoomMessageText, UnknownEvent)
|
||||
|
||||
from .bot_commands import Command
|
||||
from .chat_functions import check_authorized, get_thread_content, is_this_our_thread, is_thread, process_chat, react_to_event, send_text_to_room, check_command_prefix
|
||||
# from .config import Config
|
||||
from .storage import Storage
|
||||
|
||||
logger = logging.getLogger('MatrixGPT')
|
||||
|
||||
|
||||
class Callbacks:
|
||||
def __init__(self,
|
||||
client: AsyncClient,
|
||||
store: Storage,
|
||||
command_prefixes: dict,
|
||||
openai_obj: ModuleType,
|
||||
# openai_model: str,
|
||||
reply_in_thread: bool,
|
||||
allowed_to_invite: list,
|
||||
allowed_to_chat: str = 'all',
|
||||
system_prompt: str = None,
|
||||
log_full_response: bool = False,
|
||||
injected_system_prompt: str = False,
|
||||
openai_temperature: float = 0,
|
||||
log_level=logging.INFO,
|
||||
send_extra_messages: bool = False
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
client: nio client used to interact with matrix.
|
||||
store: Bot storage.
|
||||
config: Bot configuration parameters.
|
||||
"""
|
||||
self.client = client
|
||||
self.store = store
|
||||
# self.config = config
|
||||
self.command_prefixes = command_prefixes
|
||||
# self.openai_model = openai_model
|
||||
self.startup_ts = time.time_ns() // 1_000_000
|
||||
self.reply_in_thread = reply_in_thread
|
||||
self.allowed_to_invite = allowed_to_invite if allowed_to_invite else []
|
||||
self.allowed_to_chat = allowed_to_chat
|
||||
self.system_prompt = system_prompt
|
||||
self.log_full_response = log_full_response
|
||||
self.injected_system_prompt = injected_system_prompt
|
||||
self.openai_obj = openai_obj
|
||||
self.openai_temperature = openai_temperature
|
||||
# self.gpt4_enabled = gpt4_enabled
|
||||
self.log_level = log_level
|
||||
self.send_extra_messages = send_extra_messages
|
||||
|
||||
async def message(self, room: MatrixRoom, event: RoomMessageText) -> None:
|
||||
"""Callback for when a message event is received
|
||||
|
||||
Args:
|
||||
room: The room the event came from.
|
||||
event: The event defining the message.
|
||||
"""
|
||||
# Extract the message text
|
||||
await self.client.room_read_markers(room.room_id, event.event_id, event.event_id)
|
||||
|
||||
# Ignore messages from ourselves
|
||||
if event.sender == self.client.user_id:
|
||||
return
|
||||
|
||||
if not check_authorized(event.sender, self.allowed_to_chat):
|
||||
await react_to_event(self.client, room.room_id, event.event_id, "🚫")
|
||||
return
|
||||
|
||||
if event.server_timestamp < self.startup_ts:
|
||||
logger.debug(f'Skipping event as it was sent before startup time: {event.event_id}')
|
||||
return
|
||||
if self.store.check_seen_event(event.event_id):
|
||||
logger.debug(f'Skipping seen event: {event.event_id}')
|
||||
return
|
||||
|
||||
msg = event.body.strip().strip('\n')
|
||||
|
||||
logger.debug(f"Bot message received from {event.sender} in {room.room_id} --> {msg}")
|
||||
|
||||
# if room.member_count > 2:
|
||||
# has_command_prefix =
|
||||
# else:
|
||||
# has_command_prefix = False
|
||||
|
||||
command_activated, sent_command_prefix, command_info = check_command_prefix(msg, self.command_prefixes)
|
||||
|
||||
if not command_activated and is_thread(event): # Threaded messages
|
||||
is_our_thread, sent_command_prefix, command_info = await is_this_our_thread(self.client, room, event, self.command_prefixes)
|
||||
|
||||
if is_our_thread or room.member_count == 2:
|
||||
# Wrap this in a try/catch so we can add reaction on failure.
|
||||
# But don't want to spam the chat with errors.
|
||||
try:
|
||||
if not check_authorized(event.sender, command_info['allowed_to_chat']):
|
||||
await react_to_event(self.client, room.room_id, event.event_id, "🚫")
|
||||
return
|
||||
|
||||
await self.client.room_typing(room.room_id, typing_state=True, timeout=3000)
|
||||
thread_content = await get_thread_content(self.client, room, event)
|
||||
api_data = []
|
||||
for event in thread_content:
|
||||
if isinstance(event, MegolmEvent):
|
||||
resp = await send_text_to_room(self.client,
|
||||
room.room_id,
|
||||
'❌ 🔐 Decryption Failure',
|
||||
reply_to_event_id=event.event_id,
|
||||
thread=True,
|
||||
thread_root_id=thread_content[0].event_id
|
||||
)
|
||||
logger.critical(f'Decryption failure for event {event.event_id} in room {room.room_id}')
|
||||
await self.client.room_typing(room.room_id, typing_state=False, timeout=3000)
|
||||
self.store.add_event_id(resp.event_id)
|
||||
return
|
||||
else:
|
||||
thread_msg = event.body.strip().strip('\n')
|
||||
api_data.append(
|
||||
{
|
||||
'role': 'assistant' if event.sender == self.client.user_id else 'user',
|
||||
'content': thread_msg if not check_command_prefix(thread_msg, self.command_prefixes)[0] else thread_msg[len(sent_command_prefix):].strip()
|
||||
}
|
||||
) # if len(thread_content) >= 2 and thread_content[0].body.startswith(self.command_prefix): # if thread_content[len(thread_content) - 2].sender == self.client.user
|
||||
|
||||
# TODO: process_chat() will set typing as false after generating.
|
||||
# TODO: If there is still another query in-progress that typing state will be overwritten by the one that just finished.
|
||||
async def inner():
|
||||
await process_chat(
|
||||
self.client,
|
||||
room,
|
||||
event,
|
||||
api_data,
|
||||
self.store,
|
||||
openai_obj=self.openai_obj,
|
||||
openai_model=command_info['model'],
|
||||
openai_temperature=self.openai_temperature,
|
||||
thread_root_id=thread_content[0].event_id,
|
||||
system_prompt=self.system_prompt,
|
||||
log_full_response=self.log_full_response,
|
||||
injected_system_prompt=self.injected_system_prompt,
|
||||
send_extra_messages=self.send_extra_messages
|
||||
)
|
||||
|
||||
asyncio.get_event_loop().create_task(inner())
|
||||
except:
|
||||
await react_to_event(self.client, room.room_id, event.event_id, '❌')
|
||||
raise
|
||||
return
|
||||
elif (command_activated or room.member_count == 2) and not is_thread(event): # Everything else
|
||||
if command_info.get('allowed_to_chat') and not check_authorized(event.sender, command_info['allowed_to_chat']):
|
||||
await react_to_event(self.client, room.room_id, event.event_id, "🚫")
|
||||
return
|
||||
try:
|
||||
msg = msg if not command_activated else msg[len(sent_command_prefix):].strip() # Remove the command prefix
|
||||
command = Command(
|
||||
self.client,
|
||||
self.store,
|
||||
msg,
|
||||
room,
|
||||
event,
|
||||
openai_obj=self.openai_obj,
|
||||
openai_model=command_info['model'],
|
||||
openai_temperature=self.openai_temperature,
|
||||
reply_in_thread=self.reply_in_thread,
|
||||
system_prompt=self.system_prompt,
|
||||
injected_system_prompt=self.injected_system_prompt,
|
||||
log_full_response=self.log_full_response
|
||||
)
|
||||
await command.process()
|
||||
except:
|
||||
await react_to_event(self.client, room.room_id, event.event_id, '❌')
|
||||
raise
|
||||
else:
|
||||
# We don't want this debug info to crash the entire process if an error is encountered
|
||||
try:
|
||||
if self.log_level == logging.DEBUG:
|
||||
# This may be a little slow
|
||||
debug = {
|
||||
'command_prefix': sent_command_prefix,
|
||||
'are_we_activated': command_activated,
|
||||
'is_dm': room.member_count == 2,
|
||||
'is_thread': is_thread(event),
|
||||
'is_our_thread': await is_this_our_thread(self.client, room, event, self.command_prefixes)[0]
|
||||
|
||||
}
|
||||
logger.debug(f"Bot not reacting to event {event.event_id}: {json.dumps(debug)}")
|
||||
except Exception:
|
||||
logger.critical(traceback.format_exc())
|
||||
|
||||
async def invite(self, room: MatrixRoom, event: InviteMemberEvent) -> None:
|
||||
"""Callback for when an invite is received. Join the room specified in the invite.
|
||||
|
||||
Args:
|
||||
room: The room that we are invited to.
|
||||
event: The invite event.
|
||||
"""
|
||||
if not check_authorized(event.sender, self.allowed_to_invite):
|
||||
logger.info(f"Got invite to {room.room_id} from {event.sender} but rejected.")
|
||||
return
|
||||
|
||||
logger.debug(f"Got invite to {room.room_id} from {event.sender}.")
|
||||
|
||||
# Attempt to join 3 times before giving up
|
||||
for attempt in range(3):
|
||||
result = await self.client.join(room.room_id)
|
||||
if type(result) == JoinError:
|
||||
logger.error(f"Error joining room {room.room_id} (attempt %d): %s", attempt, result.message, )
|
||||
else:
|
||||
logger.info(f"Joined via invite: {room.room_id}")
|
||||
return
|
||||
else:
|
||||
logger.error("Unable to join room: %s", room.room_id)
|
||||
|
||||
async def invite_event_filtered_callback(self, room: MatrixRoom, event: InviteMemberEvent) -> None:
|
||||
"""
|
||||
Since the InviteMemberEvent is fired for every m.room.member state received
|
||||
in a sync response's `rooms.invite` section, we will receive some that are
|
||||
not actually our own invite event (such as the inviter's membership).
|
||||
This makes sure we only call `callbacks.invite` with our own invite events.
|
||||
"""
|
||||
if event.state_key == self.client.user_id:
|
||||
await self.invite(room, event)
|
||||
|
||||
async def decryption_failure(self, room: MatrixRoom, event: MegolmEvent) -> None:
|
||||
"""Callback for when an event fails to decrypt. Inform the user.
|
||||
|
||||
Args:
|
||||
room: The room that the event that we were unable to decrypt is in.
|
||||
event: The encrypted event that we were unable to decrypt.
|
||||
"""
|
||||
# logger.error(f"Failed to decrypt event '{event.event_id}' in room '{room.room_id}'!"
|
||||
# f"\n\n"
|
||||
# f"Tip: try using a different device ID in your config file and restart."
|
||||
# f"\n\n"
|
||||
# f"If all else fails, delete your store directory and let the bot recreate "
|
||||
# f"it (your reminders will NOT be deleted, but the bot may respond to existing "
|
||||
# f"commands a second time).")
|
||||
await self.client.room_read_markers(room.room_id, event.event_id, event.event_id)
|
||||
if event.server_timestamp > self.startup_ts:
|
||||
logger.critical(f'Decryption failure for event {event.event_id} in room {room.room_id}')
|
||||
await react_to_event(self.client, room.room_id, event.event_id, "❌ 🔐")
|
||||
|
||||
async def unknown(self, room: MatrixRoom, event: UnknownEvent) -> None:
|
||||
"""Callback for when an event with a type that is unknown to matrix-nio is received.
|
||||
Currently this is used for reaction events, which are not yet part of a released
|
||||
matrix spec (and are thus unknown to nio).
|
||||
|
||||
Args:
|
||||
room: The room the reaction was sent in.
|
||||
|
||||
event: The event itself.
|
||||
"""
|
||||
# if event.type == "m.reaction":
|
||||
# # Get the ID of the event this was a reaction to
|
||||
# relation_dict = event.source.get("content", {}).get("m.relates_to", {})
|
||||
#
|
||||
# reacted_to = relation_dict.get("event_id")
|
||||
# if reacted_to and relation_dict.get("rel_type") == "m.annotation":
|
||||
# await self._reaction(room, event, reacted_to)
|
||||
# return
|
||||
await self.client.room_read_markers(room.room_id, event.event_id, event.event_id)
|
||||
logger.debug(f"Got unknown event with type to {event.type} from {event.sender} in {room.room_id}.")
|
|
@ -1,331 +0,0 @@
|
|||
import asyncio
|
||||
import functools
|
||||
import logging
|
||||
import time
|
||||
from types import ModuleType
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import stopit
|
||||
from markdown import markdown
|
||||
from nio import AsyncClient, ErrorResponse, Event, MatrixRoom, MegolmEvent, Response, RoomGetEventResponse, \
|
||||
RoomMessageText, RoomSendResponse, SendRetryError
|
||||
|
||||
logger = logging.getLogger('MatrixGPT')
|
||||
|
||||
|
||||
async def send_text_to_room(client: AsyncClient, room_id: str, message: str, notice: bool = False,
|
||||
markdown_convert: bool = True, reply_to_event_id: Optional[str] = None,
|
||||
thread: bool = False, thread_root_id: str = None, extra_error: str = False,
|
||||
extra_msg: str = False) -> Union[RoomSendResponse, ErrorResponse]:
|
||||
"""Send text to a matrix room.
|
||||
|
||||
Args:
|
||||
client: The client to communicate to matrix with.
|
||||
room_id: The ID of the room to send the message to.
|
||||
message: The message content.
|
||||
notice: Whether the message should be sent with an "m.notice" message type
|
||||
(will not ping users).
|
||||
markdown_convert: Whether to convert the message content to markdown.
|
||||
Defaults to true.
|
||||
reply_to_event_id: Whether this message is a reply to another event. The event
|
||||
ID this is message is a reply to.
|
||||
thread:
|
||||
thread_root_id:
|
||||
|
||||
Returns:
|
||||
A RoomSendResponse if the request was successful, else an ErrorResponse.
|
||||
|
||||
"""
|
||||
# Determine whether to ping room members or not
|
||||
msgtype = "m.notice" if notice else "m.text"
|
||||
|
||||
content = {"msgtype": msgtype, "format": "org.matrix.custom.html", "body": message, }
|
||||
|
||||
if markdown_convert:
|
||||
content["formatted_body"] = markdown(message, extensions=['fenced_code'])
|
||||
|
||||
if reply_to_event_id:
|
||||
if thread:
|
||||
content["m.relates_to"] = {
|
||||
'event_id': thread_root_id,
|
||||
'is_falling_back': True,
|
||||
"m.in_reply_to": {
|
||||
"event_id": reply_to_event_id
|
||||
},
|
||||
'rel_type': "m.thread"
|
||||
}
|
||||
else:
|
||||
content["m.relates_to"] = {
|
||||
"m.in_reply_to": {
|
||||
"event_id": reply_to_event_id
|
||||
}
|
||||
}
|
||||
|
||||
# TODO: don't force this to string. what if we want to send an array?
|
||||
content["m.matrixgpt"] = {
|
||||
"error": str(extra_error),
|
||||
"msg": str(extra_msg),
|
||||
}
|
||||
try:
|
||||
return await client.room_send(room_id, "m.room.message", content, ignore_unverified_devices=True)
|
||||
except SendRetryError:
|
||||
logger.exception(f"Unable to send message response to {room_id}")
|
||||
|
||||
|
||||
def make_pill(user_id: str, displayname: str = None) -> str:
|
||||
"""Convert a user ID (and optionally a display name) to a formatted user 'pill'
|
||||
|
||||
Args:
|
||||
user_id: The MXID of the user.
|
||||
|
||||
displayname: An optional displayname. Clients like Element will figure out the
|
||||
correct display name no matter what, but other clients may not. If not
|
||||
provided, the MXID will be used instead.
|
||||
|
||||
Returns:
|
||||
The formatted user pill.
|
||||
"""
|
||||
if not displayname:
|
||||
displayname = user_id
|
||||
return f'<a href="https://matrix.to/#/{user_id}">{displayname}</a>'
|
||||
|
||||
|
||||
async def react_to_event(client: AsyncClient, room_id: str, event_id: str, reaction_text: str, extra_error: str = False,
|
||||
extra_msg: str = False) -> Union[
|
||||
Response, ErrorResponse]:
|
||||
"""Reacts to a given event in a room with the given reaction text
|
||||
|
||||
Args:
|
||||
client: The client to communicate to matrix with.
|
||||
|
||||
room_id: The ID of the room to send the message to.
|
||||
|
||||
event_id: The ID of the event to react to.
|
||||
|
||||
reaction_text: The string to react with. Can also be (one or more) emoji characters.
|
||||
|
||||
Returns:
|
||||
A nio.Response or nio.ErrorResponse if an error occurred.
|
||||
|
||||
Raises:
|
||||
SendRetryError: If the reaction was unable to be sent.
|
||||
"""
|
||||
content = {
|
||||
"m.relates_to": {
|
||||
"rel_type": "m.annotation",
|
||||
"event_id": event_id,
|
||||
"key": reaction_text
|
||||
},
|
||||
"m.matrixgpt": {
|
||||
"error": str(extra_error),
|
||||
"msg": str(extra_msg),
|
||||
}
|
||||
}
|
||||
return await client.room_send(room_id, "m.reaction", content, ignore_unverified_devices=True, )
|
||||
|
||||
|
||||
async def decryption_failure(self, room: MatrixRoom, event: MegolmEvent) -> None:
|
||||
"""Callback for when an event fails to decrypt. Inform the user"""
|
||||
# logger.error(
|
||||
# f"Failed to decrypt event '{event.event_id}' in room '{room.room_id}'!"
|
||||
# f"\n\n"
|
||||
# f"Tip: try using a different device ID in your config file and restart."
|
||||
# f"\n\n"
|
||||
# f"If all else fails, delete your store directory and let the bot recreate "
|
||||
# f"it (your reminders will NOT be deleted, but the bot may respond to existing "
|
||||
# f"commands a second time)."
|
||||
# )
|
||||
|
||||
user_msg = "Unable to decrypt this message. Check whether you've chosen to only encrypt to trusted devices."
|
||||
await send_text_to_room(self.client, room.room_id, user_msg, reply_to_event_id=event.event_id, )
|
||||
|
||||
|
||||
def is_thread(event: RoomMessageText):
|
||||
return event.source['content'].get('m.relates_to', {}).get('rel_type') == 'm.thread'
|
||||
|
||||
|
||||
def check_command_prefix(string: str, prefixes: dict):
|
||||
for k, v in prefixes.items():
|
||||
if string.startswith(f'{k} '):
|
||||
return True, k, v
|
||||
return False, None, None
|
||||
|
||||
|
||||
async def is_this_our_thread(client: AsyncClient, room: MatrixRoom, event: RoomMessageText, command_prefixes: dict) -> \
|
||||
tuple[bool, any, any]:
|
||||
base_event_id = event.source['content'].get('m.relates_to', {}).get('event_id')
|
||||
if base_event_id:
|
||||
e = await client.room_get_event(room.room_id, base_event_id)
|
||||
if not isinstance(e, RoomGetEventResponse):
|
||||
logger.critical(f'Failed to get event in is_this_our_thread(): {vars(e)}')
|
||||
return False, None, None
|
||||
else:
|
||||
return check_command_prefix(e.event.body, command_prefixes)
|
||||
else:
|
||||
return False, None, None
|
||||
|
||||
|
||||
async def get_thread_content(client: AsyncClient, room: MatrixRoom, base_event: RoomMessageText) -> List[Event]:
|
||||
messages = []
|
||||
new_event = (await client.room_get_event(room.room_id, base_event.event_id)).event
|
||||
while True:
|
||||
if new_event.source['content'].get('m.relates_to', {}).get('rel_type') == 'm.thread':
|
||||
messages.append(new_event)
|
||||
else:
|
||||
break
|
||||
new_event = (await client.room_get_event(room.room_id,
|
||||
new_event.source['content']['m.relates_to']['m.in_reply_to'][
|
||||
'event_id'])).event
|
||||
messages.append((await client.room_get_event(room.room_id, base_event.source['content']['m.relates_to'][
|
||||
'event_id'])).event) # put the root event in the array
|
||||
messages.reverse()
|
||||
return messages
|
||||
|
||||
|
||||
async def process_chat(
|
||||
client,
|
||||
room,
|
||||
event,
|
||||
command,
|
||||
store,
|
||||
openai_obj: ModuleType,
|
||||
openai_model: str,
|
||||
openai_temperature: float,
|
||||
openai_retries: int = 3,
|
||||
thread_root_id: str = None,
|
||||
system_prompt: str = None,
|
||||
log_full_response: bool = False,
|
||||
injected_system_prompt: str = False,
|
||||
send_extra_messages: bool = True
|
||||
):
|
||||
try:
|
||||
if not store.check_seen_event(event.event_id):
|
||||
await client.room_typing(room.room_id, typing_state=True, timeout=90000)
|
||||
# if self.reply_in_thread:
|
||||
# thread_content = await get_thread_content(self.client, self.room, self.event)
|
||||
|
||||
if isinstance(command, list):
|
||||
messages = command
|
||||
else:
|
||||
messages = [{'role': 'user', 'content': command}]
|
||||
|
||||
if system_prompt:
|
||||
messages.insert(0, {"role": "system", "content": system_prompt})
|
||||
if injected_system_prompt:
|
||||
if messages[-1]['role'] == 'system':
|
||||
del messages[-1]
|
||||
index = -9999
|
||||
if len(messages) >= 3: # only inject the system prompt if this isn't the first reply
|
||||
index = -1
|
||||
elif not system_prompt:
|
||||
index = 0
|
||||
if index != -9999:
|
||||
messages.insert(index, {"role": "system", "content": injected_system_prompt})
|
||||
|
||||
logger.debug(f'Generating reply to event {event.event_id}')
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# I don't think the OpenAI py api has a built-in timeout
|
||||
@stopit.threading_timeoutable(default=(None, None))
|
||||
async def generate():
|
||||
if openai_model.startswith('gpt-3') or openai_model.startswith('gpt-4') or openai_model == 'local':
|
||||
r = await loop.run_in_executor(None, functools.partial(openai_obj.ChatCompletion.create,
|
||||
model=openai_model, messages=messages,
|
||||
temperature=openai_temperature, timeout=900, max_tokens=None if openai_model != 'local' else 320))
|
||||
return r.choices[0].message.content
|
||||
elif openai_model in ['text-davinci-003', 'davinci-instruct-beta', 'text-davinci-001',
|
||||
'text-davinci-002', 'text-curie-001', 'text-babbage-001']:
|
||||
r = await loop.run_in_executor(None,
|
||||
functools.partial(openai_obj.Completion.create, model=openai_model,
|
||||
temperature=openai_temperature,
|
||||
request_timeout=900,
|
||||
max_tokens=4096))
|
||||
return r.choices[0].text
|
||||
else:
|
||||
raise Exception(f'Model {openai_model} not found!')
|
||||
|
||||
response = None
|
||||
openai_gen_error = None
|
||||
for i in range(1, openai_retries):
|
||||
sleep_time = i * 5
|
||||
try:
|
||||
task = asyncio.create_task(generate(timeout=900))
|
||||
asyncio.as_completed(task)
|
||||
response = await task
|
||||
if response is not None:
|
||||
break
|
||||
else:
|
||||
openai_gen_error = 'response was null'
|
||||
logger.warning(
|
||||
f'Response to event {event.event_id} was null, retrying {i}/{openai_retries} after {sleep_time}s.')
|
||||
# time.sleep(2)
|
||||
except Exception as e: # (stopit.utils.TimeoutException, openai.error.APIConnectionError)
|
||||
openai_gen_error = e
|
||||
logger.warning(
|
||||
f'Got exception when generating response to event {event.event_id}, retrying {i}/{openai_retries} after {sleep_time}s. Error: {e}')
|
||||
await client.room_typing(room.room_id, typing_state=True, timeout=15000)
|
||||
time.sleep(sleep_time)
|
||||
continue
|
||||
|
||||
if response is None:
|
||||
logger.critical(f'Response to event {event.event_id} in room {room.room_id} was null.')
|
||||
await client.room_typing(room.room_id, typing_state=False, timeout=15000)
|
||||
await react_to_event(client, room.room_id, event.event_id, '❌',
|
||||
extra_error=(openai_gen_error if send_extra_messages else False))
|
||||
return
|
||||
text_response = response.strip().strip('\n')
|
||||
|
||||
# Logging stuff
|
||||
if log_full_response:
|
||||
logger.debug(
|
||||
{'event_id': event.event_id, 'room': room.room_id, 'messages': messages, 'response': response})
|
||||
z = text_response.replace("\n", "\\n")
|
||||
if isinstance(command, str):
|
||||
x = command.replace("\n", "\\n")
|
||||
elif isinstance(command, list):
|
||||
x = command[-1]['content'].replace("\n", "\\n")
|
||||
else:
|
||||
x = command
|
||||
logger.info(f'Reply to {event.event_id} --> "{x}" and bot ({openai_model}) responded with "{z}"')
|
||||
|
||||
resp = await send_text_to_room(client, room.room_id, text_response, reply_to_event_id=event.event_id,
|
||||
thread=True,
|
||||
thread_root_id=thread_root_id if thread_root_id else event.event_id)
|
||||
await client.room_typing(room.room_id, typing_state=False, timeout=3000)
|
||||
|
||||
store.add_event_id(event.event_id)
|
||||
if not isinstance(resp, RoomSendResponse):
|
||||
logger.critical(f'Failed to respond to event {event.event_id} in room {room.room_id}:\n{vars(resp)}')
|
||||
await react_to_event(client, room.room_id, event.event_id, '❌')
|
||||
else:
|
||||
store.add_event_id(resp.event_id)
|
||||
except Exception:
|
||||
await react_to_event(client, room.room_id, event.event_id, '❌')
|
||||
raise
|
||||
|
||||
|
||||
def check_authorized(string, to_check):
|
||||
def check_str(s, c):
|
||||
if c != 'all':
|
||||
if '@' not in c and ':' not in c:
|
||||
# Homeserver
|
||||
if s.split(':')[-1] in c:
|
||||
return True
|
||||
elif s in c:
|
||||
# By username
|
||||
return True
|
||||
elif c == 'all':
|
||||
return True
|
||||
return False
|
||||
|
||||
if isinstance(to_check, str):
|
||||
return check_str(string, to_check)
|
||||
elif isinstance(to_check, list):
|
||||
output = False
|
||||
for item in to_check:
|
||||
if check_str(string, item):
|
||||
output = True
|
||||
return output
|
||||
else:
|
||||
raise Exception
|
|
@ -1,55 +0,0 @@
|
|||
import logging
|
||||
|
||||
from nio import AsyncClient, MatrixRoom, RoomMessageText
|
||||
|
||||
from .chat_functions import send_text_to_room
|
||||
|
||||
# from .config import Config
|
||||
from .storage import Storage
|
||||
|
||||
logger = logging.getLogger('MatrixGPT')
|
||||
|
||||
|
||||
class Message:
|
||||
def __init__(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
store: Storage,
|
||||
# config: Config,
|
||||
message_content: str,
|
||||
room: MatrixRoom,
|
||||
event: RoomMessageText,
|
||||
openai,
|
||||
|
||||
):
|
||||
"""Initialize a new Message
|
||||
|
||||
Args:
|
||||
client: nio client used to interact with matrix.
|
||||
|
||||
store: Bot storage.
|
||||
|
||||
config: Bot configuration parameters.
|
||||
|
||||
message_content: The body of the message.
|
||||
|
||||
room: The room the event came from.
|
||||
|
||||
event: The event defining the message.
|
||||
"""
|
||||
self.client = client
|
||||
self.store = store
|
||||
# self.config = config
|
||||
self.message_content = message_content
|
||||
self.room = room
|
||||
self.event = event
|
||||
|
||||
async def process(self) -> None:
|
||||
"""Process and possibly respond to the message"""
|
||||
if self.message_content.lower() == "hello world":
|
||||
await self._hello_world()
|
||||
|
||||
async def _hello_world(self) -> None:
|
||||
"""Say hello"""
|
||||
text = "Hello, world!"
|
||||
await send_text_to_room(self.client, self.room.room_id, text)
|
|
@ -1,32 +0,0 @@
|
|||
import logging
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
logger = logging.getLogger('MatrixGPT')
|
||||
|
||||
|
||||
class Storage:
|
||||
insert_event = "INSERT INTO `seen_events` (`event_id`) VALUES (?);"
|
||||
seen_events = set()
|
||||
|
||||
def __init__(self, database_file: Union[str, Path]):
|
||||
self.conn = sqlite3.connect(database_file)
|
||||
self.cursor = self.conn.cursor()
|
||||
|
||||
table_exists = self.cursor.execute("SELECT count(*) FROM sqlite_master WHERE type='table' AND name='seen_events';").fetchall()[0][0]
|
||||
if table_exists == 0:
|
||||
self.cursor.execute("CREATE TABLE `seen_events` (`event_id` text NOT NULL);")
|
||||
logger.info('Created new database file.')
|
||||
|
||||
# This does not work
|
||||
# db_seen_events = self.cursor.execute("SELECT `event_id` FROM `seen_events`;").fetchall()
|
||||
|
||||
def add_event_id(self, event_id):
|
||||
self.seen_events.add(event_id)
|
||||
|
||||
# This makes the program exit???
|
||||
# self.cursor.execute(self.insert_event, (event_id))
|
||||
|
||||
def check_seen_event(self, event_id):
|
||||
return event_id in self.seen_events
|
|
@ -0,0 +1,84 @@
|
|||
# https://github.com/anoadragon453/nio-template
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
|
||||
from nio import (AsyncClient, InviteMemberEvent, MatrixRoom, MegolmEvent, RoomMessageText, UnknownEvent)
|
||||
|
||||
from .chat_functions import check_authorized, is_thread, check_command_prefix
|
||||
from .config import global_config
|
||||
from .handle_actions import do_reply_msg, do_reply_threaded_msg, do_join_channel
|
||||
from .matrix import MatrixClientHelper
|
||||
|
||||
logger = logging.getLogger('MatrixGPT')
|
||||
|
||||
|
||||
class MatrixBotCallbacks:
|
||||
def __init__(self, client: MatrixClientHelper):
|
||||
self.client_helper = client
|
||||
self.client: AsyncClient = client.client
|
||||
self.logger = logging.getLogger('ExportBot').getChild('MatrixBotCallbacks')
|
||||
self.startup_ts = time.time() * 1000
|
||||
|
||||
async def handle_message(self, room: MatrixRoom, requestor_event: RoomMessageText) -> None:
|
||||
"""
|
||||
Callback for when a message event is received
|
||||
"""
|
||||
# Mark all messages as read.
|
||||
mark_read_task = asyncio.create_task(self.client.room_read_markers(room.room_id, requestor_event.event_id, requestor_event.event_id))
|
||||
|
||||
msg = requestor_event.body.strip().strip('\n')
|
||||
if msg == "** Unable to decrypt: The sender's device has not sent us the keys for this message. **":
|
||||
self.logger.debug(f'Unable to decrypt event "{requestor_event.event_id} in room {room.room_id}')
|
||||
return
|
||||
if requestor_event.server_timestamp < self.startup_ts:
|
||||
return
|
||||
if requestor_event.sender == self.client.user_id:
|
||||
return
|
||||
command_activated, sent_command_prefix, command_info = check_command_prefix(msg)
|
||||
|
||||
if not command_activated and is_thread(requestor_event):
|
||||
# Threaded messages
|
||||
logger.debug(f'Message from {requestor_event.sender} in {room.room_id} --> "{msg}"')
|
||||
# Start the task in the background and don't wait for it here or else we'll block everything.
|
||||
task = asyncio.create_task(do_reply_threaded_msg(self.client_helper, room, requestor_event, command_info, command_activated, sent_command_prefix))
|
||||
elif (command_activated or room.member_count == 2) and not is_thread(requestor_event):
|
||||
# Everything else
|
||||
logger.debug(f'Message from {requestor_event.sender} in {room.room_id} --> "{msg}"')
|
||||
allowed_to_chat = command_info['allowed_to_chat'] + global_config['allowed_to_chat']
|
||||
if not check_authorized(requestor_event.sender, allowed_to_chat):
|
||||
await self.client_helper.react_to_event(room.room_id, requestor_event.event_id, '🚫', extra_error='Not allowed to chat.' if global_config['send_extra_messages'] else None)
|
||||
return
|
||||
task = asyncio.create_task(do_reply_msg(self.client_helper, room, requestor_event, command_info, command_activated, sent_command_prefix))
|
||||
|
||||
async def handle_invite(self, room: MatrixRoom, event: InviteMemberEvent) -> None:
|
||||
"""Callback for when an invite is received. Join the room specified in the invite.
|
||||
Args:
|
||||
room: The room that we are invited to.
|
||||
event: The invite event.
|
||||
"""
|
||||
"""
|
||||
Since the InviteMemberEvent is fired for every m.room.member state received
|
||||
in a sync response's `rooms.invite` section, we will receive some that are
|
||||
not actually our own invite event (such as the inviter's membership).
|
||||
This makes sure we only call `callbacks.invite` with our own invite events.
|
||||
"""
|
||||
if event.state_key == self.client.user_id:
|
||||
task = asyncio.create_task(do_join_channel(self.client_helper, room, event))
|
||||
|
||||
async def decryption_failure(self, room: MatrixRoom, event: MegolmEvent) -> None:
|
||||
"""
|
||||
Callback for when an event fails to decrypt. Inform the user.
|
||||
"""
|
||||
await self.client.room_read_markers(room.room_id, event.event_id, event.event_id)
|
||||
if event.server_timestamp > self.startup_ts:
|
||||
logger.critical(f'Decryption failure for event {event.event_id} in room {room.room_id}')
|
||||
await self.client_helper.react_to_event(room.room_id, event.event_id, "❌ 🔐")
|
||||
|
||||
async def unknown(self, room: MatrixRoom, event: UnknownEvent) -> None:
|
||||
"""
|
||||
Callback for when an event with a type that is unknown to matrix-nio is received.
|
||||
Currently this is used for reaction events, which are not yet part of a released
|
||||
matrix spec (and are thus unknown to nio).
|
||||
"""
|
||||
await self.client.room_read_markers(room.room_id, event.event_id, event.event_id)
|
|
@ -0,0 +1,77 @@
|
|||
import logging
|
||||
from typing import List
|
||||
|
||||
from nio import AsyncClient, Event, MatrixRoom, RoomGetEventResponse, RoomMessageText
|
||||
|
||||
from matrix_gpt.config import global_config
|
||||
|
||||
logger = logging.getLogger('ChatFunctions')
|
||||
|
||||
|
||||
def is_thread(event: RoomMessageText):
|
||||
return event.source['content'].get('m.relates_to', {}).get('rel_type') == 'm.thread'
|
||||
|
||||
|
||||
def check_command_prefix(string: str):
|
||||
for k, v in global_config.command_prefixes.items():
|
||||
if string.startswith(f'{k} '):
|
||||
return True, k, v
|
||||
return False, None, None
|
||||
|
||||
|
||||
async def is_this_our_thread(client: AsyncClient, room: MatrixRoom, event: RoomMessageText) -> tuple[bool, any, any]:
|
||||
base_event_id = event.source['content'].get('m.relates_to', {}).get('event_id')
|
||||
if base_event_id:
|
||||
e = await client.room_get_event(room.room_id, base_event_id)
|
||||
if not isinstance(e, RoomGetEventResponse):
|
||||
logger.critical(f'Failed to get event in is_this_our_thread(): {vars(e)}')
|
||||
return False, None, None
|
||||
else:
|
||||
return check_command_prefix(e.event.body)
|
||||
else:
|
||||
return False, None, None
|
||||
|
||||
|
||||
async def get_thread_content(client: AsyncClient, room: MatrixRoom, base_event: RoomMessageText) -> List[Event]:
|
||||
messages = []
|
||||
new_event = (await client.room_get_event(room.room_id, base_event.event_id)).event
|
||||
while True:
|
||||
if new_event.source['content'].get('m.relates_to', {}).get('rel_type') == 'm.thread':
|
||||
messages.append(new_event)
|
||||
else:
|
||||
break
|
||||
new_event = (await client.room_get_event(
|
||||
room.room_id,
|
||||
new_event.source['content']['m.relates_to']['m.in_reply_to']['event_id'])
|
||||
).event
|
||||
messages.append((await client.room_get_event(
|
||||
room.room_id, base_event.source['content']['m.relates_to']['event_id'])
|
||||
).event) # put the root event in the array
|
||||
messages.reverse()
|
||||
return messages
|
||||
|
||||
|
||||
def check_authorized(string, to_check):
|
||||
def check_str(s, c):
|
||||
if c == 'all':
|
||||
return True
|
||||
else:
|
||||
if '@' not in c and ':' not in c:
|
||||
# Homeserver
|
||||
if s.split(':')[-1] in c:
|
||||
return True
|
||||
elif s in c:
|
||||
# By username
|
||||
return True
|
||||
return False
|
||||
|
||||
if isinstance(to_check, str):
|
||||
return check_str(string, to_check)
|
||||
elif isinstance(to_check, list):
|
||||
output = False
|
||||
for item in to_check:
|
||||
if check_str(string, item):
|
||||
output = True
|
||||
return output
|
||||
else:
|
||||
raise Exception
|
|
@ -1,20 +1,102 @@
|
|||
import sys
|
||||
import copy
|
||||
from pathlib import Path
|
||||
from types import NoneType
|
||||
from typing import Union
|
||||
|
||||
import bison
|
||||
|
||||
OPENAI_DEFAULT_SYSTEM_PROMPT = ""
|
||||
OPENAI_DEFAULT_INJECTED_SYSTEM_PROMPT = ""
|
||||
|
||||
config_scheme = bison.Scheme(
|
||||
bison.Option('store_path', default='bot-store/', field_type=str),
|
||||
bison.DictOption('auth', scheme=bison.Scheme(
|
||||
bison.Option('username', field_type=str, required=True),
|
||||
bison.Option('password', field_type=str, required=True),
|
||||
bison.Option('homeserver', field_type=str, required=True),
|
||||
bison.Option('device_id', field_type=str, required=True),
|
||||
)),
|
||||
bison.ListOption('allowed_to_chat', default=['all']),
|
||||
bison.ListOption('allowed_to_thread', default=['all']),
|
||||
bison.ListOption('allowed_to_invite', default=['all']),
|
||||
bison.ListOption('autojoin_rooms', default=[]),
|
||||
bison.ListOption('whitelist_rooms', default=[]),
|
||||
bison.ListOption('blacklist_rooms', default=[]),
|
||||
bison.Option('reply_in_thread', default=True, field_type=bool),
|
||||
bison.Option('set_avatar', default=True, field_type=bool),
|
||||
bison.Option('response_timeout', default=120, field_type=int),
|
||||
bison.ListOption('command', member_scheme=bison.Scheme(
|
||||
bison.Option('trigger', field_type=str, required=True),
|
||||
bison.Option('model', field_type=str, required=True),
|
||||
bison.ListOption('allowed_to_chat', default=['all']),
|
||||
bison.ListOption('allowed_to_thread', default=['all']),
|
||||
bison.Option('max_tokens', field_type=int, default=0, required=False),
|
||||
)),
|
||||
bison.DictOption('openai', scheme=bison.Scheme(
|
||||
bison.Option('api_key', field_type=str, required=True),
|
||||
bison.Option('api_base', field_type=[str, NoneType], default=None, required=False),
|
||||
bison.Option('api_retries', field_type=int, default=2),
|
||||
bison.Option('temperature', field_type=float, default=0.5),
|
||||
bison.Option('system_prompt', field_type=[str, NoneType], default=OPENAI_DEFAULT_SYSTEM_PROMPT),
|
||||
bison.Option('injected_system_prompt', field_type=[str, NoneType], default=OPENAI_DEFAULT_INJECTED_SYSTEM_PROMPT),
|
||||
)),
|
||||
bison.DictOption('logging', scheme=bison.Scheme(
|
||||
bison.Option('log_level', field_type=str, default='info'),
|
||||
bison.Option('log_full_response', field_type=bool, default=True),
|
||||
)),
|
||||
)
|
||||
|
||||
|
||||
def check_config_value_exists(config_part, key, check_type=None, allow_empty=False, choices: list = None, default=None) -> bool:
|
||||
if default and key not in config_part.keys():
|
||||
return default
|
||||
else:
|
||||
if key not in config_part.keys():
|
||||
print(f'Config key not found: "{key}"')
|
||||
sys.exit(1)
|
||||
if not allow_empty and config_part[key] is None or config_part[key] == '':
|
||||
print(f'Config key "{key}" must not be empty.')
|
||||
sys.exit(1)
|
||||
if check_type and not isinstance(config_part[key], check_type):
|
||||
print(f'Config key "{key}" must be type "{check_type}", not "{type(config_part[key])}".')
|
||||
sys.exit(1)
|
||||
if choices and config_part[key] not in choices:
|
||||
print(f'Invalid choice for config key "{key}". Choices: {choices}')
|
||||
sys.exit(1)
|
||||
return True
|
||||
class ConfigManager:
|
||||
def __init__(self):
|
||||
self._config = bison.Bison(scheme=config_scheme)
|
||||
self._command_prefixes = {}
|
||||
self._loaded = False
|
||||
|
||||
def load(self, path: Path):
|
||||
if self._loaded:
|
||||
raise Exception('Already loaded')
|
||||
self._config.config_name = 'config'
|
||||
self._config.config_format = bison.bison.YAML
|
||||
self._config.add_config_paths(str(path.parent))
|
||||
self._config.parse()
|
||||
self._command_prefixes = self._generate_command_prefixes()
|
||||
self._loaded = True
|
||||
|
||||
def validate(self):
|
||||
self._config.validate()
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return copy.deepcopy(self._config)
|
||||
|
||||
def _generate_command_prefixes(self):
|
||||
command_prefixes = {}
|
||||
for item in self._config.config['command']:
|
||||
command_prefixes[item['trigger']] = item
|
||||
return command_prefixes
|
||||
|
||||
@property
|
||||
def command_prefixes(self):
|
||||
return self._command_prefixes
|
||||
|
||||
def get(self, key, default=None):
|
||||
return copy.copy(self._config.get(key, default))
|
||||
|
||||
def __setitem__(self, key, item):
|
||||
raise Exception
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._config.config[key]
|
||||
|
||||
def __repr__(self):
|
||||
return repr(self._config.config)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._config.config)
|
||||
|
||||
def __delitem__(self, key):
|
||||
raise Exception
|
||||
|
||||
|
||||
global_config = ConfigManager()
|
||||
|
|
|
@ -0,0 +1,135 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import traceback
|
||||
from typing import Union
|
||||
|
||||
from nio import RoomSendResponse
|
||||
|
||||
from matrix_gpt import MatrixClientHelper
|
||||
from matrix_gpt.config import global_config
|
||||
from matrix_gpt.openai_client import openai_client
|
||||
|
||||
logger = logging.getLogger('ProcessChat')
|
||||
|
||||
|
||||
# TODO: process_chat() will set typing as false after generating.
|
||||
# TODO: If there is still another query in-progress that typing state will be overwritten by the one that just finished.
|
||||
|
||||
async def generate_ai_response(
|
||||
client_helper: MatrixClientHelper,
|
||||
room,
|
||||
event,
|
||||
msg: Union[str, list],
|
||||
sent_command_prefix: str,
|
||||
openai_model: str,
|
||||
thread_root_id: str = None,
|
||||
):
|
||||
client = client_helper.client
|
||||
try:
|
||||
await client.room_typing(room.room_id, typing_state=True, timeout=global_config['response_timeout'] * 1000)
|
||||
|
||||
# Set up the messages list.
|
||||
if isinstance(msg, list):
|
||||
messages = msg
|
||||
else:
|
||||
messages = [{'role': 'user', 'content': msg}]
|
||||
|
||||
# Inject the system prompt.
|
||||
system_prompt = global_config['openai'].get('system_prompt', '')
|
||||
injected_system_prompt = global_config['openai'].get('injected_system_prompt', '')
|
||||
if isinstance(system_prompt, str) and len(system_prompt):
|
||||
messages.insert(0, {"role": "system", "content": global_config['openai']['system_prompt']})
|
||||
if (isinstance(injected_system_prompt, str) and len(injected_system_prompt)) and len(messages) >= 3:
|
||||
# Only inject the system prompt if this isn't the first reply.
|
||||
if messages[-1]['role'] == 'system':
|
||||
# Delete the last system message since we want to replace it with our inject prompt.
|
||||
del messages[-1]
|
||||
messages.insert(-1, {"role": "system", "content": global_config['openai']['injected_system_prompt']})
|
||||
|
||||
max_tokens = global_config.command_prefixes[sent_command_prefix]['max_tokens']
|
||||
|
||||
async def generate():
|
||||
if openai_model in ['text-davinci-003', 'davinci-instruct-beta', 'text-davinci-001',
|
||||
'text-davinci-002', 'text-curie-001', 'text-babbage-001']:
|
||||
r = await openai_client.client().completions.create(
|
||||
model=openai_model,
|
||||
temperature=global_config['openai']['temperature'],
|
||||
request_timeout=global_config['response_timeout'],
|
||||
max_tokens=None if max_tokens == 0 else max_tokens
|
||||
)
|
||||
return r.choices[0].text
|
||||
else:
|
||||
r = await openai_client.client().chat.completions.create(
|
||||
model=openai_model, messages=messages,
|
||||
temperature=global_config['openai']['temperature'],
|
||||
timeout=global_config['response_timeout'],
|
||||
max_tokens=None if max_tokens == 0 else max_tokens
|
||||
)
|
||||
return r.choices[0].message.content
|
||||
|
||||
response = None
|
||||
try:
|
||||
task = asyncio.create_task(generate())
|
||||
for task in asyncio.as_completed([task], timeout=global_config['response_timeout']):
|
||||
try:
|
||||
response = await task
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f'Response to event {event.event_id} timed out.')
|
||||
await client_helper.react_to_event(
|
||||
room.room_id,
|
||||
event.event_id,
|
||||
'🕒',
|
||||
extra_error='Request timed out.' if global_config['send_extra_messages'] else None
|
||||
)
|
||||
await client.room_typing(room.room_id, typing_state=False, timeout=1000)
|
||||
return
|
||||
except Exception:
|
||||
logger.error(f'Exception when generating for event {event.event_id}: {traceback.format_exc()}')
|
||||
await client_helper.react_to_event(
|
||||
room.room_id,
|
||||
event.event_id,
|
||||
'❌',
|
||||
extra_error='Exception' if global_config['send_extra_messages'] else None
|
||||
)
|
||||
await client.room_typing(room.room_id, typing_state=False, timeout=1000)
|
||||
return
|
||||
|
||||
if not response:
|
||||
logger.warning(f'Response to event {event.event_id} in room {room.room_id} was null.')
|
||||
await client_helper.react_to_event(
|
||||
room.room_id,
|
||||
event.event_id,
|
||||
'❌',
|
||||
extra_error='Response was null.' if global_config['send_extra_messages'] else None
|
||||
)
|
||||
await client.room_typing(room.room_id, typing_state=False, timeout=1000)
|
||||
return
|
||||
|
||||
# The AI's response.
|
||||
text_response = response.strip().strip('\n')
|
||||
|
||||
# Logging
|
||||
if global_config['logging']['log_full_response']:
|
||||
logger.debug(
|
||||
{'event_id': event.event_id, 'room': room.room_id, 'messages': messages, 'response': response}
|
||||
)
|
||||
z = text_response.replace("\n", "\\n")
|
||||
logger.info(f'Reply to {event.event_id} --> {openai_model} responded with "{z}"')
|
||||
|
||||
# Send message to room
|
||||
resp = await client_helper.send_text_to_room(
|
||||
room.room_id,
|
||||
text_response,
|
||||
reply_to_event_id=event.event_id,
|
||||
thread=True,
|
||||
thread_root_id=thread_root_id if thread_root_id else event.event_id
|
||||
)
|
||||
await client.room_typing(room.room_id, typing_state=False, timeout=1000)
|
||||
if not isinstance(resp, RoomSendResponse):
|
||||
logger.critical(f'Failed to respond to event {event.event_id} in room {room.room_id}:\n{vars(resp)}')
|
||||
await client_helper.react_to_event(room.room_id, event.event_id, '❌', extra_error='Exception' if global_config['send_extra_messages'] else None)
|
||||
except Exception:
|
||||
await client_helper.react_to_event(room.room_id, event.event_id, '❌', extra_error='Exception' if global_config['send_extra_messages'] else None)
|
||||
raise
|
||||
|
|
@ -0,0 +1,104 @@
|
|||
import logging
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from nio import RoomMessageText, MatrixRoom, MegolmEvent, InviteMemberEvent, JoinError
|
||||
|
||||
from matrix_gpt import MatrixClientHelper
|
||||
from matrix_gpt.chat_functions import is_this_our_thread, get_thread_content, check_command_prefix, check_authorized
|
||||
from matrix_gpt.config import global_config
|
||||
from matrix_gpt.generate import generate_ai_response
|
||||
|
||||
logger = logging.getLogger('HandleMessage')
|
||||
|
||||
|
||||
async def do_reply_msg(client_helper: MatrixClientHelper, room: MatrixRoom, requestor_event: RoomMessageText, command_info, command_activated: bool, sent_command_prefix: str):
|
||||
try:
|
||||
raw_msg = requestor_event.body.strip().strip('\n')
|
||||
msg = raw_msg if not command_activated else raw_msg[len(sent_command_prefix):].strip() # Remove the command prefix
|
||||
await generate_ai_response(
|
||||
client_helper=client_helper,
|
||||
room=room,
|
||||
event=requestor_event,
|
||||
msg=msg,
|
||||
sent_command_prefix=sent_command_prefix,
|
||||
openai_model=command_info['model'],
|
||||
)
|
||||
except Exception:
|
||||
logger.critical(traceback.format_exc())
|
||||
await client_helper.react_to_event(room.room_id, requestor_event.event_id, '❌')
|
||||
raise
|
||||
|
||||
|
||||
async def do_reply_threaded_msg(client_helper: MatrixClientHelper, room: MatrixRoom, requestor_event: RoomMessageText, command_info, command_activated: bool, sent_command_prefix: str):
|
||||
client = client_helper.client
|
||||
|
||||
is_our_thread, sent_command_prefix, command_info = await is_this_our_thread(client, room, requestor_event)
|
||||
if not is_our_thread: # or room.member_count == 2
|
||||
return
|
||||
|
||||
allowed_to_chat = command_info['allowed_to_chat'] + global_config['allowed_to_chat'] + command_info['allowed_to_thread'] + global_config['allowed_to_thread']
|
||||
if not check_authorized(requestor_event.sender, allowed_to_chat):
|
||||
await client_helper.react_to_event(room.room_id, requestor_event.event_id, '🚫', extra_error='Not allowed to chat and/or thread.' if global_config['send_extra_messages'] else None)
|
||||
return
|
||||
|
||||
try:
|
||||
# TODO: sync this with redis so that we don't clear the typing state if another response is also processing
|
||||
await client.room_typing(room.room_id, typing_state=True, timeout=30000)
|
||||
|
||||
thread_content = await get_thread_content(client, room, requestor_event)
|
||||
api_data = []
|
||||
for event in thread_content:
|
||||
if isinstance(event, MegolmEvent):
|
||||
await client_helper.send_text_to_room(
|
||||
room.room_id,
|
||||
'❌ 🔐 Decryption Failure',
|
||||
reply_to_event_id=event.event_id,
|
||||
thread=True,
|
||||
thread_root_id=thread_content[0].event_id
|
||||
)
|
||||
logger.critical(f'Decryption failure for event {event.event_id} in room {room.room_id}')
|
||||
await client.room_typing(room.room_id, typing_state=False, timeout=1000)
|
||||
return
|
||||
else:
|
||||
thread_msg = event.body.strip().strip('\n')
|
||||
api_data.append(
|
||||
{
|
||||
'role': 'assistant' if event.sender == client.user_id else 'user',
|
||||
'content': thread_msg if not check_command_prefix(thread_msg)[0] else thread_msg[len(sent_command_prefix):].strip()
|
||||
}
|
||||
)
|
||||
|
||||
await generate_ai_response(
|
||||
client_helper=client_helper,
|
||||
room=room,
|
||||
event=requestor_event,
|
||||
msg=api_data,
|
||||
sent_command_prefix=sent_command_prefix,
|
||||
openai_model=command_info['model'],
|
||||
thread_root_id=thread_content[0].event_id
|
||||
)
|
||||
except:
|
||||
await client_helper.react_to_event(room.room_id, event.event_id, '❌')
|
||||
raise
|
||||
|
||||
|
||||
async def do_join_channel(client_helper: MatrixClientHelper, room: MatrixRoom, event: InviteMemberEvent):
|
||||
if not check_authorized(event.sender, global_config['allowed_to_invite']):
|
||||
logger.info(f'Got invite to {room.room_id} from {event.sender} but rejected')
|
||||
return
|
||||
|
||||
logger.info(f'Got invite to {room.room_id} from {event.sender}')
|
||||
|
||||
# Attempt to join 3 times before giving up
|
||||
client = client_helper.client
|
||||
for attempt in range(3):
|
||||
result = await client.join(room.room_id)
|
||||
if isinstance(result, JoinError):
|
||||
logger.error(f'Error joining room {room.room_id} (attempt {attempt}): "{result.message}"')
|
||||
time.sleep(5)
|
||||
else:
|
||||
logger.info(f'Joined via invite: {room.room_id}')
|
||||
return
|
||||
else:
|
||||
logger.error(f'Unable to join room: {room.room_id}')
|
|
@ -1,26 +1,24 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Union, Optional
|
||||
|
||||
from nio import AsyncClient, AsyncClientConfig, LoginError
|
||||
from nio import LoginResponse
|
||||
|
||||
logger = logging.getLogger('MatrixGPT')
|
||||
from markdown import markdown
|
||||
from nio import AsyncClient, AsyncClientConfig, LoginError, Response, ErrorResponse, RoomSendResponse, SendRetryError, SyncError
|
||||
from nio.responses import LoginResponse, SyncResponse
|
||||
|
||||
|
||||
class MatrixNioGPTHelper:
|
||||
class MatrixClientHelper:
|
||||
"""
|
||||
A simple wrapper class for common matrix-nio actions.
|
||||
"""
|
||||
client = None
|
||||
|
||||
# Encryption is disabled because it's handled by Pantalaimon.
|
||||
client_config = AsyncClientConfig(max_limit_exceeded=0, max_timeouts=0, store_sync_tokens=True, encryption_enabled=False)
|
||||
|
||||
def __init__(self, auth_file: Union[Path, str], user_id: str, passwd: str, homeserver: str, store_path: str, device_name: str = 'MatrixGPT', device_id: str = None):
|
||||
self.auth_file = auth_file
|
||||
def __init__(self, user_id: str, passwd: str, homeserver: str, store_path: str, device_name: str):
|
||||
self.user_id = user_id
|
||||
self.passwd = passwd
|
||||
|
||||
|
@ -28,54 +26,153 @@ class MatrixNioGPTHelper:
|
|||
if not (self.homeserver.startswith("https://") or self.homeserver.startswith("http://")):
|
||||
self.homeserver = "https://" + self.homeserver
|
||||
|
||||
self.store_path = store_path
|
||||
Path(self.store_path).mkdir(parents=True, exist_ok=True)
|
||||
self.store_path = Path(store_path).absolute().expanduser().resolve()
|
||||
self.store_path.mkdir(parents=True, exist_ok=True)
|
||||
self.auth_file = self.store_path / (device_name.lower() + '.json')
|
||||
|
||||
self.device_name = device_name
|
||||
self.client = AsyncClient(homeserver=self.homeserver, user=self.user_id, config=self.client_config, device_id=device_id)
|
||||
self.client: AsyncClient = AsyncClient(homeserver=self.homeserver, user=self.user_id, config=self.client_config, device_id=device_name)
|
||||
self.logger = logging.getLogger('MatrixGPT').getChild('MatrixClientHelper')
|
||||
|
||||
async def login(self) -> tuple[bool, LoginError] | tuple[bool, LoginResponse | None]:
|
||||
async def login(self) -> tuple[bool, LoginResponse | LoginError | None]:
|
||||
try:
|
||||
# If there are no previously-saved credentials, we'll use the password.
|
||||
if not os.path.exists(self.auth_file):
|
||||
logger.info('Using username/password.')
|
||||
self.logger.info('Using username/password')
|
||||
resp = await self.client.login(self.passwd, device_name=self.device_name)
|
||||
|
||||
# check that we logged in successfully.
|
||||
if isinstance(resp, LoginResponse):
|
||||
self.write_details_to_disk(resp)
|
||||
self._write_details_to_disk(resp)
|
||||
return True, resp
|
||||
else:
|
||||
return False, resp
|
||||
else:
|
||||
# Otherwise the config file exists, so we'll use the stored credentials.
|
||||
logger.info('Using cached credentials.')
|
||||
with open(self.auth_file, "r") as f:
|
||||
config = json.load(f)
|
||||
client = AsyncClient(config["homeserver"])
|
||||
client.access_token = config["access_token"]
|
||||
client.user_id = config["user_id"]
|
||||
client.device_id = config["device_id"]
|
||||
self.logger.info('Using cached credentials')
|
||||
|
||||
auth_details = self._read_details_from_disk()['auth']
|
||||
client = AsyncClient(auth_details["homeserver"])
|
||||
client.access_token = auth_details["access_token"]
|
||||
client.user_id = auth_details["user_id"]
|
||||
client.device_id = auth_details["device_id"]
|
||||
|
||||
resp = await self.client.login(self.passwd, device_name=self.device_name)
|
||||
if isinstance(resp, LoginResponse):
|
||||
self.write_details_to_disk(resp)
|
||||
self._write_details_to_disk(resp)
|
||||
return True, resp
|
||||
else:
|
||||
return False, resp
|
||||
except Exception:
|
||||
return False, None
|
||||
raise
|
||||
|
||||
def write_details_to_disk(self, resp: LoginResponse) -> None:
|
||||
"""Writes the required login details to disk so we can log in later without
|
||||
using a password.
|
||||
async def sync(self) -> SyncResponse | SyncError:
|
||||
last_sync = self._read_details_from_disk().get('extra', {}).get('last_sync')
|
||||
response = await self.client.sync(timeout=10000, full_state=True, since=last_sync)
|
||||
if isinstance(response, SyncError):
|
||||
raise Exception(response)
|
||||
self._write_details_to_disk(extra_data={'last_sync': response.next_batch})
|
||||
return response
|
||||
|
||||
Arguments:
|
||||
resp {LoginResponse} -- the successful client login response.
|
||||
homeserver -- URL of homeserver, e.g. "https://matrix.example.org"
|
||||
def run_sync_in_bg(self):
|
||||
"""
|
||||
with open(self.auth_file, "w") as f:
|
||||
json.dump({"homeserver": self.homeserver, # e.g. "https://matrix.example.org"
|
||||
"user_id": resp.user_id, # e.g. "@user:example.org"
|
||||
"device_id": resp.device_id, # device ID, 10 uppercase letters
|
||||
"access_token": resp.access_token, # cryptogr. access token
|
||||
}, f, )
|
||||
Run a sync in the background to update the `last_sync` value every 3 minutes.
|
||||
"""
|
||||
asyncio.create_task(self._do_run_sync_in_bg())
|
||||
|
||||
async def _do_run_sync_in_bg(self):
|
||||
while True:
|
||||
await self.sync()
|
||||
await asyncio.sleep(180) # 3 minutes
|
||||
|
||||
def _read_details_from_disk(self):
|
||||
if not self.auth_file.exists():
|
||||
return {}
|
||||
with open(self.auth_file, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
def _write_details_to_disk(self, resp: LoginResponse = None, extra_data: dict = None) -> None:
|
||||
data = self._read_details_from_disk()
|
||||
if resp:
|
||||
data['auth'] = {
|
||||
'homeserver': self.homeserver,
|
||||
'user_id': resp.user_id,
|
||||
'device_id': resp.device_id,
|
||||
'access_token': resp.access_token,
|
||||
}
|
||||
if extra_data:
|
||||
data['extra'] = extra_data
|
||||
with open(self.auth_file, 'w') as f:
|
||||
json.dump(data, f, indent=4)
|
||||
|
||||
async def react_to_event(self, room_id: str, event_id: str, reaction_text: str, extra_error: str = False, extra_msg: str = False) -> Union[Response, ErrorResponse]:
|
||||
content = {
|
||||
"m.relates_to": {
|
||||
"rel_type": "m.annotation",
|
||||
"event_id": event_id,
|
||||
"key": reaction_text
|
||||
},
|
||||
"m.matrixbot": {}
|
||||
}
|
||||
if extra_error:
|
||||
content["m.matrixbot"]["error"] = str(extra_error)
|
||||
if extra_msg:
|
||||
content["m.matrixbot"]["msg"] = str(extra_msg)
|
||||
return await self.client.room_send(room_id, "m.reaction", content, ignore_unverified_devices=True)
|
||||
|
||||
async def send_text_to_room(self, room_id: str, message: str, notice: bool = False,
|
||||
markdown_convert: bool = True, reply_to_event_id: Optional[str] = None,
|
||||
thread: bool = False, thread_root_id: Optional[str] = None, extra_error: Optional[str] = None,
|
||||
extra_msg: Optional[str] = None) -> Union[RoomSendResponse, ErrorResponse]:
|
||||
"""Send text to a matrix room.
|
||||
|
||||
Args:
|
||||
room_id: The ID of the room to send the message to.
|
||||
message: The message content.
|
||||
notice: Whether the message should be sent with an "m.notice" message type
|
||||
(will not ping users).
|
||||
markdown_convert: Whether to convert the message content to markdown.
|
||||
Defaults to true.
|
||||
reply_to_event_id: Whether this message is a reply to another event. The event
|
||||
ID this is message is a reply to.
|
||||
thread:
|
||||
thread_root_id:
|
||||
extra_msg:
|
||||
extra_error:
|
||||
|
||||
Returns:
|
||||
A RoomSendResponse if the request was successful, else an ErrorResponse.
|
||||
|
||||
"""
|
||||
# Determine whether to ping room members or not
|
||||
msgtype = "m.notice" if notice else "m.text"
|
||||
|
||||
content = {"msgtype": msgtype, "format": "org.matrix.custom.html", "body": message}
|
||||
|
||||
if markdown_convert:
|
||||
content["formatted_body"] = markdown(message, extensions=['fenced_code'])
|
||||
|
||||
if reply_to_event_id:
|
||||
if thread:
|
||||
content["m.relates_to"] = {
|
||||
'event_id': thread_root_id,
|
||||
'is_falling_back': True,
|
||||
"m.in_reply_to": {
|
||||
"event_id": reply_to_event_id
|
||||
},
|
||||
'rel_type': "m.thread"
|
||||
}
|
||||
else:
|
||||
content["m.relates_to"] = {
|
||||
"m.in_reply_to": {
|
||||
"event_id": reply_to_event_id
|
||||
}
|
||||
}
|
||||
|
||||
# TODO: don't force this to string. what if we want to send an array?
|
||||
content["m.matrixgpt"] = {
|
||||
"error": str(extra_error),
|
||||
"msg": str(extra_msg),
|
||||
}
|
||||
try:
|
||||
return await self.client.room_send(room_id, "m.room.message", content, ignore_unverified_devices=True)
|
||||
except SendRetryError:
|
||||
self.logger.exception(f"Unable to send message response to {room_id}")
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
from openai import AsyncOpenAI
|
||||
|
||||
from matrix_gpt.config import global_config
|
||||
|
||||
"""
|
||||
Global variable to sync importing and sharing the configured module.
|
||||
"""
|
||||
|
||||
|
||||
class OpenAIClientManager:
|
||||
def __init__(self):
|
||||
self.api_key = None
|
||||
self.api_base = None
|
||||
|
||||
def _set_from_config(self):
|
||||
"""
|
||||
Have to update the config because it may not be instantiated yet.
|
||||
"""
|
||||
if global_config['openai']['api_base']:
|
||||
self.api_key.api_key = 'abc123'
|
||||
else:
|
||||
self.api_key = global_config['openai']['api_key']
|
||||
self.api_base = None
|
||||
if global_config['openai'].get('api_base'):
|
||||
self.api_base = global_config['openai'].get('api_base')
|
||||
|
||||
def client(self):
|
||||
self._set_from_config()
|
||||
return AsyncOpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.api_base
|
||||
)
|
||||
|
||||
|
||||
openai_client = OpenAIClientManager()
|
|
@ -1,6 +1,6 @@
|
|||
matrix-nio[e2e]
|
||||
matrix-nio[e2e]==0.24.0
|
||||
pyyaml
|
||||
markdown
|
||||
python-olm
|
||||
openai
|
||||
stopit
|
||||
openai==1.16.2
|
||||
git+https://github.com/Cyberes/bison.git
|
Loading…
Reference in New Issue