97 lines
3.2 KiB
Python
97 lines
3.2 KiB
Python
import logging
|
|
from typing import List, Tuple
|
|
from urllib.parse import urlparse
|
|
|
|
from nio import AsyncClient, Event, MatrixRoom, RoomGetEventResponse, RoomMessageText
|
|
|
|
from matrix_gpt.config import global_config
|
|
from matrix_gpt.generate_clients.command_info import CommandInfo
|
|
|
|
logger = logging.getLogger('MatrixGPT').getChild('ChatFunctions')
|
|
|
|
|
|
def is_thread(event: RoomMessageText):
|
|
return event.source['content'].get('m.relates_to', {}).get('rel_type') == 'm.thread'
|
|
|
|
|
|
def check_command_prefix(string: str) -> Tuple[bool, CommandInfo | None]:
|
|
for k, v in global_config.command_prefixes.items():
|
|
if string.startswith(f'{k} '):
|
|
command_info = CommandInfo(**v)
|
|
return True, command_info
|
|
return False, None
|
|
|
|
|
|
async def is_this_our_thread(client: AsyncClient, room: MatrixRoom, event: RoomMessageText) -> Tuple[bool, CommandInfo | None]:
|
|
base_event_id = event.source['content'].get('m.relates_to', {}).get('event_id')
|
|
if base_event_id:
|
|
e = await client.room_get_event(room.room_id, base_event_id)
|
|
if not isinstance(e, RoomGetEventResponse):
|
|
logger.critical(f'Failed to get event in is_this_our_thread(): {vars(e)}')
|
|
return False, None
|
|
else:
|
|
return check_command_prefix(e.event.body)
|
|
else:
|
|
return False, None
|
|
|
|
|
|
async def get_thread_content(client: AsyncClient, room: MatrixRoom, base_event: RoomMessageText) -> List[Event]:
|
|
messages = []
|
|
|
|
# This is the event of the message that was just sent.
|
|
new_event = (await client.room_get_event(room.room_id, base_event.event_id)).event
|
|
|
|
while True:
|
|
if new_event.source['content'].get('m.relates_to', {}).get('rel_type') == 'm.thread':
|
|
# Put the event in the messages list only if it's related to the thread we're parsing.
|
|
messages.append(new_event)
|
|
else:
|
|
break
|
|
# Fetch the next event.
|
|
new_event = (await client.room_get_event(
|
|
room.room_id,
|
|
new_event.source['content']['m.relates_to']['m.in_reply_to']['event_id'])
|
|
).event
|
|
|
|
# Put the root event in the array.
|
|
messages.append((await client.room_get_event(
|
|
room.room_id, base_event.source['content']['m.relates_to']['event_id'])
|
|
).event)
|
|
messages.reverse()
|
|
return messages
|
|
|
|
|
|
def check_authorized(string, to_check):
|
|
def check_str(s, c):
|
|
if c == 'all':
|
|
return True
|
|
else:
|
|
if '@' not in c and ':' not in c:
|
|
# Homeserver
|
|
if s.split(':')[-1] in c:
|
|
return True
|
|
elif s in c:
|
|
# By username
|
|
return True
|
|
return False
|
|
|
|
if isinstance(to_check, str):
|
|
return check_str(string, to_check)
|
|
elif isinstance(to_check, list):
|
|
output = False
|
|
for item in to_check:
|
|
if check_str(string, item):
|
|
output = True
|
|
return output
|
|
else:
|
|
raise Exception
|
|
|
|
|
|
async def download_mxc(url: str, client: AsyncClient) -> bytes:
|
|
mxc = urlparse(url)
|
|
response = await client.download(mxc.netloc, mxc.path.strip("/"))
|
|
if hasattr(response, "body"):
|
|
return response.body
|
|
else:
|
|
return b''
|