diff --git a/config.sample.yaml b/config.sample.yaml index 9178996..d658e8e 100644 --- a/config.sample.yaml +++ b/config.sample.yaml @@ -7,7 +7,7 @@ bot_auth: password: password1234 homeserver: matrix.example.com store_path: 'bot-store/' - device_id: ABCDEFGHIJ + device_id: DEVICE1 openai_api_key: sk-J12J3O12U3J1LK2J310283JIJ1L2K3J openai_model: gpt-3.5-turbo diff --git a/matrix_gpt/bot/callbacks.py b/matrix_gpt/bot/callbacks.py index b66c1ed..5c82a2c 100644 --- a/matrix_gpt/bot/callbacks.py +++ b/matrix_gpt/bot/callbacks.py @@ -5,7 +5,7 @@ import time from nio import (AsyncClient, InviteMemberEvent, JoinError, MatrixRoom, MegolmEvent, RoomMessageText, UnknownEvent, ) from .bot_commands import Command -from .chat_functions import check_authorized, get_thread_content, is_thread, process_chat, react_to_event, send_text_to_room +from .chat_functions import check_authorized, get_thread_content, is_this_our_thread, is_thread, process_chat, react_to_event, send_text_to_room # from .config import Config from .storage import Storage @@ -72,7 +72,8 @@ class Callbacks: # room.member_count > 2 ... we assume a public room # room.member_count <= 2 ... we assume a DM # General message listener - if not msg.startswith(f'{self.command_prefix} ') and is_thread(event) and not self.store.check_seen_event(event.event_id): + if not msg.startswith(f'{self.command_prefix} ') and is_thread(event) and not self.store.check_seen_event(event.event_id) and (await is_this_our_thread(self.client, room, event, self.command_prefix)): + print(t) await self.client.room_typing(room.room_id, typing_state=True, timeout=3000) thread_content = await get_thread_content(self.client, room, event) api_data = [] @@ -85,10 +86,8 @@ class Callbacks: return else: thread_msg = event.body.strip().strip('\n') - api_data.append({ - 'role': 'assistant' if event.sender == self.client.user_id else 'user', - 'content': thread_msg if not thread_msg.startswith(self.command_prefix) else thread_msg[len(self.command_prefix):].strip() - }) # if len(thread_content) >= 2 and thread_content[0].body.startswith(self.command_prefix): # if thread_content[len(thread_content) - 2].sender == self.client.user + api_data.append({'role': 'assistant' if event.sender == self.client.user_id else 'user', 'content': thread_msg if not thread_msg.startswith(self.command_prefix) else thread_msg[ + len(self.command_prefix):].strip()}) # if len(thread_content) >= 2 and thread_content[0].body.startswith(self.command_prefix): # if thread_content[len(thread_content) - 2].sender == self.client.user # message = Message(self.client, self.store, msg, room, event, self.reply_in_thread) # await message.process() diff --git a/matrix_gpt/bot/chat_functions.py b/matrix_gpt/bot/chat_functions.py index f614182..410c5ef 100644 --- a/matrix_gpt/bot/chat_functions.py +++ b/matrix_gpt/bot/chat_functions.py @@ -2,29 +2,12 @@ import logging from typing import List, Optional, Union from markdown import markdown -from nio import ( - AsyncClient, - ErrorResponse, - Event, MatrixRoom, - MegolmEvent, - Response, - RoomMessageText, RoomSendResponse, - SendRetryError, -) +from nio import (AsyncClient, ErrorResponse, Event, MatrixRoom, MegolmEvent, Response, RoomMessageText, RoomSendResponse, SendRetryError, ) logger = logging.getLogger('MatrixGPT') -async def send_text_to_room( - client: AsyncClient, - room_id: str, - message: str, - notice: bool = False, - markdown_convert: bool = True, - reply_to_event_id: Optional[str] = None, - thread: bool = False, - thread_root_id: str = None -) -> Union[RoomSendResponse, ErrorResponse]: +async def send_text_to_room(client: AsyncClient, room_id: str, message: str, notice: bool = False, markdown_convert: bool = True, reply_to_event_id: Optional[str] = None, thread: bool = False, thread_root_id: str = None) -> Union[RoomSendResponse, ErrorResponse]: """Send text to a matrix room. Args: @@ -49,35 +32,19 @@ async def send_text_to_room( # Determine whether to ping room members or not msgtype = "m.notice" if notice else "m.text" - content = { - "msgtype": msgtype, - "format": "org.matrix.custom.html", - "body": message, - } + content = {"msgtype": msgtype, "format": "org.matrix.custom.html", "body": message, } if markdown_convert: content["formatted_body"] = markdown(message) if reply_to_event_id: if thread: - content["m.relates_to"] = { - 'event_id': thread_root_id, - 'is_falling_back': True, - "m.in_reply_to": { - "event_id": reply_to_event_id - }, - 'rel_type': "m.thread" - } + content["m.relates_to"] = {'event_id': thread_root_id, 'is_falling_back': True, "m.in_reply_to": {"event_id": reply_to_event_id}, 'rel_type': "m.thread"} else: content["m.relates_to"] = {"m.in_reply_to": {"event_id": reply_to_event_id}} try: - return await client.room_send( - room_id, - "m.room.message", - content, - ignore_unverified_devices=True, - ) + return await client.room_send(room_id, "m.room.message", content, ignore_unverified_devices=True, ) except SendRetryError: logger.exception(f"Unable to send message response to {room_id}") @@ -102,12 +69,7 @@ def make_pill(user_id: str, displayname: str = None) -> str: return f'{displayname}' -async def react_to_event( - client: AsyncClient, - room_id: str, - event_id: str, - reaction_text: str, -) -> Union[Response, ErrorResponse]: +async def react_to_event(client: AsyncClient, room_id: str, event_id: str, reaction_text: str, ) -> Union[Response, ErrorResponse]: """Reacts to a given event in a room with the given reaction text Args: @@ -125,20 +87,9 @@ async def react_to_event( Raises: SendRetryError: If the reaction was unable to be sent. """ - content = { - "m.relates_to": { - "rel_type": "m.annotation", - "event_id": event_id, - "key": reaction_text, - } - } + content = {"m.relates_to": {"rel_type": "m.annotation", "event_id": event_id, "key": reaction_text, }} - return await client.room_send( - room_id, - "m.reaction", - content, - ignore_unverified_devices=True, - ) + return await client.room_send(room_id, "m.reaction", content, ignore_unverified_devices=True, ) async def decryption_failure(self, room: MatrixRoom, event: MegolmEvent) -> None: @@ -153,23 +104,25 @@ async def decryption_failure(self, room: MatrixRoom, event: MegolmEvent) -> None # f"commands a second time)." # ) - user_msg = ( - "Unable to decrypt this message. " - "Check whether you've chosen to only encrypt to trusted devices." - ) + user_msg = ("Unable to decrypt this message. " + "Check whether you've chosen to only encrypt to trusted devices.") - await send_text_to_room( - self.client, - room.room_id, - user_msg, - reply_to_event_id=event.event_id, - ) + await send_text_to_room(self.client, room.room_id, user_msg, reply_to_event_id=event.event_id, ) def is_thread(event: RoomMessageText): return event.source['content'].get('m.relates_to', {}).get('rel_type') == 'm.thread' +async def is_this_our_thread(client: AsyncClient, room: MatrixRoom, event: RoomMessageText, command_flag: str): + base_event_id = event.source['content'].get('m.relates_to', {}).get('event_id') + if base_event_id: + return (await client.room_get_event(room.room_id, base_event_id)).event.body.startswith(f'{command_flag} ') + else: + # Better safe than sorry + return False + + async def get_thread_content(client: AsyncClient, room: MatrixRoom, base_event: RoomMessageText) -> List[Event]: messages = [] new_event = (await client.room_get_event(room.room_id, base_event.event_id)).event @@ -193,9 +146,7 @@ async def process_chat(client, room, event, command, store, openai, thread_root_ if isinstance(command, list): messages = command else: - messages = [ - {'role': 'user', 'content': command}, - ] + messages = [{'role': 'user', 'content': command}, ] if system_prompt: messages.insert(0, {"role": "system", "content": system_prompt})