MatrixGPT/matrix_gpt/chat_functions.py

97 lines
3.2 KiB
Python
Raw Permalink Normal View History

2024-04-07 19:41:19 -06:00
import logging
2024-04-07 22:27:00 -06:00
from typing import List, Tuple
2024-04-09 19:26:44 -06:00
from urllib.parse import urlparse
2024-04-07 19:41:19 -06:00
from nio import AsyncClient, Event, MatrixRoom, RoomGetEventResponse, RoomMessageText
from matrix_gpt.config import global_config
2024-04-07 22:27:00 -06:00
from matrix_gpt.generate_clients.command_info import CommandInfo
2024-04-07 19:41:19 -06:00
2024-04-08 00:11:19 -06:00
logger = logging.getLogger('MatrixGPT').getChild('ChatFunctions')
2024-04-07 19:41:19 -06:00
def is_thread(event: RoomMessageText):
return event.source['content'].get('m.relates_to', {}).get('rel_type') == 'm.thread'
def check_command_prefix(string: str) -> Tuple[bool, CommandInfo | None]:
2024-04-07 19:41:19 -06:00
for k, v in global_config.command_prefixes.items():
if string.startswith(f'{k} '):
2024-04-07 22:27:00 -06:00
command_info = CommandInfo(**v)
return True, command_info
return False, None
2024-04-07 19:41:19 -06:00
async def is_this_our_thread(client: AsyncClient, room: MatrixRoom, event: RoomMessageText) -> Tuple[bool, CommandInfo | None]:
2024-04-07 19:41:19 -06:00
base_event_id = event.source['content'].get('m.relates_to', {}).get('event_id')
if base_event_id:
e = await client.room_get_event(room.room_id, base_event_id)
if not isinstance(e, RoomGetEventResponse):
logger.critical(f'Failed to get event in is_this_our_thread(): {vars(e)}')
return False, None
2024-04-07 19:41:19 -06:00
else:
return check_command_prefix(e.event.body)
else:
return False, None
2024-04-07 19:41:19 -06:00
async def get_thread_content(client: AsyncClient, room: MatrixRoom, base_event: RoomMessageText) -> List[Event]:
messages = []
2024-04-09 19:26:44 -06:00
# This is the event of the message that was just sent.
2024-04-07 19:41:19 -06:00
new_event = (await client.room_get_event(room.room_id, base_event.event_id)).event
2024-04-09 19:26:44 -06:00
2024-04-07 19:41:19 -06:00
while True:
if new_event.source['content'].get('m.relates_to', {}).get('rel_type') == 'm.thread':
2024-04-09 19:26:44 -06:00
# Put the event in the messages list only if it's related to the thread we're parsing.
2024-04-07 19:41:19 -06:00
messages.append(new_event)
else:
break
2024-04-09 19:26:44 -06:00
# Fetch the next event.
2024-04-07 19:41:19 -06:00
new_event = (await client.room_get_event(
room.room_id,
new_event.source['content']['m.relates_to']['m.in_reply_to']['event_id'])
).event
2024-04-09 19:26:44 -06:00
# Put the root event in the array.
2024-04-07 19:41:19 -06:00
messages.append((await client.room_get_event(
room.room_id, base_event.source['content']['m.relates_to']['event_id'])
2024-04-09 19:26:44 -06:00
).event)
2024-04-07 19:41:19 -06:00
messages.reverse()
return messages
def check_authorized(string, to_check):
def check_str(s, c):
if c == 'all':
return True
else:
if '@' not in c and ':' not in c:
# Homeserver
if s.split(':')[-1] in c:
return True
elif s in c:
# By username
return True
return False
if isinstance(to_check, str):
return check_str(string, to_check)
elif isinstance(to_check, list):
output = False
for item in to_check:
if check_str(string, item):
output = True
return output
else:
raise Exception
2024-04-09 19:26:44 -06:00
async def download_mxc(url: str, client: AsyncClient) -> bytes:
mxc = urlparse(url)
response = await client.download(mxc.netloc, mxc.path.strip("/"))
if hasattr(response, "body"):
return response.body
else:
return b''